Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/Lean/Compiler/LCNF/SpecInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ inductive SpecParamInfo where
-/
| fixedInst
/--
A parameter that is a function and is fixed in recursive declarations. If the user tags a declaration
with `@[specialize]` without specifying which arguments should be specialized, Lean will specialize
`.fixedHO` arguments in addition to `.fixedInst`.
A parameter that is a function and is fixed in recursive declarations, or a parameter the type of
which is the polymorphic return type `α` of the declaration and which could be instantiated to a
function.
If the user tags a declaration with `@[specialize]` without specifying which arguments should be
specialized, Lean will specialize `.fixedHO` arguments in addition to `.fixedInst`.
-/
| fixedHO
/--
Expand Down Expand Up @@ -142,6 +144,17 @@ private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat
return true
return false

def isFixedPolymorphicReturnType (decl : Decl) (type : Expr) (specInfos : Array SpecParamInfo) : CompilerM Bool := do
-- logInfo m!"isFixedPolymorphicReturnType: {decl.name}, {type}, {specInfos}"
let some idx := decl.params.findIdx? fun p => type == p.toExpr
| return false
let α := decl.params[idx]!.toExpr
let retTy ← instantiateForall decl.type <| decl.params.map (mkFVar ·.fvarId)
-- logInfo m!"isFixedPolymorphicReturnType2: {decl.name}, {α}, {retTy}"
if specInfos[idx]! matches .fixedNeutral && retTy == α then
return true
return false

/--
Save parameter information for `decls`.

Expand All @@ -158,8 +171,11 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
let specArgs? := getSpecializationArgs? (← getEnv) decl.name
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
let mut paramsInfo : Array SpecParamInfo := #[]
-- logInfo m!"decl.type: {decl.name} {decl.params.map fun p => (mkFVar p.fvarId, p.type)} {decl.type}"
for h :i in *...decl.params.size do
let param := decl.params[i]
-- let b ← isFixedPolymorphicReturnType decl param.type paramsInfo
-- logInfo m!"isFixedPolymorphicReturnType: {b}"
let info ←
if contains i then
pure .user
Expand All @@ -178,7 +194,7 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
specify which arguments must be specialized besides instances. In this case, we try to specialize
any "fixed higher-order argument"
-/
else if specArgs? == some #[] && param.type matches .forallE .. then
else if specArgs? == some #[] && (param.type matches .forallE .. || (← isFixedPolymorphicReturnType decl param.type paramsInfo)) then
pure .fixedHO
else
pure .other
Expand Down
61 changes: 49 additions & 12 deletions src/Lean/Compiler/LCNF/Specialize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ The keys never contain free variables or loose bound variables.

