Skip to content
50 changes: 31 additions & 19 deletions angrop/chain_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _normalize_jmp_reg(self, gadget, pre_preserve=None, to_set_regs=None):
pass
return None

def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None):
def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None, final_gadget=None):
if not self.chain_builder._can_do_write:
return None
if pre_preserve is None:
Expand Down Expand Up @@ -731,23 +731,27 @@ def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None):

try:
# step1: find a shifter that clean up the jmp_mem call
sc = abs(gadget.stack_change) + self.project.arch.bytes
shifter = None
# find the smallest shifter
shift_gadgets = self.chain_builder._shifter.shift_gadgets
keys = sorted(shift_gadgets.keys())
shifter_list = [shift_gadgets[x] for x in keys if x >= sc]
if not shifter_list:
return None
shifter_list = itertools.chain.from_iterable(shifter_list)
for shifter in shifter_list:
if shifter.pc_offset < shift_size:
continue
if not shifter.changed_regs.intersection(post_preserve):
break
# if final_gadget is passed in, then it is the shifter
if final_gadget:
shifter = final_gadget
else:
return None
assert shifter.transit_type == 'pop_pc'
sc = abs(gadget.stack_change) + self.project.arch.bytes
shifter = None
# find the smallest shifter
shift_gadgets = self.chain_builder._shifter.shift_gadgets
keys = sorted(shift_gadgets.keys())
shifter_list = [shift_gadgets[x] for x in keys if x >= sc]
if not shifter_list:
return None
shifter_list = itertools.chain.from_iterable(shifter_list)
for shifter in shifter_list:
if shifter.pc_offset < shift_size:
continue
if not shifter.changed_regs.intersection(post_preserve):
break
else:
return None
assert shifter.transit_type == 'pop_pc'

# step2: write the shifter to a writable location
data = struct.pack(self.project.arch.struct_fmt(), shifter.addr)
Expand Down Expand Up @@ -813,7 +817,7 @@ def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None):
except (RopException, IndexError):
return None

