Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions spectec/src/exe-spectec/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type pass =
| TypeFamilyRemoval
| Else
| Undep
| SubExpansion
| Uncaseremoval
| AliasDemut
| Ite
Expand All @@ -41,6 +42,7 @@ let all_passes = [
Else;
Uncaseremoval;
Sideconditions;
SubExpansion;
Sub;
AliasDemut;
]
Expand Down Expand Up @@ -104,6 +106,7 @@ let pass_flag = function
| AliasDemut -> "alias-demut"
| Else -> "else"
| Undep -> "remove-indexed-types"
| SubExpansion -> "sub-expansion"
| Uncaseremoval -> "uncase-removal"
| Ite -> "ite"

Expand All @@ -115,6 +118,7 @@ let pass_desc = function
| TypeFamilyRemoval -> "Transform Type families into sum types"
| Else -> "Eliminate the otherwise premise in relations"
| Undep -> "Transform indexed types into types with well-formedness predicates"
| SubExpansion -> "Expands subtype matching"
| Uncaseremoval -> "Eliminate the uncase expression"
| AliasDemut -> "Lifts type aliases out of mutual groups"
| Ite -> "If-then-else introduction"
Expand All @@ -128,6 +132,7 @@ let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function
| TypeFamilyRemoval -> Middlend.Typefamilyremoval.transform
| Else -> Middlend.Else.transform
| Undep -> Middlend.Undep.transform
| SubExpansion -> Middlend.Subexpansion.transform
| Uncaseremoval -> Middlend.Uncaseremoval.transform
| AliasDemut -> Middlend.AliasDemut.transform
| Ite -> Middlend.Ite.transform
Expand Down
1 change: 1 addition & 0 deletions spectec/src/middlend/dune
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
else
undep
utils
subexpansion
uncaseremoval
aliasDemut
ite
Expand Down
241 changes: 241 additions & 0 deletions spectec/src/middlend/subexpansion.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
(*
This pass expands the subtyping patterns that appear in the LHS of
function clauses and type family arguments.

It achieves this through the following steps:
* For each argument, we collect every unique sub expression.
* Then, for each sub expression, we collect every case that is
possible in the subtype. If the specific case additionally carries
values, then we generate binds to add in the function scope.
* With all of these cases, for each unique sub expression, we compute
the cartesian product in order to absolutely grab all the possible cases.
See $cvtop to see how this might be done.
* Once we have calculated the product, we generate a subst for each product
and proceed to generate the clause/type instance.
* Finally, we filter out binds that appear in the subst.

For example, take the following types and function:

syntax A = t1 nat | t2 nat nat
syntax B = t1 nat | t2 nat nat | t3 | t4

def $foo(B) : nat
def $foo(x : A <: B) = 1
def $foo(t3) = 2
def $foo(t4) = 3

Would be transformed as such:

def $foo(B) : nat
def $foo{n : nat}(t1(n)) = 1
def $foo{n1 : nat, n2 : nat}(t2(n1, n2)) = 1
def $foo(t3) = 2
def $foo(t4) = 3

*)

open Util
open Source
open Il.Ast
open Il

(* Errors *)

let error at msg = Error.error at "sub expression expansion" msg

(* Environment *)

(* Global IL env*)
let env_ref = ref Il.Env.empty

let empty_tuple_exp at = TupE [] $$ at % (TupT [] $ at)

