Skip to content

Commit 0308a06

Browse files
authored
Merge pull request #195 from Wasm-DSL/sub-expansion-pass
Sub-expansion pass
2 parents 93b7a35 + 3e673fc commit 0308a06

File tree

11 files changed

+30239
-1493
lines changed

11 files changed

+30239
-1493
lines changed

spectec/src/exe-spectec/main.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type pass =
2222
| TypeFamilyRemoval
2323
| Else
2424
| Undep
25+
| SubExpansion
2526
| Uncaseremoval
2627
| AliasDemut
2728
| Ite
@@ -41,6 +42,7 @@ let all_passes = [
4142
Else;
4243
Uncaseremoval;
4344
Sideconditions;
45+
SubExpansion;
4446
Sub;
4547
AliasDemut;
4648
]
@@ -104,6 +106,7 @@ let pass_flag = function
104106
| AliasDemut -> "alias-demut"
105107
| Else -> "else"
106108
| Undep -> "remove-indexed-types"
109+
| SubExpansion -> "sub-expansion"
107110
| Uncaseremoval -> "uncase-removal"
108111
| Ite -> "ite"
109112

@@ -115,6 +118,7 @@ let pass_desc = function
115118
| TypeFamilyRemoval -> "Transform Type families into sum types"
116119
| Else -> "Eliminate the otherwise premise in relations"
117120
| Undep -> "Transform indexed types into types with well-formedness predicates"
121+
| SubExpansion -> "Expands subtype matching"
118122
| Uncaseremoval -> "Eliminate the uncase expression"
119123
| AliasDemut -> "Lifts type aliases out of mutual groups"
120124
| Ite -> "If-then-else introduction"
@@ -128,6 +132,7 @@ let run_pass : pass -> Il.Ast.script -> Il.Ast.script = function
128132
| TypeFamilyRemoval -> Middlend.Typefamilyremoval.transform
129133
| Else -> Middlend.Else.transform
130134
| Undep -> Middlend.Undep.transform
135+
| SubExpansion -> Middlend.Subexpansion.transform
131136
| Uncaseremoval -> Middlend.Uncaseremoval.transform
132137
| AliasDemut -> Middlend.AliasDemut.transform
133138
| Ite -> Middlend.Ite.transform

