Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions soteria/lib/bv_values/encoding.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,150 @@ let encode_value (v : Svalue.t) =
|> Iter.to_list
|> bool_ands

module LetBinder = struct
module ValTbl = Hashtbl.Make (struct
type t = Svalue.t

let equal = Svalue.equal
let hash (v : t) = v.tag
end)

module VarSet = Hashset.Make (struct
type t = Svalue.Var.t

let equal = Svalue.Var.equal
let hash = Svalue.Var.to_int
let pp = Svalue.Var.pp
end)

(** Generates an appropriate [Var] that uniquely identifies this value. *)
let var_of_value (v : Svalue.t) = Svalue.Var.of_int v.tag

(** Like [var_of_value] but returns a properly typed [Svalue]. *)
let val_of_value (v : Svalue.t) = Svalue.mk_var (var_of_value v) v.node.ty

(** Returns the set of [Var]s mentionned in this value. *)
let dependencies v =
let deps = VarSet.with_capacity 31 in
Svalue.iter v (fun v ->
match v.node.kind with Var dep -> VarSet.add deps dep | _ -> ());
deps

(** [replace ?cond v] replaces [v] with a [Var] node with id [v.tag] if
[cond v] is true ([cond] defaults to [Fun.const true]); otherwise
recursively replaces the children. *)
let rec replace ?(cond = Fun.const true) (v : Svalue.t) =
let replace = replace ~cond in
if cond v then val_of_value v
else
match v.node.kind with
| Var _ | Bool _ | Float _ | BitVec _ -> v
| Ite (c, t, e) ->
let c' = replace c in
let t' = replace t in
let e' = replace e in
if c == c' && t == t' && e == e' then v else Svalue.Bool.ite c' t' e'
| Unop (unop, v1) ->
let v1' = replace v1 in
if v1 == v1' then v else Eval.eval_unop unop v1'
| Binop (binop, v1, v2) ->
let v1' = replace v1 in
let v2' = replace v2 in
if v1 == v1' && v2 == v2' then v else Eval.eval_binop binop v1' v2'
| Nop (Distinct, vs) ->
let vs', changed = List.map_changed replace vs in
if Stdlib.not changed then v else Svalue.Bool.distinct vs'
| Seq vs ->
let vs', changed = List.map_changed replace vs in
if Stdlib.not changed then v else Svalue.SSeq.mk ~seq_ty:v.node.ty vs'
| Ptr (l, o) ->
let l' = replace l in
let o' = replace o in
if l == l' && o == o' then v else Svalue.Ptr.mk l' o'

(** Makes the bindings for a map of bindings (ignores the pointed-to value).
Will return, for each binding, an appropriate variable for the binding
(using [var_of_value]), the encoded value (with possible sub-bindings
replaced), and a set containing the binding dependencies of this value. *)
let mk_bindings bind_map =
ValTbl.fold
(fun v _ acc ->
let var = var_of_value v in
let v =
replace ~cond:(fun w -> ValTbl.mem bind_map w && w.tag <> v.tag) v
in
let sexp = encode_value v in
let deps = dependencies v in
(var, sexp, deps) :: acc)
bind_map []

(** [order_bindings known_vars bindings] will order the given [bindings] into
groups of bindings; each group only needs the variables in the previous
groups or in [known_vars] to be defined. *)
let order_bindings known_vars bindings =
let for_all f s = Seq.for_all f (VarSet.to_seq s) in
let rec aux acc acc_full cur_visited
(bindings : (Svalue.Var.t * sexp * VarSet.t) list)
(next_bindings : (Svalue.Var.t * sexp * VarSet.t) list) :
(string * sexp) list list =
match bindings with
| [] -> (
match next_bindings with
| [] -> acc :: acc_full
| _ ->
List.iter (VarSet.add known_vars) cur_visited;
aux [] (acc :: acc_full) [] next_bindings [])
| ((v, sexp, deps) as binding) :: bindings ->
if VarSet.mem known_vars v then
aux acc acc_full cur_visited bindings next_bindings
else if for_all (VarSet.mem known_vars) deps then
aux
((Svalue.Var.to_string v, sexp) :: acc)
acc_full (v :: cur_visited) bindings next_bindings
else aux acc acc_full cur_visited bindings (binding :: next_bindings)
in
aux [] [] [] bindings []

(** Will calculate the let-bindings for the given value. Will only consider
bindings for values that occur at least [min_occurrences] times, and will
only return any bindings at all if at least [min_binds] are found matching
the above criterium. Will return the value with any relevant binding
substituted (but {b not} defined in the AST!) along with the bindings to
be applied after encoding. *)
let let_binds_for ?(min_binds = 5) ?(min_occurrences = 20) v =
let tag_counts = ValTbl.create 255 in

Svalue.iter v (fun v ->
match v.node.kind with
| Var _ | Bool _ | Float _ | BitVec _ -> ()
| _ ->
let curr =
ValTbl.find_opt tag_counts v |> Option.value ~default:0
in
ValTbl.replace tag_counts v (curr + 1));
ValTbl.filter_map_inplace
(fun _ count -> if count >= min_occurrences then Some count else None)
tag_counts;

let length = ValTbl.length tag_counts in
if length == 0 || length < min_binds then (v, [])
else
let bindings = mk_bindings tag_counts in

let known_vars = dependencies v in
let bindings = order_bindings known_vars bindings in

let v_subst = replace ~cond:(ValTbl.mem tag_counts) v in
(v_subst, bindings)

(** [apply_bindings bindings sexp] Applies the given groups of [bindings] to
[sexp]. *)
let apply_bindings bindings sexp =
List.fold_left (Fun.flip Simple_smt.let_) sexp bindings
end

let encode_value v =
let v, bindings = LetBinder.let_binds_for ~min_occurrences:10 v in
encode_value v |> LetBinder.apply_bindings bindings

let init_commands = []
3 changes: 3 additions & 0 deletions soteria/lib/solvers/smt_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ let bv_saddo l r = app_ "bvsaddo" [ l; r ]
let bv_umulo l r = app_ "bvumulo" [ l; r ]
let bv_smulo l r = app_ "bvsmulo" [ l; r ]

(* Redefine bool_ands, so that bool_ands [ x ] = x *)
let bool_ands = function [] -> bool_k true | [ p ] -> p | ps -> app_ "and" ps

(* Solver commands *)

let reset = simple_command [ "reset" ]