diff --git a/angrop/chain_builder/builder.py b/angrop/chain_builder/builder.py index 2a761cbe..464edf38 100644 --- a/angrop/chain_builder/builder.py +++ b/angrop/chain_builder/builder.py @@ -696,7 +696,7 @@ def _normalize_jmp_reg(self, gadget, pre_preserve=None, to_set_regs=None): pass return None - def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None): + def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None, final_gadget=None): if not self.chain_builder._can_do_write: return None if pre_preserve is None: @@ -731,23 +731,27 @@ def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None): try: # step1: find a shifter that clean up the jmp_mem call - sc = abs(gadget.stack_change) + self.project.arch.bytes - shifter = None - # find the smallest shifter - shift_gadgets = self.chain_builder._shifter.shift_gadgets - keys = sorted(shift_gadgets.keys()) - shifter_list = [shift_gadgets[x] for x in keys if x >= sc] - if not shifter_list: - return None - shifter_list = itertools.chain.from_iterable(shifter_list) - for shifter in shifter_list: - if shifter.pc_offset < shift_size: - continue - if not shifter.changed_regs.intersection(post_preserve): - break + # if final_gadget is passed in, then it is the shifter + if final_gadget: + shifter = final_gadget else: - return None - assert shifter.transit_type == 'pop_pc' + sc = abs(gadget.stack_change) + self.project.arch.bytes + shifter = None + # find the smallest shifter + shift_gadgets = self.chain_builder._shifter.shift_gadgets + keys = sorted(shift_gadgets.keys()) + shifter_list = [shift_gadgets[x] for x in keys if x >= sc] + if not shifter_list: + return None + shifter_list = itertools.chain.from_iterable(shifter_list) + for shifter in shifter_list: + if shifter.pc_offset < shift_size: + continue + if not shifter.changed_regs.intersection(post_preserve): + break + else: + return None + assert shifter.transit_type == 'pop_pc' # step2: write the shifter to a writable location data = struct.pack(self.project.arch.struct_fmt(), shifter.addr) @@ -813,7 +817,7 @@ def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None): except (RopException, IndexError): return None - def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set_regs=None): + def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set_regs=None, final_gadget=None): """ pre_preserve: what registers to preserve before executing the gadget post_preserve: what registers to preserve after executing the gadget @@ -858,8 +862,13 @@ def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set if tmp is None: return None gadgets = tmp + gadgets + if final_gadget: + gadgets += [final_gadget] elif gadget.transit_type == 'jmp_mem': - rb = self._normalize_jmp_mem(gadget, pre_preserve=pre_preserve, post_preserve=post_preserve) + rb = self._normalize_jmp_mem(gadget, + pre_preserve=pre_preserve, + post_preserve=post_preserve, + final_gadget=final_gadget) return rb elif gadget.transit_type == 'pop_pc': pass @@ -872,6 +881,9 @@ def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set if rb is None: return None + if final_gadget: + return rb + # normalize non-positive stack_change if gadget.stack_change <= 0: shift_gadgets = self.chain_builder._shifter.shift_gadgets diff --git a/angrop/chain_builder/reg_mover.py b/angrop/chain_builder/reg_mover.py index 93ff7304..c16c7187 100644 --- a/angrop/chain_builder/reg_mover.py +++ b/angrop/chain_builder/reg_mover.py @@ -17,12 +17,19 @@ l = logging.getLogger(__name__) _global_reg_mover = None # type: ignore +_global_push_pop_mover = None # type: ignore + def _set_global_reg_mover(reg_mover, ptr_list): global _global_reg_mover# pylint: disable=global-statement _global_reg_mover = reg_mover Builder.used_writable_ptrs = ptr_list -def worker_func(t): +def _set_global_push_pop_mover(push_pop_mover, ptr_list): + global _global_push_pop_mover# pylint: disable=global-statement + _global_push_pop_mover = push_pop_mover + Builder.used_writable_ptrs = ptr_list + +def reg_mover_worker_func(t): new_move, gadget = t gadget.project = _global_reg_mover.project # type: ignore pre_preserve = {new_move.from_reg} @@ -35,6 +42,152 @@ def worker_func(t): solver = rb._blank_state.solver return new_move, gadget.addr, solver, rb +def push_pop_mover_worker_func(t): + key, value = t + from_reg, to_reg = key + project = _global_push_pop_mover.project # type: ignore + move = RopRegMove(from_reg, to_reg, project.arch.bits) + pusher, popper = value + pusher.project = project + popper.project = project + rb = _global_push_pop_mover.normalize_gadget(pusher, # type: ignore + pre_preserve={from_reg}, + post_preserve={to_reg}, + final_gadget=popper) + if rb is not None and move in rb.reg_moves: + solver = rb._blank_state.solver + else: + rb = None + solver = None + return move, solver, rb + +class PushPopMover(Builder): + """ + construct register moves by chaining push/pop, like push rax; jmp rbx => pop rdi; ret + to `mov rdi, rax` + """ + def __init__(self, chain_builder, reg_mover): + super().__init__(chain_builder) + self._stack_write_dict = defaultdict(list) + self._reg_mover = reg_mover + + def bootstrap(self): + gadgets = self.filter_gadgets(self.chain_builder.gadgets) + for g in gadgets: + for reg in g.stack_writes.values(): + self._stack_write_dict[reg].append(g) + + def _optimize_todos(self): + """ + only try to find push/pop register moves chains if there is no way to directly move + between these two registers + """ + todos = {} + graph = self._reg_mover._graph + reg_setting_dict = self.chain_builder._reg_setter._reg_setting_dict + for from_reg, to_reg in itertools.product(self.arch.reg_list, self.arch.reg_list): + if from_reg == to_reg: + continue + edge = (from_reg, to_reg) + if graph.has_edge(*edge): + if graph.get_edge_data(*edge)['bits'] == self.project.arch.bits: + continue + pop_gadgets = [g for g in reg_setting_dict[to_reg] if isinstance(g, RopGadget)] + if not pop_gadgets: + continue + push_gadgets = self._stack_write_dict[from_reg] + if not push_gadgets: + continue + good_pusher = None + good_popper = None + for offset in range(0, 0x20, self.project.arch.bytes): + matched_pop_gadgets = [rb for rb in pop_gadgets if any(pop.stack_offset == offset and pop.reg == to_reg + for pop in rb.reg_pops)] + matched_push_gadgets = [g for g in push_gadgets if + offset in g.stack_writes and g.stack_writes[offset] == from_reg] + matched_push_gadgets = sorted(matched_push_gadgets, key=lambda g: g.max_stack_offset) + matched_pop_gadgets = sorted(matched_pop_gadgets, key=lambda g: (g.stack_change, len(g.changed_regs))) + for pusher, popper in itertools.product(matched_push_gadgets, matched_pop_gadgets): + if popper.stack_change >= pusher.max_stack_offset + self.project.arch.bytes: + good_pusher = pusher + good_popper = popper + break + if good_pusher: + break + if not good_pusher: + continue + todos[(from_reg, to_reg)] = (good_pusher, good_popper) + return todos + + def normalize_single_threaded(self): + todos = self._optimize_todos() + for key, value in todos.items(): + from_reg, to_reg = key + move = RopRegMove(from_reg, to_reg, self.project.arch.bits) + pusher, popper = value + rb = self.normalize_gadget(pusher, pre_preserve={from_reg}, post_preserve={to_reg}, final_gadget=popper) + if rb and move in rb.reg_moves: + yield move, rb + + def normalize_multiprocessing(self, processes): + todos = self._optimize_todos() + with mp.Manager() as manager: + # HACK: ideally, used_ptrs should be a resource of each ropblock that can be reassigned + # when conflict happens. but currently, I'm being lazy and just make sure every pointer + # is different + ptr_list = manager.list(Builder.used_writable_ptrs) + initargs = (self, ptr_list) + with mp.Pool(processes=processes, initializer=_set_global_push_pop_mover, initargs=initargs) as pool: + for move, solver, rb in pool.imap_unordered(push_pop_mover_worker_func, todos.items()): + if rb is None: + continue + state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, 0) + state.solver = solver + rb.set_project(self.project) + rb.set_builder(self) + rb._blank_state = state + yield move, rb + Builder.used_writable_ptrs = list(ptr_list) + + def optimize(self, processes): + res = False + if processes == 1: + iterable = self.normalize_single_threaded() + else: + iterable = self.normalize_multiprocessing(processes) + for move, rb in iterable: + res = True + for move in rb.reg_moves: + edge = (move.from_reg, move.to_reg) + if self._reg_mover._graph.has_edge(*edge): + edge_data = self._reg_mover._graph.get_edge_data(*edge) + edge_blocks = edge_data['block'] + edge_blocks.append(rb) + edge_data['block'] = sorted(edge_blocks, key=lambda x: x.stack_change) + if move.bits > edge_data['bits']: + edge_data['bits'] = move.bits + else: + self._reg_mover._graph.add_edge(*edge, block=[rb], bits=move.bits) + + return res + + def filter_gadgets(self, gadgets): + """ + filter gadgets having the same effect + """ + # first: filter out gadgets that don't do stack_writes + gadgets = {g for g in gadgets if g.stack_writes and + not g.has_conditional_branch and not g.has_symbolic_access()} + gadgets = self._filter_gadgets(gadgets) + return gadgets + + def _effect_tuple(self, g): + d = tuple(sorted(list(g.stack_writes.items()), key=lambda x: x[0])) + return d + + def _comparison_tuple(self, g): + return (g.max_stack_offset, g.num_sym_mem_access, rop_utils.transit_num(g), g.isn_count) + class RegMover(Builder): """ handle register moves such as `mov rax, rcx` @@ -47,9 +200,12 @@ def __init__(self, chain_builder): self._graph: nx.Graph = None # type: ignore self._normalize_todos = {} + self._push_pop_mover = PushPopMover(chain_builder, self) + def bootstrap(self): self._reg_moving_gadgets = sorted(self.filter_gadgets(self.chain_builder.gadgets), key=lambda g:g.stack_change) self._reg_moving_blocks = {g for g in self._reg_moving_gadgets if g.self_contained} + self._push_pop_mover.bootstrap() self._build_move_graph() def build_normalize_todos(self): @@ -132,7 +288,7 @@ def normalize_multiprocessing(self, processes): ptr_list = manager.list(Builder.used_writable_ptrs) initargs = (self, ptr_list) with mp.Pool(processes=processes, initializer=_set_global_reg_mover, initargs=initargs) as pool: - for new_move, addr, solver, rb in pool.imap_unordered(worker_func, self.normalize_todos()): + for new_move, addr, solver, rb in pool.imap_unordered(reg_mover_worker_func, self.normalize_todos()): if rb is None: continue state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, 0) @@ -182,6 +338,8 @@ def optimize(self, processes): edge_data['bits'] = move.bits else: self._graph.add_edge(*edge, block=[rb], bits=move.bits) + + res |= self._push_pop_mover.optimize(processes) return res def _build_move_graph(self): diff --git a/angrop/chain_builder/reg_setter.py b/angrop/chain_builder/reg_setter.py index d885884b..d0815a6f 100644 --- a/angrop/chain_builder/reg_setter.py +++ b/angrop/chain_builder/reg_setter.py @@ -31,13 +31,12 @@ def __init__(self, chain_builder): self.hard_chain_cache: dict[tuple, list] = None # type: ignore # Estimate of how difficult it is to set each register. # all self-contained and not symbolic access - self._reg_setting_dict: dict[str, list] = None # type: ignore + self._reg_setting_dict: dict[str, list] = defaultdict(list) def bootstrap(self): self._reg_setting_gadgets = self.filter_gadgets(self.chain_builder.gadgets) # update reg_setting_dict - self._reg_setting_dict = defaultdict(list) for g in self._reg_setting_gadgets: if not g.self_contained: continue @@ -54,13 +53,14 @@ def bootstrap(self): self.hard_chain_cache = {} #### Utility Functions #### + def _insert_to_reg_dict(self, gs): for rb in gs: for reg in rb.popped_regs: self._reg_setting_dict[reg].append(rb) for reg in self._reg_setting_dict: lst = self._reg_setting_dict[reg] - self._reg_setting_dict[reg] = sorted(lst, key=lambda x: x.stack_change) + self._reg_setting_dict[reg] = sorted(lst, key=lambda x: (x.stack_change, x.isn_count)) def _expand_ropblocks(self, mixins): """ @@ -665,17 +665,6 @@ def _comparison_tuple(self, g): return (len(g.changed_regs-g.popped_regs), g.stack_change, g.num_sym_mem_access, g.isn_count, int(g.has_conditional_branch is True)) - def _same_effect(self, g1, g2): # pylint: disable=no-self-use - if g1.popped_regs != g2.popped_regs: - return False - if g1.concrete_regs != g2.concrete_regs: - return False - if g1.reg_dependencies != g2.reg_dependencies: - return False - if g1.transit_type != g2.transit_type: - return False - return True - def filter_gadgets(self, gadgets): """ process gadgets based on their effects diff --git a/angrop/gadget_finder/gadget_analyzer.py b/angrop/gadget_finder/gadget_analyzer.py index 7671ca28..db3359e1 100644 --- a/angrop/gadget_finder/gadget_analyzer.py +++ b/angrop/gadget_finder/gadget_analyzer.py @@ -1,3 +1,4 @@ +import re import math import ctypes import logging @@ -583,7 +584,13 @@ def _check_reg_changes(self, final_state, _, gadget): if extended is not None and bits == 64: if extended <= 32: bits = 32 - pop = RopRegPop(reg, bits) + ast_depth = ast.depth + if ast.op == 'Reverse': + ast_depth -= 1 + name = list(ast.variables)[0] + re_res = re.match(r"symbolic_stack_(\d+)_", name) + offset = int(re_res.group(1)) * self.project.arch.bytes # type:ignore + pop = RopRegPop(reg, bits, offset, ast_depth) gadget.reg_pops.add(pop) gadget.changed_regs.add(reg) else: @@ -921,6 +928,7 @@ def _build_mem_access(self, a, gadget, init_state, final_state): mem_access.data_dependencies = rop_utils.get_ast_dependency(a.data.ast) mem_access.data_controllers = rop_utils.get_ast_controllers(init_state, a.data.ast, mem_access.data_dependencies) + mem_access.data_depth = a.data.ast.depth else: mem_access.data_constant = init_state.solver.eval(a.data.ast) elif a.action == "read": @@ -1001,6 +1009,7 @@ def _build_mem_change(self, read_action, write_action, gadget, init_state, final mem_change.op = write_action.data.ast.op mem_change.data_dependencies = data_dependencies mem_change.data_stack_controllers = data_stack_controllers + mem_change.data_depth = sym_data.depth mem_change.data_controllers = data_controllers mem_change.data_size = write_action.data.ast.size() mem_change.addr_size = write_action.addr.ast.size() @@ -1138,6 +1147,11 @@ def _analyze_mem_access(self, final_state, init_state, gadget): if a.action == "read": gadget.mem_reads.append(mem_access) if a.action == "write": + # if two writes at the same address, the second one overrules the first one + if mem_access.stack_offset is not None: + for m in gadget.mem_writes: + if mem_access.stack_offset == m.stack_offset: + gadget.mem_writes.remove(m) gadget.mem_writes.append(mem_access) return True diff --git a/angrop/rop_effect.py b/angrop/rop_effect.py index 39fccbda..dd04e372 100644 --- a/angrop/rop_effect.py +++ b/angrop/rop_effect.py @@ -20,6 +20,7 @@ def __init__(self): self.data_dependencies = set() self.data_controllers = set() self.data_stack_controllers = set() + self.data_depth: int | None = None self.addr_constant = None self.stack_offset = None # addr_constant - init_sp self.data_constant = None @@ -98,10 +99,12 @@ class RopRegPop: """ a class to represent register pop effect """ - def __init__(self, reg, bits): + def __init__(self, reg, bits, offset, depth): assert type(reg) is str self.reg = reg self.bits = bits + self.stack_offset = offset + self.ast_depth = depth def __hash__(self): return hash((self.reg, self.bits)) @@ -183,6 +186,31 @@ def num_sym_mem_access(self): res -= 1 return res + @property + def stack_writes(self): + """ + offsets relative to the final sp + """ + d = {} + for m in self.mem_writes: + if m.stack_offset is None: + continue + if m.data_depth != 1: + continue + # gadgets like push [rax]; ret has no data_controllers because it is a symbolic read + if not m.data_controllers: + continue + reg = list(m.data_controllers)[0] + if hasattr(self, "transit_type"): + if self.transit_type == 'jmp_reg' and self.pc_reg == reg: # type: ignore + continue + if self.transit_type == 'jmp_mem': # type: ignore + pc_vars = self.pc_target.variables # type: ignore + if any(v.startswith(f'sreg_{reg}') for v in pc_vars): + continue + d[m.stack_offset - self.stack_change] = reg + return d + @property def popped_regs(self): return {x.reg for x in self.reg_pops} diff --git a/angrop/rop_gadget.py b/angrop/rop_gadget.py index 039ab5ac..6b8b9e97 100644 --- a/angrop/rop_gadget.py +++ b/angrop/rop_gadget.py @@ -22,7 +22,7 @@ def __init__(self, addr): # for jmp_reg, which register it jumps to self.pc_reg: str = None # type: ignore # for jmp_mem, where it jumps to - self.pc_target: int = None # type: ignore + self.pc_target = None # type: ignore @property def self_contained(self): diff --git a/tests/test_chainbuilder.py b/tests/test_chainbuilder.py index 74dcf0b8..f0d94775 100644 --- a/tests/test_chainbuilder.py +++ b/tests/test_chainbuilder.py @@ -1197,6 +1197,57 @@ def test_mem_changer(): # add is tested test_add_to_mem +def test_push_pop_move(): + # simple push-pop + proj = angr.load_shellcode( + """ + push rax; jmp rsi + pop rsi; ret + pop rdi; ret + """, + "x86_64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.move_regs(rdi='rax') + assert chain + + proj = angr.load_shellcode( + """ + push rax; jmp qword ptr [rsi] + pop rsi; ret + pop rdi; ret + pop rcx; pop rdx; ret + mov [rcx], rdx; ret + """, + "x86_64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.move_regs(rdi='rax') + assert chain + + proj = angr.load_shellcode( + """ + push rax; call qword ptr [rsi] + pop rsi; ret + pop rdx; pop rdi; ret + pop rcx; pop rdx; ret + mov [rcx], rdx; ret + """, + "x86_64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.move_regs(rdi='rax') + assert chain + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} diff --git a/tests/test_find_gadgets.py b/tests/test_find_gadgets.py index cdb0392c..08ffd8e5 100644 --- a/tests/test_find_gadgets.py +++ b/tests/test_find_gadgets.py @@ -188,7 +188,7 @@ def local_multiprocess_analyze_gadget_list(): assert gadgets[0].addr == 0x4006d8 assert gadgets[1].addr == 0x400864 -def test_gadget_filtering(): +def test_gadget_filtering1(): proj = angr.Project(os.path.join(tests_dir, "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) rop.analyze_gadget(0x42bca5) diff --git a/tests/test_gadgets.py b/tests/test_gadgets.py index 6755eaf3..b4144853 100644 --- a/tests/test_gadgets.py +++ b/tests/test_gadgets.py @@ -496,6 +496,145 @@ def test_riscv_zero_register(): gs = rop.analyze_addr(0x0000000000011f32) assert len(gs) == 1 +def test_stack_writes(): + proj = angr.load_shellcode( + """ + push rax; jmp rsi + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + + assert g + assert g.stack_writes + assert 0 in g.stack_writes + assert g.stack_writes[0] == 'rax' + + proj = angr.load_shellcode( + """ + push rax; call rsi + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + + assert g + assert g.stack_writes + assert 8 in g.stack_writes + assert g.stack_writes[8] == 'rax' + + proj = angr.load_shellcode( + """ + push qword ptr [rax + rdx*2 - 0x7f]; jmp rcx + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + assert not g.stack_writes + + proj = angr.load_shellcode( + """ + push rdi; jmp qword ptr [rsi - 0x7f] + push -0x7e631700; push rdi; jmp qword ptr [rsi + 0x66] + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + addrs = [] + for gadgets in rop._chain_builder._reg_mover._push_pop_mover._stack_write_dict.values(): + addrs += [g.addr for g in gadgets] + assert 4 not in addrs + + proj = angr.load_shellcode( + """ + push rdi; popfq ; call rsi + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + g = rop.analyze_gadget(0) + assert len(g.mem_writes) == 1 + assert g.mem_writes[0].data_constant == 4 + + proj = angr.load_shellcode( + """ + push rax; jmp rax + push rax; jmp [rax] + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + g = rop.analyze_gadget(0) + assert not g.stack_writes + g = rop.analyze_gadget(3) + assert not g.stack_writes + +def test_reg_pops(): + proj = angr.load_shellcode( + """ + pop rax; ret + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + assert g.reg_pops + reg_pop = list(g.reg_pops)[0] + assert reg_pop.reg == 'rax' + assert reg_pop.bits == 64 + assert reg_pop.stack_offset == 0 + assert reg_pop.ast_depth == 1 + + proj = angr.load_shellcode( + """ + pop rax; mov eax, eax; ret + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + assert g.reg_pops + reg_pop = list(g.reg_pops)[0] + assert reg_pop.reg == 'rax' + assert reg_pop.bits == 32 + assert reg_pop.stack_offset == 0 + assert reg_pop.ast_depth > 1 + + proj = angr.load_shellcode( + """ + pop rax; pop rbx; ret + """, + "x86_64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0) + rbx_pops = [x for x in g.reg_pops if x.reg == 'rbx'] + assert len(rbx_pops) == 1 + assert rbx_pops[0].stack_offset == 8 + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')}