(* Computes the cartesian product of a given list. *)
let product_of_lists (lists : 'a list list) =
List.fold_left (fun acc lst ->
List.concat_map (fun existing ->
List.map (fun v -> v :: existing) lst) acc) [[]] lists

let product_of_lists_append (lists : 'a list list) =
List.fold_left (fun acc lst ->
List.concat_map (fun existing ->
List.map (fun v -> existing @ [v]) lst) acc) [[]] lists

let get_bind_id b =
match b.it with
| ExpB (id, _) | TypB id
| DefB (id, _, _) | GramB (id, _, _) -> id.it

let eq_sube (id, t1, t2) (id', t1', t2') =
Eq.eq_id id id' && Eq.eq_typ t1 t1' && Eq.eq_typ t2 t2'

let rec collect_sube_exp e =
let c_func = collect_sube_exp in
match e.it with
(* Assumption - nested sub expressions do not exist. Must also be a varE. *)
Comment on lines +72 to +75
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You used the Iter module in #191, any reason to not use it here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real reason, probably was trying to get a more functional approach instead of using Iter.

Also since its only traversing through exp I maybe thought it was overkill.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. No issue, was just curious.

I wish we didn't have to duplicate the traversal logic in every pass, but 🤷🏻

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I really wish there was an easier way to make a generic traversal pass that didn't involve side effects 😅 . I'll look into it more to see if something can be done

| SubE ({it = VarE id; _}, t1, t2) -> [id, t1, t2]
| CallE (_, args) -> List.concat_map collect_sube_arg args
| StrE fields -> List.concat_map (fun (_a, e1) -> c_func e1) fields
| UnE (_, _, e1) | CvtE (e1, _, _) | LiftE e1 | TheE e1 | OptE (Some e1)
| ProjE (e1, _) | UncaseE (e1, _)
| CaseE (_, e1) | LenE e1 | DotE (e1, _) -> c_func e1
| BinE (_, _, e1, e2) | CmpE (_, _, e1, e2)
| CompE (e1, e2) | MemE (e1, e2)
| CatE (e1, e2) | IdxE (e1, e2) -> c_func e1 @ c_func e2
| TupE exps | ListE exps -> List.concat_map collect_sube_exp exps
| SliceE (e1, e2, e3) -> c_func e1 @ c_func e2 @ c_func e3
| UpdE (e1, p, e2)
| ExtE (e1, p, e2) -> c_func e1 @ collect_fcalls_path p @ c_func e2
| IterE (e1, (iter, id_exp_pairs)) ->
c_func e1 @ collect_sube_iter iter @
List.concat_map (fun (_, exp) -> c_func exp) id_exp_pairs
| _ -> []

and collect_sube_iter i =
match i with
| ListN (e1, _) -> collect_sube_exp e1
| _ -> []

and collect_sube_arg a =
match a.it with
| ExpA exp -> collect_sube_exp exp
| _ -> []

and collect_fcalls_path p =
match p.it with
| RootP -> []
| IdxP (p, e) -> collect_fcalls_path p @ collect_sube_exp e
| SliceP (p, e1, e2) -> collect_fcalls_path p @ collect_sube_exp e1 @ collect_sube_exp e2
| DotP (p, _) -> collect_fcalls_path p

let check_matching c_args match_args =
Option.is_some (try
Eval.match_list Eval.match_arg !env_ref Subst.empty c_args match_args
with Eval.Irred -> None)

let get_case_typ t =
match t.it with
| TupT typs -> typs
| _ -> [VarE ("_" $ t.at) $$ t.at % t, t]

let collect_all_instances case_typ ids at inst =
match inst.it with
| InstD (_, _, {it = VariantT typcases; _}) when
List.for_all (fun (_, (_, t, _), _) -> t.it = TupT []) typcases ->
List.map (fun (m, _, _) -> ([], CaseE (m, empty_tuple_exp no_region) $$ at % case_typ)) typcases
| InstD (_, _, {it = VariantT typcases; _}) ->
let _, new_cases =
List.fold_left (fun (ids', acc) (m, (_, t, _), _) ->
let typs = get_case_typ t in
let new_binds, typs' = Utils.improve_ids_binders ids' true t.at typs in
let exps = List.map fst typs' in
let tup_exp = TupE exps $$ at % t in
let case_exp = CaseE (m, tup_exp) $$ at % case_typ in
let new_ids = List.map get_bind_id new_binds in
(new_ids @ ids', (new_binds, case_exp) :: acc)
) (ids, []) typcases
in
new_cases
| _ -> error at "Expected a variant type"

let rec collect_all_instances_typ ids at typ =
match typ.it with
| VarT (var_id, dep_args) -> let (_, insts) = Il.Env.find_typ !env_ref var_id in
(match insts with
| [] -> [] (* Should never happen *)
| _ ->
let inst_opt = List.find_opt (fun inst ->
match inst.it with
| InstD (_, args, _) -> check_matching dep_args args
) insts in
match inst_opt with
| None -> error at ("Could not find specific instance for typ: " ^ Il.Print.string_of_typ typ)
| Some inst -> collect_all_instances typ ids at inst
)
| TupT exp_typ_pairs ->
let instances_list = List.map (fun (_, t) ->
collect_all_instances_typ ids at t
) exp_typ_pairs in
let product = product_of_lists_append instances_list in
List.map (fun lst ->
let binds, exps = List.split lst in
List.concat binds, TupE exps $$ at % typ) product
| _ -> []

let generate_subst_list lhs binds =
(* Collect all unique sub expressions for each argument *)
let subs = List.concat_map (fun a ->
Lib.List.nub eq_sube (collect_sube_arg a)
) lhs in
let ids = List.map get_bind_id binds in

(* Collect all cases for the specific subtype, generating any potential binds in the process *)
let _, cases =
List.fold_left (fun (binds, cases) (id, t1, _) ->
let ids' = List.map get_bind_id binds @ ids in
let instances = collect_all_instances_typ ids' id.at t1 in
let new_binds = List.concat_map fst instances in
let cases'' = List.map (fun case_data -> (id, case_data)) instances in
(new_binds @ binds, cases'' :: cases)
) (binds, []) subs
in

(* Compute cartesian product for all cases and generate a subst *)
let cases' = product_of_lists cases in
List.map (List.fold_left (fun (binds, subst) (id, (binds', exp)) ->
(binds' @ binds, Il.Subst.add_varid subst id exp)) ([], Il.Subst.empty)
) cases'

let t_clause clause =
match clause.it with
| DefD (binds, lhs, rhs, prems) ->
let subst_list = generate_subst_list lhs binds in
List.map (fun (binds', subst) ->
(* Subst all occurrences of the subE id *)
let new_lhs = Il.Subst.subst_args subst lhs in
let new_prems = Il.Subst.subst_list Il.Subst.subst_prem subst prems in
let new_rhs = Il.Subst.subst_exp subst rhs in

(* Filtering binds - only the subst ids *)
let binds_filtered = Lib.List.filter_not (fun b -> match b.it with
| ExpB (id, _) -> Il.Subst.mem_varid subst id
| _ -> false
) (binds' @ binds) in
let new_binds, _ = Il.Subst.subst_binds subst binds_filtered in
(* Reduction is done here to remove subtyping expressions *)
DefD (new_binds, List.map (Il.Eval.reduce_arg !env_ref) new_lhs, new_rhs, new_prems) $ clause.at
) subst_list

let t_inst inst =
match inst.it with
| InstD (binds, lhs, deftyp) ->
let subst_list = generate_subst_list lhs binds in
List.map (fun (binds', subst) ->
(* Subst all occurrences of the subE id *)
let new_lhs = Il.Subst.subst_args subst lhs in
let new_rhs = Il.Subst.subst_deftyp subst deftyp in

(* Filtering binds - only the subst ids *)
let binds_filtered = Lib.List.filter_not (fun b -> match b.it with
| ExpB (id, _) -> Il.Subst.mem_varid subst id
| _ -> false
) (binds' @ binds) in

let new_binds, _ = Il.Subst.subst_binds subst binds_filtered in
(* Reduction is done here to remove subtyping expressions *)
InstD (new_binds, List.map (Il.Eval.reduce_arg !env_ref) new_lhs, new_rhs) $ inst.at
) subst_list


let rec t_def def =
match def.it with
| RecD defs -> { def with it = RecD (List.map t_def defs) }
| DecD (id, params, typ, clauses) ->
{ def with it = DecD (id, params, typ, List.concat_map t_clause clauses) }
| TypD (id, params, insts) ->
{ def with it = TypD (id, params, List.concat_map t_inst insts)}
| _ -> def

let transform (defs : script) =
env_ref := Il.Env.env_of_script defs;
List.map (t_def) defs
1 change: 1 addition & 0 deletions spectec/src/middlend/subexpansion.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val transform : Il.Ast.script -> Il.Ast.script
4 changes: 2 additions & 2 deletions spectec/src/middlend/undep.ml
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ let create_well_formed_predicate id env inst =
| TupT tups -> tups
| _ -> [(VarE ("_" $ id.at) $$ id.at % case_typ, case_typ)]
in
let extra_binds, t_pairs = Utils.improve_ids_binders false id.at exp_typ_pairs in
let extra_binds, t_pairs = Utils.improve_ids_binders [] false id.at exp_typ_pairs in
let new_binds = case_binds @ extra_binds in
let exp = TupE (List.map fst t_pairs) $$ at % (TupT t_pairs $ at) in
let case_exp = CaseE (m, exp) $$ at % user_typ in
Expand Down Expand Up @@ -386,7 +386,7 @@ let create_well_formed_predicate id env inst =
([wrapped], tups, prems)
) typfields) in

let (rule_binds, pairs') = Utils.improve_ids_binders true at pairs in
let (rule_binds, pairs') = Utils.improve_ids_binders [] true at pairs in
let new_prems = (List.filter_map get_exp_typ rule_binds |> List.concat_map (get_wf_pred env)) @ rule_prems in
let str_exp = StrE (List.map2 (fun a ((e, t), wrapped) ->
let tupt = TupT [(e, t)] $ at in
Expand Down
18 changes: 15 additions & 3 deletions spectec/src/middlend/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ let generate_var ids id =
let max = 1000 in
let rec go prefix c =
if max <= c then assert false else
let name = fresh_prefix ^ "_" ^ Int.to_string c in
let name = prefix ^ "_" ^ Int.to_string c in
if (List.mem name ids)
then go prefix (c + 1)
else name
Expand All @@ -47,7 +47,7 @@ let generate_var ids id =
| s when List.mem s ids -> go s start
| _ -> id

let improve_ids_binders generate_all_binds at exp_typ_pairs =
let improve_ids_binders ids generate_all_binds at exp_typ_pairs =
let get_id_from_exp e =
match e.it with
| VarE id -> Some id.it
Expand All @@ -71,11 +71,23 @@ let improve_ids_binders generate_all_binds at exp_typ_pairs =
let tupt = TupT pairs $ typ_at in
let tupe = TupE (List.map fst pairs) $$ exp_at % tupt in
(binds' @ binds, (tupe, tupt) :: pairs')
| ({it = IterE (_, (_, iter_binds)); _}, {it = IterT _; _}) as b :: bs' ->
let new_binds = if generate_all_binds
then
List.filter_map (fun (_, e) ->
match e.it with
| VarE id -> Some (ExpB (id, e.note) $ e.at)
| _ -> None
) iter_binds
else []
in
let (binds, pairs) = improve_ids_helper ids bs' in
(new_binds @ binds, b :: pairs)
| b :: bs' ->
let (binds, pairs) = improve_ids_helper ids bs' in
(binds, b :: pairs)
in
improve_ids_helper [] exp_typ_pairs
improve_ids_helper ids exp_typ_pairs

let get_param_id p =
match p.it with
Expand Down
5 changes: 3 additions & 2 deletions spectec/test-middlend/dune.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/05-else.il specification.act/05-else.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/06-uncase-removal.il specification.act/06-uncase-removal.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/07-sideconditions.il specification.act/07-sideconditions.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/08-sub.il specification.act/08-sub.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/09-alias-demut.il specification.act/09-alias-demut.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/08-sub-expansion.il specification.act/08-sub-expansion.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/09-sub.il specification.act/09-sub.il))))
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/10-alias-demut.il specification.act/10-alias-demut.il))))
Loading