/--
Given the specialization mask `paramsInfo` and the arguments `args`,
collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t.
- `mask[i] = some args[i]` if `paramsInfo[i] != .other`
collect their dependencies, and return an array `mask` of size `args.size` s.t.
- `mask[i] = some args[i]` if `paramsInfo[i]? != some .other`
- `mask[i] = none`, otherwise.
That is, `mask` contains only the arguments that are contributing to the code specialization.
We use this information to compute a "key" to uniquely identify the code specialization, and
Expand All @@ -185,7 +185,9 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM
!ctx.ground.contains fvarId
Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do
let mut argMask := #[]
for paramInfo in paramsInfo, arg in args do
for i in *...args.size do
let paramInfo := paramsInfo[i]?.getD .fixedHO
let arg := args[i]!
match paramInfo with
| .other =>
argMask := argMask.push none
Expand All @@ -200,7 +202,9 @@ end Collector
Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
-/
def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do
for paramInfo in paramsInfo, arg in args do
for i in *...args.size do
let arg := args[i]!
let paramInfo := paramsInfo[i]?.getD .fixedHO -- .fixedHO might be too aggressive
match paramInfo with
| .other => pure ()
| .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
Expand Down Expand Up @@ -267,21 +271,54 @@ where
let .code code := decl.value | panic! "can only specialize decls with code"
let mut params ← params.mapM internalizeParam
let decls ← decls.mapM internalizeCodeDecl
for param in decl.params, arg in argMask do
let mut bodyType := decl.type.instantiateLevelParamsNoCache decl.levelParams us
for arg in argMask, param in decl.params do
let .forallE _ d b _ := bodyType.headBeta
| panic! "has param of type {param.type}, but bodyType {bodyType} was not a forall"
if let some arg := arg then
let arg ← normArg arg
modify fun s => s.insert param.fvarId arg
bodyType := b.instantiate1 arg.toExpr
else
-- Keep the parameter
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
params := params.push (← internalizeParam param)
for param in decl.params[argMask.size...*] do
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
params := params.push (← internalizeParam param)
let param ← internalizeParam { param with type := d }
params := params.push param
bodyType := b.instantiate1 (.fvar param.fvarId)
let extraParams := decl.params[argMask.size...*] -- non-empty if undersaturated app
let extraMask := argMask[decl.params.size...*] -- non-empty if oversaturated app
-- Add extraneous parameters to decl
for param in extraParams do
let .forallE _ d b _ := bodyType.headBeta
| panic! "has param of type {param.type}, but bodyType {bodyType} was not a forall"
-- Keep the parameter
let param ← internalizeParam { param with type := d }
params := params.push param
bodyType := b.instantiate1 (.fvar param.fvarId)
let code := code.instantiateValueLevelParams decl.levelParams us
let code ← internalizeCode code
let code := attachCodeDecls decls code
let type ← code.inferType
-- Eta-expand to accomodate extraneous args (cf. `etaExpandCore`)
let code ←
if extraMask.size = 0 then
pure code
else
let mut extraArgs := #[]
for arg in extraMask do
let .forallE _ d b _ := bodyType.headBeta
| panic! "oversaturated arg mask but decl.type was not a forall"
if let some arg := arg then
let arg ← normArg arg
extraArgs := extraArgs.push arg
bodyType := b.instantiate1 arg.toExpr
else
let p ← mkAuxParam d
params := params.push p
extraArgs := extraArgs.push (.fvar p.fvarId)
bodyType := b.instantiate1 (.fvar p.fvarId)
code.bind fun fvarId => do
let auxDecl ← mkAuxLetDecl (.fvar fvarId extraArgs)
return .let auxDecl (.return auxDecl.fvarId)
let type := bodyType
let type ← mkForallParams params type
let value := .code code
let safe := decl.safe
Expand All @@ -298,7 +335,7 @@ def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array Arg) : Arr
for info in paramsInfo, arg in args do
if info matches .other then
result := result.push arg
return result ++ args[paramsInfo.size...*]
return result -- ++ args[paramsInfo.size...*]

def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet :=
params.foldlM (init := {}) fun r p => do
Expand Down
190 changes: 190 additions & 0 deletions tests/lean/run/10924.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
@[specialize]
def foo {α} : Nat → (α → α) → α → α
| 0, f => f
| n+1, f => foo n f

set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 5
def foo._at_._example.spec_0 x.1 : Nat :=
cases x.1 : Nat
| Nat.zero =>
let _x.2 := 6;
return _x.2
| Nat.succ n.3 =>
let _x.4 := foo._at_._example.spec_0 n.3;
return _x.4
[Compiler.saveBase] size: 1
def _example n : Nat :=
let _x.1 := foo._at_._example.spec_0 n;
return _x.1
-/
#guard_msgs in
example {n} := foo n (· + 1) 5

set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 9
def foo._at_._example.spec_0 x.1 : Nat :=
fun _f.2 x.3 : Nat :=
let _x.4 := 1;
let _x.5 := Nat.add x.3 _x.4;
return _x.5;
let _x.6 := 5;
cases x.1 : Nat
| Nat.zero =>
let _x.7 := _f.2 _x.6;
let _x.8 := _f.2 _x.7;
return _x.8
| Nat.succ n.9 =>
let _x.10 := foo._at_._example.spec_0 n.9;
return _x.10
[Compiler.saveBase] size: 1
def _example n : Nat :=
let _x.1 := foo._at_._example.spec_0 n;
return _x.1
-/
#guard_msgs in
example {n} := foo n (fun f a => f (f a)) (· + 1) 5

set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 5
def foo._at_._example.spec_0 x.1 : Nat :=
let _x.2 := 5;
cases x.1 : Nat
| Nat.zero =>
return _x.2
| Nat.succ n.3 =>
let _x.4 := foo._at_._example.spec_0 n.3;
return _x.4
[Compiler.saveBase] size: 1
def _example n : Nat :=
let _x.1 := foo._at_._example.spec_0 n;
return _x.1
-/
#guard_msgs in
example {n} := foo n id id id id id id 5

