Skip to content

Commit

Permalink
Use a faster internal data structure for RDA. (angr#4051)
Browse files Browse the repository at this point in the history
* A preliminary implementation of Liveness.

* Do not access RDState.analysis if it's None.

* Lint code.
  • Loading branch information
ltfish authored Jul 31, 2023
1 parent 4419073 commit 02e5aca
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 179 deletions.
45 changes: 15 additions & 30 deletions angr/analyses/decompiler/ail_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
BinaryOp,
)

from ...errors import SimMemoryMissingError
from ...engines.light import SpOffset
from ...code_location import CodeLocation
from ...analyses.reaching_definitions.external_codeloc import ExternalCodeLocation
Expand All @@ -35,7 +34,7 @@

if TYPE_CHECKING:
from ailment.manager import Manager
from angr.analyses.reaching_definitions import ReachingDefinitionsAnalysis
from angr.analyses.reaching_definitions import ReachingDefinitionsModel


_l = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,7 +102,7 @@ def __init__(
):
self.func = func
self.func_graph = func_graph if func_graph is not None else func.graph
self._reaching_definitions: Optional[ReachingDefinitionsAnalysis] = None
self._reaching_definitions: Optional["ReachingDefinitionsModel"] = None
self._propagator = None

self._remove_dead_memdefs = remove_dead_memdefs
Expand Down Expand Up @@ -191,25 +190,20 @@ def _handler(node):
AILGraphWalker(self.func_graph, _handler, replace_nodes=True).walk()
self.blocks = {}

def _compute_reaching_definitions(self) -> "ReachingDefinitionsAnalysis":
def _compute_reaching_definitions(self) -> "ReachingDefinitionsModel":
# Computing reaching definitions or return the cached one
if self._reaching_definitions is not None:
return self._reaching_definitions
rd = self.project.analyses.ReachingDefinitions(
subject=self.func,
func_graph=self.func_graph,
# init_context=(), <-- in case of fire break glass
observe_all=True, # observe_callback=self._simplify_function_rd_observe_callback
observe_all=False,
use_callee_saved_regs_at_return=self._use_callee_saved_regs_at_return,
)
).model
self._reaching_definitions = rd
return rd

@staticmethod
# pylint:disable=unused-argument
def _simplify_function_rd_observe_callback(ob_type, **kwargs):
return ob_type == "node" or (ob_type == "insn" and kwargs.get("op_type", None) == OP_BEFORE)

def _compute_propagation(self):
# Propagate expressions or return the existing result
if self._propagator is not None:
Expand Down Expand Up @@ -772,7 +766,7 @@ def _unify_local_variables(self) -> bool:
# ensure the expression that we want to replace with is still up-to-date
replace_with_original_def = self._find_atom_def_at(replace_with, rd, def_.codeloc)
if replace_with_original_def is not None and not self._check_atom_last_def(
replace_with, used_expr.size, u.ins_addr, rd, replace_with_original_def
replace_with, u, rd, replace_with_original_def
):
all_uses_replaced = False
continue
Expand Down Expand Up @@ -813,26 +807,18 @@ def _unify_local_variables(self) -> bool:
@staticmethod
def _find_atom_def_at(atom, rd, codeloc: CodeLocation) -> Optional[Definition]:
if isinstance(atom, Register):
observ = rd.observed_results[("insn", codeloc.ins_addr, OP_BEFORE)]
try:
reg_vals = observ.register_definitions.load(atom.reg_offset, size=atom.size)
defs = list(observ.extract_defs_from_mv(reg_vals))
return defs[0] if len(defs) == 1 else None
except SimMemoryMissingError:
pass
defs = rd.get_defs(atom, codeloc, OP_BEFORE)
return next(iter(defs)) if len(defs) == 1 else None

return None

@staticmethod
def _check_atom_last_def(atom, size, ins_addr, rd, the_def) -> bool:
def _check_atom_last_def(atom, codeloc, rd, the_def) -> bool:
if isinstance(atom, Register):
observ = rd.observed_results[("insn", ins_addr, OP_BEFORE)]
try:
reg_vals = observ.register_definitions.load(atom.reg_offset, size=size)
for existing_def in observ.extract_defs_from_mv(reg_vals):
if existing_def.codeloc != the_def.codeloc:
return False
except SimMemoryMissingError:
pass
defs = rd.get_defs(atom, codeloc, OP_BEFORE)
for d in defs:
if d.codeloc != the_def.codeloc:
return False

return True