def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set_regs=None):
def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set_regs=None, final_gadget=None):
"""
pre_preserve: what registers to preserve before executing the gadget
post_preserve: what registers to preserve after executing the gadget
Expand Down Expand Up @@ -858,8 +862,13 @@ def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set
if tmp is None:
return None
gadgets = tmp + gadgets
if final_gadget:
gadgets += [final_gadget]
elif gadget.transit_type == 'jmp_mem':
rb = self._normalize_jmp_mem(gadget, pre_preserve=pre_preserve, post_preserve=post_preserve)
rb = self._normalize_jmp_mem(gadget,
pre_preserve=pre_preserve,
post_preserve=post_preserve,
final_gadget=final_gadget)
return rb
elif gadget.transit_type == 'pop_pc':
pass
Expand All @@ -872,6 +881,9 @@ def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set
if rb is None:
return None

if final_gadget:
return rb

# normalize non-positive stack_change
if gadget.stack_change <= 0:
shift_gadgets = self.chain_builder._shifter.shift_gadgets
Expand Down
162 changes: 160 additions & 2 deletions angrop/chain_builder/reg_mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
l = logging.getLogger(__name__)

_global_reg_mover = None # type: ignore
_global_push_pop_mover = None # type: ignore

def _set_global_reg_mover(reg_mover, ptr_list):
global _global_reg_mover# pylint: disable=global-statement
_global_reg_mover = reg_mover
Builder.used_writable_ptrs = ptr_list

def worker_func(t):
def _set_global_push_pop_mover(push_pop_mover, ptr_list):
global _global_push_pop_mover# pylint: disable=global-statement
_global_push_pop_mover = push_pop_mover
Builder.used_writable_ptrs = ptr_list

def reg_mover_worker_func(t):
new_move, gadget = t
gadget.project = _global_reg_mover.project # type: ignore
pre_preserve = {new_move.from_reg}
Expand All @@ -35,6 +42,152 @@ def worker_func(t):
solver = rb._blank_state.solver
return new_move, gadget.addr, solver, rb

def push_pop_mover_worker_func(t):
key, value = t
from_reg, to_reg = key
project = _global_push_pop_mover.project # type: ignore
move = RopRegMove(from_reg, to_reg, project.arch.bits)
pusher, popper = value
pusher.project = project
popper.project = project
rb = _global_push_pop_mover.normalize_gadget(pusher, # type: ignore
pre_preserve={from_reg},
post_preserve={to_reg},
final_gadget=popper)
if rb is not None and move in rb.reg_moves:
solver = rb._blank_state.solver
else:
rb = None
solver = None
return move, solver, rb

class PushPopMover(Builder):
"""
construct register moves by chaining push/pop, like push rax; jmp rbx => pop rdi; ret
to `mov rdi, rax`
"""
def __init__(self, chain_builder, reg_mover):
super().__init__(chain_builder)
self._stack_write_dict = defaultdict(list)
self._reg_mover = reg_mover

def bootstrap(self):
gadgets = self.filter_gadgets(self.chain_builder.gadgets)
for g in gadgets:
for reg in g.stack_writes.values():
self._stack_write_dict[reg].append(g)

def _optimize_todos(self):
"""
only try to find push/pop register moves chains if there is no way to directly move
between these two registers
"""
todos = {}
graph = self._reg_mover._graph
reg_setting_dict = self.chain_builder._reg_setter._reg_setting_dict
for from_reg, to_reg in itertools.product(self.arch.reg_list, self.arch.reg_list):
if from_reg == to_reg:
continue
edge = (from_reg, to_reg)
if graph.has_edge(*edge):
if graph.get_edge_data(*edge)['bits'] == self.project.arch.bits:
continue
pop_gadgets = [g for g in reg_setting_dict[to_reg] if isinstance(g, RopGadget)]
if not pop_gadgets:
continue
push_gadgets = self._stack_write_dict[from_reg]
if not push_gadgets:
continue
good_pusher = None
good_popper = None
for offset in range(0, 0x20, self.project.arch.bytes):
matched_pop_gadgets = [rb for rb in pop_gadgets if any(pop.stack_offset == offset and pop.reg == to_reg
for pop in rb.reg_pops)]
matched_push_gadgets = [g for g in push_gadgets if
offset in g.stack_writes and g.stack_writes[offset] == from_reg]
matched_push_gadgets = sorted(matched_push_gadgets, key=lambda g: g.max_stack_offset)
matched_pop_gadgets = sorted(matched_pop_gadgets, key=lambda g: (g.stack_change, len(g.changed_regs)))
for pusher, popper in itertools.product(matched_push_gadgets, matched_pop_gadgets):
if popper.stack_change >= pusher.max_stack_offset + self.project.arch.bytes:
good_pusher = pusher
good_popper = popper
break
if good_pusher:
break
if not good_pusher:
continue
todos[(from_reg, to_reg)] = (good_pusher, good_popper)
return todos

def normalize_single_threaded(self):
todos = self._optimize_todos()
for key, value in todos.items():
from_reg, to_reg = key
move = RopRegMove(from_reg, to_reg, self.project.arch.bits)
pusher, popper = value
rb = self.normalize_gadget(pusher, pre_preserve={from_reg}, post_preserve={to_reg}, final_gadget=popper)
if rb and move in rb.reg_moves:
yield move, rb

def normalize_multiprocessing(self, processes):
todos = self._optimize_todos()
with mp.Manager() as manager:
# HACK: ideally, used_ptrs should be a resource of each ropblock that can be reassigned
# when conflict happens. but currently, I'm being lazy and just make sure every pointer
# is different
ptr_list = manager.list(Builder.used_writable_ptrs)
initargs = (self, ptr_list)
with mp.Pool(processes=processes, initializer=_set_global_push_pop_mover, initargs=initargs) as pool:
for move, solver, rb in pool.imap_unordered(push_pop_mover_worker_func, todos.items()):
if rb is None:
continue
state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, 0)
state.solver = solver
rb.set_project(self.project)
rb.set_builder(self)
rb._blank_state = state
yield move, rb
Builder.used_writable_ptrs = list(ptr_list)

def optimize(self, processes):
res = False
if processes == 1:
iterable = self.normalize_single_threaded()
else:
iterable = self.normalize_multiprocessing(processes)
for move, rb in iterable:
res = True
for move in rb.reg_moves:
edge = (move.from_reg, move.to_reg)
if self._reg_mover._graph.has_edge(*edge):
edge_data = self._reg_mover._graph.get_edge_data(*edge)
edge_blocks = edge_data['block']
edge_blocks.append(rb)
edge_data['block'] = sorted(edge_blocks, key=lambda x: x.stack_change)
if move.bits > edge_data['bits']:
edge_data['bits'] = move.bits
else:
self._reg_mover._graph.add_edge(*edge, block=[rb], bits=move.bits)

return res

def filter_gadgets(self, gadgets):
"""
filter gadgets having the same effect
"""
# first: filter out gadgets that don't do stack_writes
gadgets = {g for g in gadgets if g.stack_writes and
not g.has_conditional_branch and not g.has_symbolic_access()}
gadgets = self._filter_gadgets(gadgets)
return gadgets

def _effect_tuple(self, g):
d = tuple(sorted(list(g.stack_writes.items()), key=lambda x: x[0]))
return d

def _comparison_tuple(self, g):
return (g.max_stack_offset, g.num_sym_mem_access, rop_utils.transit_num(g), g.isn_count)

class RegMover(Builder):
"""
handle register moves such as `mov rax, rcx`
Expand All @@ -47,9 +200,12 @@ def __init__(self, chain_builder):
self._graph: nx.Graph = None # type: ignore
self._normalize_todos = {}

self._push_pop_mover = PushPopMover(chain_builder, self)

def bootstrap(self):
self._reg_moving_gadgets = sorted(self.filter_gadgets(self.chain_builder.gadgets), key=lambda g:g.stack_change)
self._reg_moving_blocks = {g for g in self._reg_moving_gadgets if g.self_contained}
self._push_pop_mover.bootstrap()
self._build_move_graph()