set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 9
def foo._at_._example.spec_0 x.1 : Nat :=
fun _f.2 f g : Nat :=
let _x.3 := f g;
let _x.4 := f _x.3;
return _x.4;
fun _f.5 _y.6 : Nat :=
return _y.6;
let _x.7 := 5;
cases x.1 : Nat
| Nat.zero =>
let _x.8 := _f.2 _f.5;
let _x.9 := _f.2 _x.8 _x.7;
return _x.9
| Nat.succ n.10 =>
let _x.11 := foo._at_._example.spec_0 n.10;
return _x.11
[Compiler.saveBase] size: 1
def _example n : Nat :=
let _x.1 := foo._at_._example.spec_0 n;
return _x.1
-/
#guard_msgs in
example {n} := foo n (fun f g => f <| f g) (fun f g => f <| f g) id 5

@[specialize]
def List.forBreak_ {α : Type u} {m : Type w → Type x} [Monad m] (xs : List α) (body : α → ExceptCpsT PUnit m PUnit) : m PUnit :=
match xs with
| [] => pure ⟨⟩
| x :: xs => body x (fun _ => forBreak_ xs body) (fun _ => pure ⟨⟩)

-- This one still does not properly specialize for the success and error continuations
-- (`_y.4`, `_y.5`). The reason is that the loop body is not yet inlined when the specializer looks
-- at the recursive call site in `List.forBreak_._at_._example.spec_0`, so it allocates another,
-- strictly less general specialization `…spec_0.spec_0`.
-- The reason the loop body is not yet inlined is that it occurs in the recursive call site as well,
-- but only pre-specialization.
set_option trace.Compiler.saveBase true in
/--
trace: [Compiler.saveBase] size: 23
def List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _y.2 _y.3 _y.4 _y.5 xs : _y.3 :=
cases xs : _y.3
| List.nil =>
let _x.6 := PUnit.unit;
let _x.7 := @Prod.mk _ _ _x.6 _y.2;
let _x.8 := _y.4 _x.7;
return _x.8
| List.cons head.9 tail.10 =>
let _x.11 := 0;
let _x.12 := instDecidableEqNat _y.2 _x.11;
cases _x.12 : _y.3
| Decidable.isFalse x.13 =>
let _x.14 := 10;
let _x.15 := Nat.decLt _x.14 _y.2;
cases _x.15 : _y.3
| Decidable.isFalse x.16 =>
let _x.17 := Nat.add _y.2 head.9;
let _x.18 := List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _x.17 _y.3 _y.4 _y.5 tail.10;
return _x.18
| Decidable.isTrue x.19 =>
let _x.20 := Nat.add _y.2 _x.1;
let _x.21 := PUnit.unit;
let _x.22 := @Prod.mk _ _ _x.21 _x.20;
let _x.23 := _y.4 _x.22;
return _x.23
| Decidable.isTrue x.24 =>
let _x.25 := _y.5 _y.2;
return _x.25
[Compiler.saveBase] size: 19
def List.forBreak_._at_._example.spec_0 _x.1 xs : Nat :=
let x := 42;
fun _f.2 a : Nat :=
cases a : Nat
| Prod.mk fst.3 snd.4 =>
return snd.4;
fun _f.5 _y.6 : Nat :=
return _y.6;
cases xs : Nat
| List.nil =>
return x
| List.cons head.7 tail.8 =>
let _x.9 := 0;
let _x.10 := instDecidableEqNat x _x.9;
cases _x.10 : Nat
| Decidable.isFalse x.11 =>
let _x.12 := 10;
let _x.13 := Nat.decLt _x.12 x;
cases _x.13 : Nat
| Decidable.isFalse x.14 =>
let _x.15 := Nat.add x head.7;
let _x.16 := List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _x.15 _ _f.2 _f.5 tail.8;
return _x.16
| Decidable.isTrue x.17 =>
let _x.18 := Nat.add x _x.1;
return _x.18
| Decidable.isTrue x.19 =>
return x
[Compiler.saveBase] size: 8
def _example : Nat :=
let _x.1 := 1;
let _x.2 := 2;
let _x.3 := 3;
let _x.4 := @List.nil _;
let _x.5 := @List.cons _ _x.3 _x.4;
let _x.6 := @List.cons _ _x.2 _x.5;
let _x.7 := @List.cons _ _x.1 _x.6;
let _x.8 := List.forBreak_._at_._example.spec_0 _x.1 _x.7;
return _x.8
-/
#guard_msgs in
-- set_option trace.Compiler.specialize.candidate true in
-- set_option trace.Compiler.specialize.step true in
example := Id.run <| ExceptCpsT.runCatch do
let x := 42;
let ((), x) ←
(List.forBreak_ (m:=StateT Nat (ExceptCpsT Nat Id)) [1, 2, 3] fun i _β «continue» «break» x =>
if x = 0 then throw x
else if x > 10 then «break» () (x + 1)
else «continue» PUnit.unit (x + i)).run x
return x
Loading