Expand Down Expand Up @@ -961,10 +947,9 @@ def _fold_call_exprs(self) -> bool:
defsite_defs_per_atom = defaultdict(set)
for dd in defsite_all_expr_uses:
defsite_defs_per_atom[dd.atom].add(dd)
usesite_rdstate = rd.observed_results[("stmt", (u.block_addr, u.block_idx, u.stmt_idx), 0)]
usesite_expr_def_outdated = False
for defsite_expr_atom, defsite_expr_uses in defsite_defs_per_atom.items():
usesite_expr_uses = set(usesite_rdstate.get_definitions(defsite_expr_atom))
usesite_expr_uses = set(rd.get_defs(defsite_expr_atom, u, OP_BEFORE))
if not usesite_expr_uses:
# the atom is not defined at the use site - it's fine
continue
Expand Down
18 changes: 11 additions & 7 deletions angr/analyses/decompiler/block_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,19 @@ def _compute_propagation(self, block):

def _compute_reaching_definitions(self, block):
def observe_callback(ob_type, addr=None, op_type=None, **kwargs) -> bool: # pylint:disable=unused-argument
return ob_type == "stmt" or ob_type == "node" and addr == block.addr and op_type == OP_AFTER
return ob_type == "node" and addr == block.addr and op_type == OP_AFTER