def build_normalize_todos(self):
Expand Down Expand Up @@ -132,7 +288,7 @@ def normalize_multiprocessing(self, processes):
ptr_list = manager.list(Builder.used_writable_ptrs)
initargs = (self, ptr_list)
with mp.Pool(processes=processes, initializer=_set_global_reg_mover, initargs=initargs) as pool:
for new_move, addr, solver, rb in pool.imap_unordered(worker_func, self.normalize_todos()):
for new_move, addr, solver, rb in pool.imap_unordered(reg_mover_worker_func, self.normalize_todos()):
if rb is None:
continue
state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, 0)
Expand Down Expand Up @@ -182,6 +338,8 @@ def optimize(self, processes):
edge_data['bits'] = move.bits
else:
self._graph.add_edge(*edge, block=[rb], bits=move.bits)

res |= self._push_pop_mover.optimize(processes)
return res

def _build_move_graph(self):
Expand Down
17 changes: 3 additions & 14 deletions angrop/chain_builder/reg_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ def __init__(self, chain_builder):
self.hard_chain_cache: dict[tuple, list] = None # type: ignore
# Estimate of how difficult it is to set each register.
# all self-contained and not symbolic access
self._reg_setting_dict: dict[str, list] = None # type: ignore
self._reg_setting_dict: dict[str, list] = defaultdict(list)

def bootstrap(self):
self._reg_setting_gadgets = self.filter_gadgets(self.chain_builder.gadgets)

# update reg_setting_dict
self._reg_setting_dict = defaultdict(list)
for g in self._reg_setting_gadgets:
if not g.self_contained:
continue
Expand All @@ -54,13 +53,14 @@ def bootstrap(self):
self.hard_chain_cache = {}

#### Utility Functions ####

def _insert_to_reg_dict(self, gs):
for rb in gs:
for reg in rb.popped_regs:
self._reg_setting_dict[reg].append(rb)
for reg in self._reg_setting_dict:
lst = self._reg_setting_dict[reg]
self._reg_setting_dict[reg] = sorted(lst, key=lambda x: x.stack_change)
self._reg_setting_dict[reg] = sorted(lst, key=lambda x: (x.stack_change, x.isn_count))

def _expand_ropblocks(self, mixins):
"""
Expand Down Expand Up @@ -665,17 +665,6 @@ def _comparison_tuple(self, g):
return (len(g.changed_regs-g.popped_regs), g.stack_change, g.num_sym_mem_access,
g.isn_count, int(g.has_conditional_branch is True))

def _same_effect(self, g1, g2): # pylint: disable=no-self-use
if g1.popped_regs != g2.popped_regs:
return False
if g1.concrete_regs != g2.concrete_regs:
return False
if g1.reg_dependencies != g2.reg_dependencies:
return False
if g1.transit_type != g2.transit_type:
return False
return True

def filter_gadgets(self, gadgets):
"""
process gadgets based on their effects
Expand Down
16 changes: 15 additions & 1 deletion angrop/gadget_finder/gadget_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import math
import ctypes
import logging
Expand Down Expand Up @@ -583,7 +584,13 @@ def _check_reg_changes(self, final_state, _, gadget):
if extended is not None and bits == 64:
if extended <= 32:
bits = 32
pop = RopRegPop(reg, bits)
ast_depth = ast.depth
if ast.op == 'Reverse':
ast_depth -= 1
name = list(ast.variables)[0]
re_res = re.match(r"symbolic_stack_(\d+)_", name)
offset = int(re_res.group(1)) * self.project.arch.bytes # type:ignore
pop = RopRegPop(reg, bits, offset, ast_depth)
gadget.reg_pops.add(pop)
gadget.changed_regs.add(reg)
else:
Expand Down Expand Up @@ -921,6 +928,7 @@ def _build_mem_access(self, a, gadget, init_state, final_state):
mem_access.data_dependencies = rop_utils.get_ast_dependency(a.data.ast)
mem_access.data_controllers = rop_utils.get_ast_controllers(init_state, a.data.ast,
mem_access.data_dependencies)
mem_access.data_depth = a.data.ast.depth
else:
mem_access.data_constant = init_state.solver.eval(a.data.ast)
elif a.action == "read":
Expand Down Expand Up @@ -1001,6 +1009,7 @@ def _build_mem_change(self, read_action, write_action, gadget, init_state, final
mem_change.op = write_action.data.ast.op
mem_change.data_dependencies = data_dependencies
mem_change.data_stack_controllers = data_stack_controllers
mem_change.data_depth = sym_data.depth
mem_change.data_controllers = data_controllers
mem_change.data_size = write_action.data.ast.size()
mem_change.addr_size = write_action.addr.ast.size()
Expand Down Expand Up @@ -1138,6 +1147,11 @@ def _analyze_mem_access(self, final_state, init_state, gadget):
if a.action == "read":
gadget.mem_reads.append(mem_access)
if a.action == "write":
# if two writes at the same address, the second one overrules the first one
if mem_access.stack_offset is not None:
for m in gadget.mem_writes:
if mem_access.stack_offset == m.stack_offset:
gadget.mem_writes.remove(m)
gadget.mem_writes.append(mem_access)
return True

Expand Down
Loading
Loading