spectec/src/middlend/dune

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
else
1111
undep
1212
utils
13+
subexpansion
1314
uncaseremoval
1415
aliasDemut
1516
ite
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
(*
2+
This pass expands the subtyping patterns that appear in the LHS of
3+
function clauses and type family arguments.
4+
5+
It achieves this through the following steps:
6+
* For each argument, we collect every unique sub expression.
7+
* Then, for each sub expression, we collect every case that is
8+
possible in the subtype. If the specific case additionally carries
9+
values, then we generate binds to add in the function scope.
10+
* With all of these cases, for each unique sub expression, we compute
11+
the cartesian product in order to absolutely grab all the possible cases.
12+
See $cvtop to see how this might be done.
13+
* Once we have calculated the product, we generate a subst for each product
14+
and proceed to generate the clause/type instance.
15+
* Finally, we filter out binds that appear in the subst.
16+
17+
For example, take the following types and function:
18+
19+
syntax A = t1 nat | t2 nat nat
20+
syntax B = t1 nat | t2 nat nat | t3 | t4
21+
22+
def $foo(B) : nat
23+
def $foo(x : A <: B) = 1
24+
def $foo(t3) = 2
25+
def $foo(t4) = 3
26+
27+
Would be transformed as such:
28+
29+
def $foo(B) : nat
30+
def $foo{n : nat}(t1(n)) = 1
31+
def $foo{n1 : nat, n2 : nat}(t2(n1, n2)) = 1
32+
def $foo(t3) = 2
33+
def $foo(t4) = 3
34+
35+
*)
36+
37+
open Util
38+
open Source
39+
open Il.Ast
40+
open Il
41+
42+
(* Errors *)
43+
44+
let error at msg = Error.error at "sub expression expansion" msg
45+
46+
(* Environment *)
47+
48+
(* Global IL env*)
49+
let env_ref = ref Il.Env.empty
50+
51+
let empty_tuple_exp at = TupE [] $$ at % (TupT [] $ at)
52+
53+
(* Computes the cartesian product of a given list. *)
54+
let product_of_lists (lists : 'a list list) =
55+
List.fold_left (fun acc lst ->
56+
List.concat_map (fun existing ->
57+
List.map (fun v -> v :: existing) lst) acc) [[]] lists
58+
59+
let product_of_lists_append (lists : 'a list list) =
60+
List.fold_left (fun acc lst ->
61+
List.concat_map (fun existing ->
62+
List.map (fun v -> existing @ [v]) lst) acc) [[]] lists
63+
64+
let get_bind_id b =
65+
match b.it with
66+
| ExpB (id, _) | TypB id
67+
| DefB (id, _, _) | GramB (id, _, _) -> id.it
68+
69+
let eq_sube (id, t1, t2) (id', t1', t2') =
70+
Eq.eq_id id id' && Eq.eq_typ t1 t1' && Eq.eq_typ t2 t2'
71+
72+
let rec collect_sube_exp e =
73+
let c_func = collect_sube_exp in
74+
match e.it with
75+
(* Assumption - nested sub expressions do not exist. Must also be a varE. *)
76+
| SubE ({it = VarE id; _}, t1, t2) -> [id, t1, t2]
77+
| CallE (_, args) -> List.concat_map collect_sube_arg args
78+
| StrE fields -> List.concat_map (fun (_a, e1) -> c_func e1) fields
79+
| UnE (_, _, e1) | CvtE (e1, _, _) | LiftE e1 | TheE e1 | OptE (Some e1)
80+
| ProjE (e1, _) | UncaseE (e1, _)
81+
| CaseE (_, e1) | LenE e1 | DotE (e1, _) -> c_func e1
82+
| BinE (_, _, e1, e2) | CmpE (_, _, e1, e2)
83+
| CompE (e1, e2) | MemE (e1, e2)
84+
| CatE (e1, e2) | IdxE (e1, e2) -> c_func e1 @ c_func e2
85+
| TupE exps | ListE exps -> List.concat_map collect_sube_exp exps
86+
| SliceE (e1, e2, e3) -> c_func e1 @ c_func e2 @ c_func e3
87+
| UpdE (e1, p, e2)
88+
| ExtE (e1, p, e2) -> c_func e1 @ collect_fcalls_path p @ c_func e2
89+
| IterE (e1, (iter, id_exp_pairs)) ->
90+
c_func e1 @ collect_sube_iter iter @
91+
List.concat_map (fun (_, exp) -> c_func exp) id_exp_pairs
92+
| _ -> []
93+
94+
and collect_sube_iter i =
95+
match i with
96+
| ListN (e1, _) -> collect_sube_exp e1
97+
| _ -> []
98+
99+
and collect_sube_arg a =
100+
match a.it with
101+
| ExpA exp -> collect_sube_exp exp
102+
| _ -> []
103+
104+
and collect_fcalls_path p =
105+
match p.it with
106+
| RootP -> []
107+
| IdxP (p, e) -> collect_fcalls_path p @ collect_sube_exp e
108+
| SliceP (p, e1, e2) -> collect_fcalls_path p @ collect_sube_exp e1 @ collect_sube_exp e2
109+
| DotP (p, _) -> collect_fcalls_path p
110+
111+
let check_matching c_args match_args =
112+
Option.is_some (try
113+
Eval.match_list Eval.match_arg !env_ref Subst.empty c_args match_args
114+
with Eval.Irred -> None)
115+
116+
let get_case_typ t =
117+
match t.it with
118+
| TupT typs -> typs
119+
| _ -> [VarE ("_" $ t.at) $$ t.at % t, t]
120+
121+
let collect_all_instances case_typ ids at inst =
122+
match inst.it with
123+
| InstD (_, _, {it = VariantT typcases; _}) when
124+
List.for_all (fun (_, (_, t, _), _) -> t.it = TupT []) typcases ->
125+
List.map (fun (m, _, _) -> ([], CaseE (m, empty_tuple_exp no_region) $$ at % case_typ)) typcases
126+
| InstD (_, _, {it = VariantT typcases; _}) ->
127+
let _, new_cases =
128+
List.fold_left (fun (ids', acc) (m, (_, t, _), _) ->
129+
let typs = get_case_typ t in
130+
let new_binds, typs' = Utils.improve_ids_binders ids' true t.at typs in
131+
let exps = List.map fst typs' in
132+
let tup_exp = TupE exps $$ at % t in
133+
let case_exp = CaseE (m, tup_exp) $$ at % case_typ in
134+
let new_ids = List.map get_bind_id new_binds in
135+
(new_ids @ ids', (new_binds, case_exp) :: acc)
136+
) (ids, []) typcases
137+
in
138+
new_cases
139+
| _ -> error at "Expected a variant type"
140+
141+
let rec collect_all_instances_typ ids at typ =
142+
match typ.it with
143+
| VarT (var_id, dep_args) -> let (_, insts) = Il.Env.find_typ !env_ref var_id in
144+
(match insts with
145+
| [] -> [] (* Should never happen *)
146+
| _ ->
147+
let inst_opt = List.find_opt (fun inst ->
148+
match inst.it with
149+
| InstD (_, args, _) -> check_matching dep_args args
150+
) insts in
151+
match inst_opt with
152+
| None -> error at ("Could not find specific instance for typ: " ^ Il.Print.string_of_typ typ)
153+
| Some inst -> collect_all_instances typ ids at inst
154+
)
155+
| TupT exp_typ_pairs ->
156+
let instances_list = List.map (fun (_, t) ->
157+
collect_all_instances_typ ids at t
158+
) exp_typ_pairs in
159+
let product = product_of_lists_append instances_list in
160+
List.map (fun lst ->
161+
let binds, exps = List.split lst in
162+
List.concat binds, TupE exps $$ at % typ) product
163+
| _ -> []
164+
165+
let generate_subst_list lhs binds =
166+
(* Collect all unique sub expressions for each argument *)
167+
let subs = List.concat_map (fun a ->
168+
Lib.List.nub eq_sube (collect_sube_arg a)
169+
) lhs in
170+
let ids = List.map get_bind_id binds in
171+
172+
(* Collect all cases for the specific subtype, generating any potential binds in the process *)
173+
let _, cases =
174+
List.fold_left (fun (binds, cases) (id, t1, _) ->
175+
let ids' = List.map get_bind_id binds @ ids in
176+
let instances = collect_all_instances_typ ids' id.at t1 in
177+
let new_binds = List.concat_map fst instances in
178+
let cases'' = List.map (fun case_data -> (id, case_data)) instances in
179+
(new_binds @ binds, cases'' :: cases)
180+
) (binds, []) subs
181+
in
182+
183+
(* Compute cartesian product for all cases and generate a subst *)
184+
let cases' = product_of_lists cases in
185+
List.map (List.fold_left (fun (binds, subst) (id, (binds', exp)) ->
186+
(binds' @ binds, Il.Subst.add_varid subst id exp)) ([], Il.Subst.empty)
187+
) cases'
188+
189+
let t_clause clause =
190+
match clause.it with
191+
| DefD (binds, lhs, rhs, prems) ->
192+
let subst_list = generate_subst_list lhs binds in
193+
List.map (fun (binds', subst) ->
194+
(* Subst all occurrences of the subE id *)
195+
let new_lhs = Il.Subst.subst_args subst lhs in
196+
let new_prems = Il.Subst.subst_list Il.Subst.subst_prem subst prems in
197+
let new_rhs = Il.Subst.subst_exp subst rhs in
198+
199+
(* Filtering binds - only the subst ids *)
200+
let binds_filtered = Lib.List.filter_not (fun b -> match b.it with
201+
| ExpB (id, _) -> Il.Subst.mem_varid subst id
202+
| _ -> false
203+
) (binds' @ binds) in
204+
let new_binds, _ = Il.Subst.subst_binds subst binds_filtered in
205+
(* Reduction is done here to remove subtyping expressions *)
206+
DefD (new_binds, List.map (Il.Eval.reduce_arg !env_ref) new_lhs, new_rhs, new_prems) $ clause.at
207+
) subst_list
208+
209+
let t_inst inst =
210+
match inst.it with
211+
| InstD (binds, lhs, deftyp) ->
212+
let subst_list = generate_subst_list lhs binds in
213+
List.map (fun (binds', subst) ->
214+
(* Subst all occurrences of the subE id *)
215+
let new_lhs = Il.Subst.subst_args subst lhs in
216+
let new_rhs = Il.Subst.subst_deftyp subst deftyp in
217+
218+
(* Filtering binds - only the subst ids *)
219+
let binds_filtered = Lib.List.filter_not (fun b -> match b.it with
220+
| ExpB (id, _) -> Il.Subst.mem_varid subst id
221+
| _ -> false
222+
) (binds' @ binds) in
223+
224+
let new_binds, _ = Il.Subst.subst_binds subst binds_filtered in
225+
(* Reduction is done here to remove subtyping expressions *)
226+
InstD (new_binds, List.map (Il.Eval.reduce_arg !env_ref) new_lhs, new_rhs) $ inst.at
227+
) subst_list
228+
229+
230+
let rec t_def def =
231+
match def.it with
232+
| RecD defs -> { def with it = RecD (List.map t_def defs) }
233+
| DecD (id, params, typ, clauses) ->
234+
{ def with it = DecD (id, params, typ, List.concat_map t_clause clauses) }
235+
| TypD (id, params, insts) ->
236+
{ def with it = TypD (id, params, List.concat_map t_inst insts)}
237+
| _ -> def
238+
239+
let transform (defs : script) =
240+
env_ref := Il.Env.env_of_script defs;
241+
List.map (t_def) defs
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
val transform : Il.Ast.script -> Il.Ast.script

spectec/src/middlend/undep.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ let create_well_formed_predicate id env inst =
352352
| TupT tups -> tups
353353
| _ -> [(VarE ("_" $ id.at) $$ id.at % case_typ, case_typ)]
354354
in
355-
let extra_binds, t_pairs = Utils.improve_ids_binders false id.at exp_typ_pairs in
355+
let extra_binds, t_pairs = Utils.improve_ids_binders [] false id.at exp_typ_pairs in
356356
let new_binds = case_binds @ extra_binds in
357357
let exp = TupE (List.map fst t_pairs) $$ at % (TupT t_pairs $ at) in
358358
let case_exp = CaseE (m, exp) $$ at % user_typ in
@@ -386,7 +386,7 @@ let create_well_formed_predicate id env inst =
386386
([wrapped], tups, prems)
387387
) typfields) in
388388

389-
let (rule_binds, pairs') = Utils.improve_ids_binders true at pairs in
389+
let (rule_binds, pairs') = Utils.improve_ids_binders [] true at pairs in
390390
let new_prems = (List.filter_map get_exp_typ rule_binds |> List.concat_map (get_wf_pred env)) @ rule_prems in
391391
let str_exp = StrE (List.map2 (fun a ((e, t), wrapped) ->
392392
let tupt = TupT [(e, t)] $ at in

spectec/src/middlend/utils.ml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ let generate_var ids id =
3737
let max = 1000 in
3838
let rec go prefix c =
3939
if max <= c then assert false else
40-
let name = fresh_prefix ^ "_" ^ Int.to_string c in
40+
let name = prefix ^ "_" ^ Int.to_string c in
4141
if (List.mem name ids)
4242
then go prefix (c + 1)
4343
else name
@@ -47,7 +47,7 @@ let generate_var ids id =
4747
| s when List.mem s ids -> go s start
4848
| _ -> id
4949

50-
let improve_ids_binders generate_all_binds at exp_typ_pairs =
50+
let improve_ids_binders ids generate_all_binds at exp_typ_pairs =
5151
let get_id_from_exp e =
5252
match e.it with
5353
| VarE id -> Some id.it
@@ -71,11 +71,23 @@ let improve_ids_binders generate_all_binds at exp_typ_pairs =
7171
let tupt = TupT pairs $ typ_at in
7272
let tupe = TupE (List.map fst pairs) $$ exp_at % tupt in
7373
(binds' @ binds, (tupe, tupt) :: pairs')
74+
| ({it = IterE (_, (_, iter_binds)); _}, {it = IterT _; _}) as b :: bs' ->
75+
let new_binds = if generate_all_binds
76+
then
77+
List.filter_map (fun (_, e) ->
78+
match e.it with
79+
| VarE id -> Some (ExpB (id, e.note) $ e.at)
80+
| _ -> None
81+
) iter_binds
82+
else []
83+
in
84+
let (binds, pairs) = improve_ids_helper ids bs' in
85+
(new_binds @ binds, b :: pairs)
7486
| b :: bs' ->
7587
let (binds, pairs) = improve_ids_helper ids bs' in
7688
(binds, b :: pairs)
7789
in
78-
improve_ids_helper [] exp_typ_pairs
90+
improve_ids_helper ids exp_typ_pairs
7991

8092
let get_param_id p =
8193
match p.it with

spectec/test-middlend/dune.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/05-else.il specification.act/05-else.il))))
77
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/06-uncase-removal.il specification.act/06-uncase-removal.il))))
88
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/07-sideconditions.il specification.act/07-sideconditions.il))))
9-
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/08-sub.il specification.act/08-sub.il))))
10-
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/09-alias-demut.il specification.act/09-alias-demut.il))))
9+
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/08-sub-expansion.il specification.act/08-sub-expansion.il))))
10+
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/09-sub.il specification.act/09-sub.il))))
11+
(rule (alias runtest) (deps (alias dune.inc) (file specification.act) (glob_files_rec specification.exp/*)) (action (no-infer (diff specification.exp/10-alias-demut.il specification.act/10-alias-demut.il))))

0 commit comments

Comments
 (0)