if self._reaching_definitions is None:
self._reaching_definitions = self.project.analyses[ReachingDefinitionsAnalysis].prep()(
subject=block,
track_tmps=True,
stack_pointer_tracker=self._stack_pointer_tracker,
observe_all=False,
observe_callback=observe_callback,
self._reaching_definitions = (
self.project.analyses[ReachingDefinitionsAnalysis]
.prep()(
subject=block,
track_tmps=True,
stack_pointer_tracker=self._stack_pointer_tracker,
observe_all=False,
observe_callback=observe_callback,
)
.model
)
return self._reaching_definitions

Expand Down
56 changes: 14 additions & 42 deletions angr/analyses/propagator/engine_ail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import claripy
from ailment import Stmt, Expr

from angr.errors import SimMemoryMissingError
from angr.knowledge_plugins.propagations.prop_value import PropValue, Detail
from angr.analyses.reaching_definitions.external_codeloc import ExternalCodeLocation
from angr.knowledge_plugins.key_definitions.atoms import Register
from ...utils.constants import is_alignment_mask
from ...engines.light import SimEngineLightAILMixin
from ...sim_variable import SimStackVariable, SimMemoryVariable
from ..reaching_definitions.reaching_definitions import OP_BEFORE, OP_AFTER
from ..reaching_definitions.reaching_definitions import OP_BEFORE
from .engine_base import SimEnginePropagatorBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -378,26 +378,17 @@ def _test_concatenation(pv: PropValue):
reg_defat = None
if self._reaching_definitions is not None:
codeloc = self._codeloc()
key = "stmt", (codeloc.block_addr, codeloc.block_idx, codeloc.stmt_idx), OP_BEFORE
if key in self._reaching_definitions.observed_results:
o = self._reaching_definitions.observed_results[key]
try:
mv = o.register_definitions.load(expr.reg_offset, size=expr.size)
except SimMemoryMissingError:
mv = None
if mv is not None:
reg_defs = o.extract_defs_from_mv(mv)
reg_defat_codelocs = {reg_def.codeloc for reg_def in reg_defs}
if len(reg_defat_codelocs) == 1:
reg_defat = next(iter(reg_defat_codelocs))
defat_key = "stmt", (reg_defat.block_addr, reg_defat.block_idx, reg_defat.stmt_idx), OP_BEFORE
if defat_key not in self._reaching_definitions.observed_results:
# the observation point does not exist. probably it's because te observation point is in a
# callee function.
reg_defat = None
if isinstance(reg_defat, ExternalCodeLocation):
# there won't be an observed result for external code location. give up
reg_defat = None
reg_defat_defs = self._reaching_definitions.get_defs(
Register(expr.reg_offset, expr.size), codeloc, OP_BEFORE
)
reg_defat_codelocs = {reg_def.codeloc for reg_def in reg_defat_defs}
if len(reg_defat_codelocs) == 1:
reg_defat = next(iter(reg_defat_codelocs))
if reg_defat.stmt_idx is None:
# the observation point is in a callee function
reg_defat = None
if isinstance(reg_defat, ExternalCodeLocation):
reg_defat = None

if new_expr is not None:
# check if this new_expr uses any expression that has been overwritten
Expand Down Expand Up @@ -1139,36 +1130,17 @@ def is_using_outdated_def(
l.warning("Unknown where the expression is defined. Assume the definition is out-dated.")
return True, False

key_defat = "stmt", (expr_defat.block_addr, expr_defat.block_idx, expr_defat.stmt_idx), OP_AFTER
if key_defat not in self._reaching_definitions.observed_results:
l.warning(
"Required reaching definition state at instruction address %#x is not found. Assume the definition is "
"out-dated.",
expr_defat.ins_addr,
)
return True, False

key_currloc = "stmt", (current_loc.block_addr, current_loc.block_idx, current_loc.stmt_idx), OP_BEFORE
if key_currloc not in self._reaching_definitions.observed_results:
l.warning(
"Required reaching definition state at instruction address %#x is not found. Assume the definition is "
"out-dated.",
current_loc.ins_addr,
)
return True, False

from .outdated_definition_walker import OutdatedDefinitionWalker # pylint:disable=import-outside-toplevel

walker = OutdatedDefinitionWalker(
expr,
expr_defat,
self._reaching_definitions.observed_results[key_defat],
current_loc,
self._reaching_definitions.observed_results[key_currloc],
self.state,
self.arch,
avoid=avoid,
extract_offset_to_sp=self.extract_offset_to_sp,
rda=self._reaching_definitions,
)
walker.walk_expression(expr)
return walker.out_dated, walker.has_avoid
Expand Down
70 changes: 17 additions & 53 deletions angr/analyses/propagator/outdated_definition_walker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# pylint:disable=consider-using-in
from typing import Optional, Callable, TYPE_CHECKING

from ailment import Block, Stmt, Expr, AILBlockWalker

from ...errors import SimMemoryMissingError
from ...code_location import CodeLocation
from ...knowledge_plugins.key_definitions.constants import OP_BEFORE, OP_AFTER
from ...knowledge_plugins.key_definitions import atoms

if TYPE_CHECKING:
from archinfo import Arch
from .propagator import PropagatorAILState
from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues
from angr.knowledge_plugins.key_definitions import LiveDefinitions
from angr.analyses.reaching_definitions import ReachingDefinitionsModel


class OutdatedDefinitionWalker(AILBlockWalker):
Expand All @@ -21,20 +22,17 @@ def __init__(
self,
expr,
expr_defat: CodeLocation,
livedefs_defat: "LiveDefinitions",
current_loc: CodeLocation,
livedefs_currentloc: "LiveDefinitions",
state: "PropagatorAILState",
arch: "Arch",
avoid: Optional[Expr.Expression] = None,
extract_offset_to_sp: Callable = None,
rda: "ReachingDefinitionsModel" = None,
):
super().__init__()
self.expr = expr
self.expr_defat = expr_defat
self.livedefs_defat = livedefs_defat
self.current_loc = current_loc
self.livedefs_currentloc = livedefs_currentloc
self.state = state
self.avoid = avoid
self.arch = arch
Expand All @@ -45,6 +43,7 @@ def __init__(
self.expr_handlers[Expr.VEXCCallExpression] = self._handle_VEXCCallExpression
self.out_dated = False
self.has_avoid = False
self.rda = rda

# pylint:disable=unused-argument
def _handle_Tmp(self, expr_idx: int, expr: Expr.Tmp, stmt_idx: int, stmt: Stmt.Assignment, block: Optional[Block]):
Expand All @@ -63,51 +62,28 @@ def _handle_Register(
self.has_avoid = True

# is the used register still alive at this point?
try:
reg_vals: "MultiValues" = self.livedefs_defat.register_definitions.load(expr.reg_offset, size=expr.size)
defs_defat = list(self.livedefs_defat.extract_defs_from_mv(reg_vals))
except SimMemoryMissingError:
defs_defat = []

try:
reg_vals: "MultiValues" = self.livedefs_currentloc.register_definitions.load(
expr.reg_offset, size=expr.size
)
defs_currentloc = list(self.livedefs_currentloc.extract_defs_from_mv(reg_vals))
except SimMemoryMissingError:
defs_currentloc = []
defs_defat = self.rda.get_defs(atoms.Register(expr.reg_offset, expr.size), self.expr_defat, OP_AFTER)
defs_currentloc = self.rda.get_defs(atoms.Register(expr.reg_offset, expr.size), self.current_loc, OP_BEFORE)

codelocs_defat = {def_.codeloc for def_ in defs_defat}
codelocs_currentloc = {def_.codeloc for def_ in defs_currentloc}
if not (codelocs_defat and codelocs_currentloc and codelocs_defat == codelocs_currentloc):
self.out_dated = True

def _handle_Load(self, expr_idx: int, expr: Expr.Load, stmt_idx: int, stmt: Stmt.Statement, block: Optional[Block]):
if self.avoid is not None and ( # pylint:disable=consider-using-in
expr == self.avoid or expr.addr == self.avoid
):
if self.avoid is not None and (expr == self.avoid or expr.addr == self.avoid):
self.has_avoid = True

if isinstance(expr.addr, Expr.StackBaseOffset):
sp_offset = self.extract_offset_to_sp(expr.addr)

if sp_offset is not None:
stack_addr = self.livedefs_defat.stack_offset_to_stack_addr(sp_offset)
try:
mem_vals: "MultiValues" = self.livedefs_defat.stack_definitions.load(
stack_addr, size=expr.size, endness=expr.endness
)
defs_defat = list(self.livedefs_defat.extract_defs_from_mv(mem_vals))
except SimMemoryMissingError:
defs_defat = []

try:
mem_vals: "MultiValues" = self.livedefs_currentloc.stack_definitions.load(
stack_addr, size=expr.size, endness=expr.endness
)
defs_currentloc = list(self.livedefs_defat.extract_defs_from_mv(mem_vals))
except SimMemoryMissingError:
defs_currentloc = []
defs_defat = self.rda.get_defs(
atoms.MemoryLocation(atoms.SpOffset(expr.bits, sp_offset), expr.size), self.expr_defat, OP_AFTER
)
defs_currentloc = self.rda.get_defs(
atoms.MemoryLocation(atoms.SpOffset(expr.bits, sp_offset), expr.size), self.current_loc, OP_BEFORE
)

codelocs_defat = {def_.codeloc for def_ in defs_defat}
codelocs_currentloc = {def_.codeloc for def_ in defs_currentloc}
Expand All @@ -126,21 +102,9 @@ def _handle_Load(self, expr_idx: int, expr: Expr.Load, stmt_idx: int, stmt: Stmt

elif isinstance(expr.addr, Expr.Const):
mem_addr = expr.addr.value
try:
mem_vals: "MultiValues" = self.livedefs_defat.memory_definitions.load(
mem_addr, size=expr.size, endness=expr.endness
)
defs_defat = list(self.livedefs_defat.extract_defs_from_mv(mem_vals))
except SimMemoryMissingError:
defs_defat = []

try:
mem_vals: "MultiValues" = self.livedefs_currentloc.memory_definitions.load(
mem_addr, size=expr.size, endness=expr.endness
)
defs_currentloc = list(self.livedefs_defat.extract_defs_from_mv(mem_vals))
except SimMemoryMissingError:
defs_currentloc = []
defs_defat = self.rda.get_defs(atoms.MemoryLocation(mem_addr, expr.size), self.expr_defat, OP_AFTER)
defs_currentloc = self.rda.get_defs(atoms.MemoryLocation(mem_addr, expr.size), self.current_loc, OP_BEFORE)

codelocs_defat = {def_.codeloc for def_ in defs_defat}
codelocs_currentloc = {def_.codeloc for def_ in defs_currentloc}
Expand Down
22 changes: 14 additions & 8 deletions angr/analyses/reaching_definitions/engine_ail.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,26 @@ def _external_codeloc(self):

def _set_codeloc(self):
# TODO do we want a better mechanism to specify context updates?
self.state.move_codelocs(
CodeLocation(
self.block.addr,
self.stmt_idx,
ins_addr=self.ins_addr,
block_idx=self.block.idx,
context=self.state.codeloc.context,
)
new_codeloc = CodeLocation(
self.block.addr,
self.stmt_idx,
ins_addr=self.ins_addr,
block_idx=self.block.idx,
context=self.state.codeloc.context,
)
self.state.move_codelocs(new_codeloc)
self.state.analysis.model.at_new_stmt(new_codeloc)

#
# AIL statement handlers
#

def _process_Stmt(self, whitelist=None):
super()._process_Stmt(whitelist=whitelist)

if self.state.analysis:
self.state.analysis.model.complete_loc()

def _handle_Stmt(self, stmt):
if self.state.analysis:
self.state.analysis.stmt_observe(self.stmt_idx, stmt, self.block, self.state, OP_BEFORE)
Expand Down
Loading

0 comments on commit 02e5aca

Please sign in to comment.