1+ (*
2+ This pass expands the subtyping patterns that appear in the LHS of
3+ function clauses and type family arguments.
4+
5+ It achieves this through the following steps:
6+ * For each argument, we collect every unique sub expression.
7+ * Then, for each sub expression, we collect every case that is
8+ possible in the subtype. If the specific case additionally carries
9+ values, then we generate binds to add in the function scope.
10+ * With all of these cases, for each unique sub expression, we compute
11+ the cartesian product in order to absolutely grab all the possible cases.
12+ See $cvtop to see how this might be done.
13+ * Once we have calculated the product, we generate a subst for each product
14+ and proceed to generate the clause/type instance.
15+ * Finally, we filter out binds that appear in the subst.
16+
17+ For example, take the following types and function:
18+
19+ syntax A = t1 nat | t2 nat nat
20+ syntax B = t1 nat | t2 nat nat | t3 | t4
21+
22+ def $foo(B) : nat
23+ def $foo(x : A <: B) = 1
24+ def $foo(t3) = 2
25+ def $foo(t4) = 3
26+
27+ Would be transformed as such:
28+
29+ def $foo(B) : nat
30+ def $foo{n : nat}(t1(n)) = 1
31+ def $foo{n1 : nat, n2 : nat}(t2(n1, n2)) = 1
32+ def $foo(t3) = 2
33+ def $foo(t4) = 3
34+
35+ *)
36+
37+ open Util
38+ open Source
39+ open Il.Ast
40+ open Il
41+
42+ (* Errors *)
43+
44+ let error at msg = Error. error at " sub expression expansion" msg
45+
46+ (* Environment *)
47+
48+ (* Global IL env*)
49+ let env_ref = ref Il.Env. empty
50+
51+ let empty_tuple_exp at = TupE [] $$ at % (TupT [] $ at)
52+
53+ (* Computes the cartesian product of a given list. *)
54+ let product_of_lists (lists : 'a list list ) =
55+ List. fold_left (fun acc lst ->
56+ List. concat_map (fun existing ->
57+ List. map (fun v -> v :: existing) lst) acc) [[]] lists
58+
59+ let product_of_lists_append (lists : 'a list list ) =
60+ List. fold_left (fun acc lst ->
61+ List. concat_map (fun existing ->
62+ List. map (fun v -> existing @ [v]) lst) acc) [[]] lists
63+
64+ let get_bind_id b =
65+ match b.it with
66+ | ExpB (id, _) | TypB id
67+ | DefB (id , _ , _ ) | GramB (id , _ , _ ) -> id.it
68+
69+ let eq_sube (id , t1 , t2 ) (id' , t1' , t2' ) =
70+ Eq. eq_id id id' && Eq. eq_typ t1 t1' && Eq. eq_typ t2 t2'
71+
72+ let rec collect_sube_exp e =
73+ let c_func = collect_sube_exp in
74+ match e.it with
75+ (* Assumption - nested sub expressions do not exist. Must also be a varE. *)
76+ | SubE ({it = VarE id ; _} , t1 , t2 ) -> [id, t1, t2]
77+ | CallE (_ , args ) -> List. concat_map collect_sube_arg args
78+ | StrE fields -> List. concat_map (fun (_a , e1 ) -> c_func e1) fields
79+ | UnE (_, _, e1) | CvtE (e1, _, _) | LiftE e1 | TheE e1 | OptE (Some e1)
80+ | ProjE (e1, _) | UncaseE (e1, _)
81+ | CaseE (_ , e1 ) | LenE e1 | DotE (e1 , _ ) -> c_func e1
82+ | BinE (_, _, e1, e2) | CmpE (_, _, e1, e2)
83+ | CompE (e1, e2) | MemE (e1, e2)
84+ | CatE (e1 , e2 ) | IdxE (e1 , e2 ) -> c_func e1 @ c_func e2
85+ | TupE exps | ListE exps -> List. concat_map collect_sube_exp exps
86+ | SliceE (e1 , e2 , e3 ) -> c_func e1 @ c_func e2 @ c_func e3
87+ | UpdE (e1, p, e2)
88+ | ExtE (e1 , p , e2 ) -> c_func e1 @ collect_fcalls_path p @ c_func e2
89+ | IterE (e1 , (iter , id_exp_pairs )) ->
90+ c_func e1 @ collect_sube_iter iter @
91+ List. concat_map (fun (_ , exp ) -> c_func exp) id_exp_pairs
92+ | _ -> []
93+
94+ and collect_sube_iter i =
95+ match i with
96+ | ListN (e1 , _ ) -> collect_sube_exp e1
97+ | _ -> []
98+
99+ and collect_sube_arg a =
100+ match a.it with
101+ | ExpA exp -> collect_sube_exp exp
102+ | _ -> []
103+
104+ and collect_fcalls_path p =
105+ match p.it with
106+ | RootP -> []
107+ | IdxP (p , e ) -> collect_fcalls_path p @ collect_sube_exp e
108+ | SliceP (p , e1 , e2 ) -> collect_fcalls_path p @ collect_sube_exp e1 @ collect_sube_exp e2
109+ | DotP (p , _ ) -> collect_fcalls_path p
110+
111+ let check_matching c_args match_args =
112+ Option. is_some (try
113+ Eval. match_list Eval. match_arg ! env_ref Subst. empty c_args match_args
114+ with Eval. Irred -> None )
115+
116+ let get_case_typ t =
117+ match t.it with
118+ | TupT typs -> typs
119+ | _ -> [VarE (" _" $ t.at) $$ t.at % t, t]
120+
121+ let collect_all_instances case_typ ids at inst =
122+ match inst.it with
123+ | InstD (_, _, {it = VariantT typcases; _}) when
124+ List. for_all (fun (_ , (_ , t , _ ), _ ) -> t.it = TupT [] ) typcases ->
125+ List. map (fun (m , _ , _ ) -> ([] , CaseE (m, empty_tuple_exp no_region) $$ at % case_typ)) typcases
126+ | InstD (_ , _ , {it = VariantT typcases ; _} ) ->
127+ let _, new_cases =
128+ List. fold_left (fun (ids' , acc ) (m , (_ , t , _ ), _ ) ->
129+ let typs = get_case_typ t in
130+ let new_binds, typs' = Utils. improve_ids_binders ids' true t.at typs in
131+ let exps = List. map fst typs' in
132+ let tup_exp = TupE exps $$ at % t in
133+ let case_exp = CaseE (m, tup_exp) $$ at % case_typ in
134+ let new_ids = List. map get_bind_id new_binds in
135+ (new_ids @ ids', (new_binds, case_exp) :: acc)
136+ ) (ids, [] ) typcases
137+ in
138+ new_cases
139+ | _ -> error at " Expected a variant type"
140+
141+ let rec collect_all_instances_typ ids at typ =
142+ match typ.it with
143+ | VarT (var_id , dep_args ) -> let (_, insts) = Il.Env. find_typ ! env_ref var_id in
144+ (match insts with
145+ | [] -> [] (* Should never happen *)
146+ | _ ->
147+ let inst_opt = List. find_opt (fun inst ->
148+ match inst.it with
149+ | InstD (_ , args , _ ) -> check_matching dep_args args
150+ ) insts in
151+ match inst_opt with
152+ | None -> error at (" Could not find specific instance for typ: " ^ Il.Print. string_of_typ typ)
153+ | Some inst -> collect_all_instances typ ids at inst
154+ )
155+ | TupT exp_typ_pairs ->
156+ let instances_list = List. map (fun (_ , t ) ->
157+ collect_all_instances_typ ids at t
158+ ) exp_typ_pairs in
159+ let product = product_of_lists_append instances_list in
160+ List. map (fun lst ->
161+ let binds, exps = List. split lst in
162+ List. concat binds, TupE exps $$ at % typ) product
163+ | _ -> []
164+
165+ let generate_subst_list lhs binds =
166+ (* Collect all unique sub expressions for each argument *)
167+ let subs = List. concat_map (fun a ->
168+ Lib.List. nub eq_sube (collect_sube_arg a)
169+ ) lhs in
170+ let ids = List. map get_bind_id binds in
171+
172+ (* Collect all cases for the specific subtype, generating any potential binds in the process *)
173+ let _, cases =
174+ List. fold_left (fun (binds , cases ) (id , t1 , _ ) ->
175+ let ids' = List. map get_bind_id binds @ ids in
176+ let instances = collect_all_instances_typ ids' id.at t1 in
177+ let new_binds = List. concat_map fst instances in
178+ let cases'' = List. map (fun case_data -> (id, case_data)) instances in
179+ (new_binds @ binds, cases'' :: cases)
180+ ) (binds, [] ) subs
181+ in
182+
183+ (* Compute cartesian product for all cases and generate a subst *)
184+ let cases' = product_of_lists cases in
185+ List. map (List. fold_left (fun (binds , subst ) (id , (binds' , exp )) ->
186+ (binds' @ binds, Il.Subst. add_varid subst id exp)) ([] , Il.Subst. empty)
187+ ) cases'
188+
189+ let t_clause clause =
190+ match clause.it with
191+ | DefD (binds , lhs , rhs , prems ) ->
192+ let subst_list = generate_subst_list lhs binds in
193+ List. map (fun (binds' , subst ) ->
194+ (* Subst all occurrences of the subE id *)
195+ let new_lhs = Il.Subst. subst_args subst lhs in
196+ let new_prems = Il.Subst. subst_list Il.Subst. subst_prem subst prems in
197+ let new_rhs = Il.Subst. subst_exp subst rhs in
198+
199+ (* Filtering binds - only the subst ids *)
200+ let binds_filtered = Lib.List. filter_not (fun b -> match b.it with
201+ | ExpB (id , _ ) -> Il.Subst. mem_varid subst id
202+ | _ -> false
203+ ) (binds' @ binds) in
204+ let new_binds, _ = Il.Subst. subst_binds subst binds_filtered in
205+ (* Reduction is done here to remove subtyping expressions *)
206+ DefD (new_binds, List. map (Il.Eval. reduce_arg ! env_ref) new_lhs, new_rhs, new_prems) $ clause.at
207+ ) subst_list
208+
209+ let t_inst inst =
210+ match inst.it with
211+ | InstD (binds , lhs , deftyp ) ->
212+ let subst_list = generate_subst_list lhs binds in
213+ List. map (fun (binds' , subst ) ->
214+ (* Subst all occurrences of the subE id *)
215+ let new_lhs = Il.Subst. subst_args subst lhs in
216+ let new_rhs = Il.Subst. subst_deftyp subst deftyp in
217+
218+ (* Filtering binds - only the subst ids *)
219+ let binds_filtered = Lib.List. filter_not (fun b -> match b.it with
220+ | ExpB (id , _ ) -> Il.Subst. mem_varid subst id
221+ | _ -> false
222+ ) (binds' @ binds) in
223+
224+ let new_binds, _ = Il.Subst. subst_binds subst binds_filtered in
225+ (* Reduction is done here to remove subtyping expressions *)
226+ InstD (new_binds, List. map (Il.Eval. reduce_arg ! env_ref) new_lhs, new_rhs) $ inst.at
227+ ) subst_list
228+
229+
230+ let rec t_def def =
231+ match def.it with
232+ | RecD defs -> { def with it = RecD (List. map t_def defs) }
233+ | DecD (id , params , typ , clauses ) ->
234+ { def with it = DecD (id, params, typ, List. concat_map t_clause clauses) }
235+ | TypD (id , params , insts ) ->
236+ { def with it = TypD (id, params, List. concat_map t_inst insts)}
237+ | _ -> def
238+
239+ let transform (defs : script ) =
240+ env_ref := Il.Env. env_of_script defs;
241+ List. map (t_def) defs
0 commit comments