diff --git a/soteria/lib/bv_values/encoding.ml b/soteria/lib/bv_values/encoding.ml index e8c9b44a5..6403dd411 100644 --- a/soteria/lib/bv_values/encoding.ml +++ b/soteria/lib/bv_values/encoding.ml @@ -132,4 +132,150 @@ let encode_value (v : Svalue.t) = |> Iter.to_list |> bool_ands +module LetBinder = struct + module ValTbl = Hashtbl.Make (struct + type t = Svalue.t + + let equal = Svalue.equal + let hash (v : t) = v.tag + end) + + module VarSet = Hashset.Make (struct + type t = Svalue.Var.t + + let equal = Svalue.Var.equal + let hash = Svalue.Var.to_int + let pp = Svalue.Var.pp + end) + + (** Generates an appropriate [Var] that uniquely identifies this value. *) + let var_of_value (v : Svalue.t) = Svalue.Var.of_int v.tag + + (** Like [var_of_value] but returns a properly typed [Svalue]. *) + let val_of_value (v : Svalue.t) = Svalue.mk_var (var_of_value v) v.node.ty + + (** Returns the set of [Var]s mentionned in this value. *) + let dependencies v = + let deps = VarSet.with_capacity 31 in + Svalue.iter v (fun v -> + match v.node.kind with Var dep -> VarSet.add deps dep | _ -> ()); + deps + + (** [replace ?cond v] replaces [v] with a [Var] node with id [v.tag] if + [cond v] is true ([cond] defaults to [Fun.const true]); otherwise + recursively replaces the children. *) + let rec replace ?(cond = Fun.const true) (v : Svalue.t) = + let replace = replace ~cond in + if cond v then val_of_value v + else + match v.node.kind with + | Var _ | Bool _ | Float _ | BitVec _ -> v + | Ite (c, t, e) -> + let c' = replace c in + let t' = replace t in + let e' = replace e in + if c == c' && t == t' && e == e' then v else Svalue.Bool.ite c' t' e' + | Unop (unop, v1) -> + let v1' = replace v1 in + if v1 == v1' then v else Eval.eval_unop unop v1' + | Binop (binop, v1, v2) -> + let v1' = replace v1 in + let v2' = replace v2 in + if v1 == v1' && v2 == v2' then v else Eval.eval_binop binop v1' v2' + | Nop (Distinct, vs) -> + let vs', changed = List.map_changed replace vs in + if Stdlib.not changed then v else Svalue.Bool.distinct vs' + | Seq vs -> + let vs', changed = List.map_changed replace vs in + if Stdlib.not changed then v else Svalue.SSeq.mk ~seq_ty:v.node.ty vs' + | Ptr (l, o) -> + let l' = replace l in + let o' = replace o in + if l == l' && o == o' then v else Svalue.Ptr.mk l' o' + + (** Makes the bindings for a map of bindings (ignores the pointed-to value). + Will return, for each binding, an appropriate variable for the binding + (using [var_of_value]), the encoded value (with possible sub-bindings + replaced), and a set containing the binding dependencies of this value. *) + let mk_bindings bind_map = + ValTbl.fold + (fun v _ acc -> + let var = var_of_value v in + let v = + replace ~cond:(fun w -> ValTbl.mem bind_map w && w.tag <> v.tag) v + in + let sexp = encode_value v in + let deps = dependencies v in + (var, sexp, deps) :: acc) + bind_map [] + + (** [order_bindings known_vars bindings] will order the given [bindings] into + groups of bindings; each group only needs the variables in the previous + groups or in [known_vars] to be defined. *) + let order_bindings known_vars bindings = + let for_all f s = Seq.for_all f (VarSet.to_seq s) in + let rec aux acc acc_full cur_visited + (bindings : (Svalue.Var.t * sexp * VarSet.t) list) + (next_bindings : (Svalue.Var.t * sexp * VarSet.t) list) : + (string * sexp) list list = + match bindings with + | [] -> ( + match next_bindings with + | [] -> acc :: acc_full + | _ -> + List.iter (VarSet.add known_vars) cur_visited; + aux [] (acc :: acc_full) [] next_bindings []) + | ((v, sexp, deps) as binding) :: bindings -> + if VarSet.mem known_vars v then + aux acc acc_full cur_visited bindings next_bindings + else if for_all (VarSet.mem known_vars) deps then + aux + ((Svalue.Var.to_string v, sexp) :: acc) + acc_full (v :: cur_visited) bindings next_bindings + else aux acc acc_full cur_visited bindings (binding :: next_bindings) + in + aux [] [] [] bindings [] + + (** Will calculate the let-bindings for the given value. Will only consider + bindings for values that occur at least [min_occurrences] times, and will + only return any bindings at all if at least [min_binds] are found matching + the above criterium. Will return the value with any relevant binding + substituted (but {b not} defined in the AST!) along with the bindings to + be applied after encoding. *) + let let_binds_for ?(min_binds = 5) ?(min_occurrences = 20) v = + let tag_counts = ValTbl.create 255 in + + Svalue.iter v (fun v -> + match v.node.kind with + | Var _ | Bool _ | Float _ | BitVec _ -> () + | _ -> + let curr = + ValTbl.find_opt tag_counts v |> Option.value ~default:0 + in + ValTbl.replace tag_counts v (curr + 1)); + ValTbl.filter_map_inplace + (fun _ count -> if count >= min_occurrences then Some count else None) + tag_counts; + + let length = ValTbl.length tag_counts in + if length == 0 || length < min_binds then (v, []) + else + let bindings = mk_bindings tag_counts in + + let known_vars = dependencies v in + let bindings = order_bindings known_vars bindings in + + let v_subst = replace ~cond:(ValTbl.mem tag_counts) v in + (v_subst, bindings) + + (** [apply_bindings bindings sexp] Applies the given groups of [bindings] to + [sexp]. *) + let apply_bindings bindings sexp = + List.fold_left (Fun.flip Simple_smt.let_) sexp bindings +end + +let encode_value v = + let v, bindings = LetBinder.let_binds_for ~min_occurrences:10 v in + encode_value v |> LetBinder.apply_bindings bindings + let init_commands = [] diff --git a/soteria/lib/solvers/smt_utils.ml b/soteria/lib/solvers/smt_utils.ml index 33d34a451..c926d6b4e 100644 --- a/soteria/lib/solvers/smt_utils.ml +++ b/soteria/lib/solvers/smt_utils.ml @@ -115,6 +115,9 @@ let bv_saddo l r = app_ "bvsaddo" [ l; r ] let bv_umulo l r = app_ "bvumulo" [ l; r ] let bv_smulo l r = app_ "bvsmulo" [ l; r ] +(* Redefine bool_ands, so that bool_ands [ x ] = x *) +let bool_ands = function [] -> bool_k true | [ p ] -> p | ps -> app_ "and" ps + (* Solver commands *) let reset = simple_command [ "reset" ]