Skip to content

Commit 1ba0266

Browse files
nomeataluisacicolini
authored andcommitted
feat: nested well-founded recursion via automatic preprocessing (leanprover#6744)
This PR extend the preprocessing of well-founded recursive definitions to bring assumptions like `h✝ : x ∈ xs` into scope automatically. This fixes leanprover#5471, and follows (roughly) the design written there. See the module docs at `src/Lean/Elab/PreDefinition/WF/AutoAttach.lean` for details on the implementation. This only works for higher-order functions that have a suitable setup. See for example section “Well-founded recursion preprocessing setup” in `src/Init/Data/List/Attach.lean`. This does not change the `decreasing_tactic`, so in some cases there is still the need for a manual termination proof some cases. We expect a better termination tactic in the near future.
1 parent e811f88 commit 1ba0266

18 files changed

+891
-153
lines changed

src/Init/ByCases.lean

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ theorem apply_ite (f : α → β) (P : Prop) [Decidable P] (x y : α) :
3838
apply_dite f P (fun _ => x) (fun _ => y)
3939

4040
/-- A `dite` whose results do not actually depend on the condition may be reduced to an `ite`. -/
41-
@[simp] theorem dite_eq_ite [Decidable P] : (dite P (fun _ => a) fun _ => b) = ite P a b := rfl
41+
@[simp] theorem dite_eq_ite [Decidable P] :
42+
(dite P (fun _ => a) (fun _ => b)) = ite P a b := rfl
4243

4344
@[deprecated "Use `ite_eq_right_iff`" (since := "2024-09-18")]
4445
theorem ite_some_none_eq_none [Decidable P] :

src/Init/Control/Basic.lean

+7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Author: Leonardo de Moura, Sebastian Ullrich
55
-/
66
prelude
77
import Init.Core
8+
import Init.BinderNameHint
89

910
universe u v w
1011

@@ -35,6 +36,12 @@ instance (priority := 500) instForInOfForIn' [ForIn' m ρ α d] : ForIn m ρ α
3536
simp [h]
3637
rfl
3738

39+
@[wf_preprocess] theorem forIn_eq_forin' [d : Membership α ρ] [ForIn' m ρ α d] {β} [Monad m]
40+
(x : ρ) (b : β) (f : (a : α) → β → m (ForInStep β)) :
41+
forIn x b f = forIn' x b (fun x h => binderNameHint x f <| binderNameHint h () <| f x) := by
42+
simp [binderNameHint]
43+
rfl -- very strange why `simp` did not close it
44+
3845
/-- Extract the value from a `ForInStep`, ignoring whether it is `done` or `yield`. -/
3946
def ForInStep.value (x : ForInStep α) : α :=
4047
match x with

src/Init/Data/Array/Attach.lean

+73
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,16 @@ and simplifies these to the function directly taking the value.
650650
rw [List.filterMap_subtype]
651651
simp [hf]
652652

653+
654+
@[simp] theorem flatMap_subtype {p : α → Prop} {l : Array { x // p x }}
655+
{f : { x // p x } → Array β} {g : α → Array β} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
656+
(l.flatMap f) = l.unattach.flatMap g := by
657+
cases l
658+
simp only [size_toArray, List.flatMap_toArray, List.unattach_toArray, List.length_unattach,
659+
mk.injEq]
660+
rw [List.flatMap_subtype]
661+
simp [hf]
662+
653663
@[simp] theorem findSome?_subtype {p : α → Prop} {l : Array { x // p x }}
654664
{f : { x // p x } → Option β} {g : α → Option β} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
655665
l.findSome? f = l.unattach.findSome? g := by
@@ -695,4 +705,67 @@ and simplifies these to the function directly taking the value.
695705
(Array.mkArray n x).unattach = Array.mkArray n x.1 := by
696706
simp [unattach]
697707

708+
/-! ### Well-founded recursion preprocessing setup -/
709+
710+
@[wf_preprocess] theorem Array.map_wfParam (xs : Array α) (f : α → β) :
711+
(wfParam xs).map f = xs.attach.unattach.map f := by
712+
simp [wfParam]
713+
714+
@[wf_preprocess] theorem Array.map_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : α → β) :
715+
xs.unattach.map f = xs.map fun ⟨x, h⟩ =>
716+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
717+
simp [wfParam]
718+
719+
@[wf_preprocess] theorem foldl_wfParam (xs : Array α) (f : β → α → β) (x : β) :
720+
(wfParam xs).foldl f x = xs.attach.unattach.foldl f x := by
721+
simp [wfParam]
722+
723+
@[wf_preprocess] theorem foldl_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : β → α → β) (x : β):
724+
xs.unattach.foldl f x = xs.foldl (fun s ⟨x, h⟩ =>
725+
binderNameHint s f <| binderNameHint x (f s) <| binderNameHint h () <| f s (wfParam x)) x := by
726+
simp [wfParam]
727+
728+
@[wf_preprocess] theorem foldr_wfParam (xs : Array α) (f : α → β → β) (x : β) :
729+
(wfParam xs).foldr f x = xs.attach.unattach.foldr f x := by
730+
simp [wfParam]
731+
732+
@[wf_preprocess] theorem foldr_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : α → β → β) (x : β):
733+
xs.unattach.foldr f x = xs.foldr (fun ⟨x, h⟩ s =>
734+
binderNameHint x f <| binderNameHint s (f x) <| binderNameHint h () <| f (wfParam x) s) x := by
735+
simp [wfParam]
736+
737+
@[wf_preprocess] theorem filter_wfParam (xs : Array α) (f : α → Bool) :
738+
(wfParam xs).filter f = xs.attach.unattach.filter f:= by
739+
simp [wfParam]
740+
741+
@[wf_preprocess] theorem filter_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : α → Bool) :
742+
xs.unattach.filter f = (xs.filter (fun ⟨x, h⟩ =>
743+
binderNameHint x f <| binderNameHint h () <| f (wfParam x))).unattach := by
744+
simp [wfParam]
745+
746+
@[wf_preprocess] theorem reverse_wfParam (xs : Array α) :
747+
(wfParam xs).reverse = xs.attach.unattach.reverse := by simp [wfParam]
748+
749+
@[wf_preprocess] theorem reverse_unattach (P : α → Prop) (xs : Array (Subtype P)) :
750+
xs.unattach.reverse = xs.reverse.unattach := by simp
751+
752+
@[wf_preprocess] theorem filterMap_wfParam (xs : Array α) (f : α → Option β) :
753+
(wfParam xs).filterMap f = xs.attach.unattach.filterMap f := by
754+
simp [wfParam]
755+
756+
@[wf_preprocess] theorem filterMap_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : α → Option β) :
757+
xs.unattach.filterMap f = xs.filterMap fun ⟨x, h⟩ =>
758+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
759+
simp [wfParam]
760+
761+
@[wf_preprocess] theorem flatMap_wfParam (xs : Array α) (f : α → Array β) :
762+
(wfParam xs).flatMap f = xs.attach.unattach.flatMap f := by
763+
simp [wfParam]
764+
765+
@[wf_preprocess] theorem flatMap_unattach (P : α → Prop) (xs : Array (Subtype P)) (f : α → Array β) :
766+
xs.unattach.flatMap f = xs.flatMap fun ⟨x, h⟩ =>
767+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
768+
simp [wfParam]
769+
770+
698771
end Array

src/Init/Data/List/Attach.lean

+63
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Authors: Mario Carneiro
66
prelude
77
import Init.Data.List.Count
88
import Init.Data.Subtype
9+
import Init.BinderNameHint
910

1011
namespace List
1112

@@ -796,4 +797,66 @@ and simplifies these to the function directly taking the value.
796797
(List.replicate n x).unattach = List.replicate n x.1 := by
797798
simp [unattach, -map_subtype]
798799

800+
/-! ### Well-founded recursion preprocessing setup -/
801+
802+
@[wf_preprocess] theorem map_wfParam (xs : List α) (f : α → β) :
803+
(wfParam xs).map f = xs.attach.unattach.map f := by
804+
simp [wfParam]
805+
806+
@[wf_preprocess] theorem map_unattach (P : α → Prop) (xs : List (Subtype P)) (f : α → β) :
807+
xs.unattach.map f = xs.map fun ⟨x, h⟩ =>
808+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
809+
simp [wfParam]
810+
811+
@[wf_preprocess] theorem foldl_wfParam (xs : List α) (f : β → α → β) (x : β) :
812+
(wfParam xs).foldl f x = xs.attach.unattach.foldl f x := by
813+
simp [wfParam]
814+
815+
@[wf_preprocess] theorem foldl_unattach (P : α → Prop) (xs : List (Subtype P)) (f : β → α → β) (x : β):
816+
xs.unattach.foldl f x = xs.foldl (fun s ⟨x, h⟩ =>
817+
binderNameHint s f <| binderNameHint x (f s) <| binderNameHint h () <| f s (wfParam x)) x := by
818+
simp [wfParam]
819+
820+
@[wf_preprocess] theorem foldr_wfParam (xs : List α) (f : α → β → β) (x : β) :
821+
(wfParam xs).foldr f x = xs.attach.unattach.foldr f x := by
822+
simp [wfParam]
823+
824+
@[wf_preprocess] theorem foldr_unattach (P : α → Prop) (xs : List (Subtype P)) (f : α → β → β) (x : β):
825+
xs.unattach.foldr f x = xs.foldr (fun ⟨x, h⟩ s =>
826+
binderNameHint x f <| binderNameHint s (f x) <| binderNameHint h () <| f (wfParam x) s) x := by
827+
simp [wfParam]
828+
829+
@[wf_preprocess] theorem filter_wfParam (xs : List α) (f : α → Bool) :
830+
(wfParam xs).filter f = xs.attach.unattach.filter f:= by
831+
simp [wfParam]
832+
833+
@[wf_preprocess] theorem filter_unattach (P : α → Prop) (xs : List (Subtype P)) (f : α → Bool) :
834+
xs.unattach.filter f = (xs.filter (fun ⟨x, h⟩ =>
835+
binderNameHint x f <| binderNameHint h () <| f (wfParam x))).unattach := by
836+
simp [wfParam]
837+
838+
@[wf_preprocess] theorem reverse_wfParam (xs : List α) :
839+
(wfParam xs).reverse = xs.attach.unattach.reverse := by simp [wfParam]
840+
841+
@[wf_preprocess] theorem reverse_unattach (P : α → Prop) (xs : List (Subtype P)) :
842+
xs.unattach.reverse = xs.reverse.unattach := by simp
843+
844+
@[wf_preprocess] theorem filterMap_wfParam (xs : List α) (f : α → Option β) :
845+
(wfParam xs).filterMap f = xs.attach.unattach.filterMap f := by
846+
simp [wfParam]
847+
848+
@[wf_preprocess] theorem filterMap_unattach (P : α → Prop) (xs : List (Subtype P)) (f : α → Option β) :
849+
xs.unattach.filterMap f = xs.filterMap fun ⟨x, h⟩ =>
850+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
851+
simp [wfParam]
852+
853+
@[wf_preprocess] theorem flatMap_wfParam (xs : List α) (f : α → List β) :
854+
(wfParam xs).flatMap f = xs.attach.unattach.flatMap f := by
855+
simp [wfParam]
856+
857+
@[wf_preprocess] theorem flatMap_unattach (P : α → Prop) (xs : List (Subtype P)) (f : α → List β) :
858+
xs.unattach.flatMap f = xs.flatMap fun ⟨x, h⟩ =>
859+
binderNameHint x f <| binderNameHint h () <| f (wfParam x) := by
860+
simp [wfParam]
861+
799862
end List

src/Init/WF.lean

+16
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Author: Leonardo de Moura
55
-/
66
prelude
77
import Init.SizeOf
8+
import Init.BinderNameHint
89
import Init.Data.Nat.Basic
910

1011
universe u v
@@ -414,3 +415,18 @@ theorem mkSkipLeft {α : Type u} {β : Type v} {b₁ b₂ : β} {s : β → β
414415
end
415416

416417
end PSigma
418+
419+
/--
420+
The `wfParam` gadget is used internally during the construction of recursive functions by
421+
wellfounded recursion, to keep track of the parameter for which the automatic introduction
422+
of `List.attach` (or similar) is plausible.
423+
-/
424+
def wfParam {α : Sort u} (a : α) : α := a
425+
426+
/--
427+
Reverse direction of `dite_eq_ite`. Used by the well-founded definition preprocessor to extend the
428+
context of a termination proof inside `if-then-else` with the condition.
429+
-/
430+
@[wf_preprocess] theorem ite_eq_dite [Decidable P] :
431+
ite P a b = (dite P (fun h => binderNameHint h () a) (fun h => binderNameHint h () b)) := by
432+
rfl

src/Lean/Elab/PreDefinition/WF/AutoAttach.lean

-19
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/-
2+
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
prelude
7+
import Lean.Meta.Transform
8+
import Lean.Elab.RecAppSyntax
9+
10+
namespace Lean.Elab.WF
11+
open Meta
12+
13+
/--
14+
Preprocesses the expressions to improve the effectiveness of `wfRecursion`.
15+
16+
* Floats out the RecApp markers.
17+
Example:
18+
```
19+
def f : Nat → Nat
20+
| 0 => 1
21+
| i+1 => (f x) i
22+
```
23+
24+
Unlike `Lean.Elab.Structural.preprocess`, do _not_ beta-reduce, as it could
25+
remove `let_fun`-lambdas that contain explicit termination proofs.
26+
-/
27+
def floatRecApp (e : Expr) : CoreM Expr :=
28+
Core.transform e
29+
(post := fun e => do
30+
if e.isApp && e.getAppFn.isMData then
31+
let .mdata m f := e.getAppFn | unreachable!
32+
if m.isRecApp then
33+
return .done (.mdata m (f.beta e.getAppArgs))
34+
return .continue)
35+
36+
end Lean.Elab.WF

src/Lean/Elab/PreDefinition/WF/Ite.lean

-30
This file was deleted.

src/Lean/Elab/PreDefinition/WF/Main.lean

+10-9
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ import Lean.Elab.PreDefinition.Basic
88
import Lean.Elab.PreDefinition.TerminationMeasure
99
import Lean.Elab.PreDefinition.Mutual
1010
import Lean.Elab.PreDefinition.WF.PackMutual
11-
import Lean.Elab.PreDefinition.WF.Preprocess
11+
import Lean.Elab.PreDefinition.WF.FloatRecApp
1212
import Lean.Elab.PreDefinition.WF.Rel
1313
import Lean.Elab.PreDefinition.WF.Fix
1414
import Lean.Elab.PreDefinition.WF.Unfold
15-
import Lean.Elab.PreDefinition.WF.Ite
16-
import Lean.Elab.PreDefinition.WF.AutoAttach
15+
import Lean.Elab.PreDefinition.WF.Preprocess
1716
import Lean.Elab.PreDefinition.WF.GuessLex
1817

1918
namespace Lean.Elab
@@ -23,8 +22,8 @@ open Meta
2322
def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option TerminationMeasure)) : TermElabM Unit := do
2423
let termMeasures? := termMeasure?s.mapM id -- Either all or none, checked by `elabTerminationByHints`
2524
let preDefs ← preDefs.mapM fun preDef =>
26-
return { preDef with value := (← preprocess preDef.value) }
27-
let (fixedPrefixSize, argsPacker, unaryPreDef) ← withoutModifyingEnv do
25+
return { preDef with value := (← floatRecApp preDef.value) }
26+
let (fixedPrefixSize, argsPacker, unaryPreDef, wfPreprocessProofs) ← withoutModifyingEnv do
2827
for preDef in preDefs do
2928
addAsAxiom preDef
3029
let fixedPrefixSize ← Mutual.getFixedPrefix preDefs
@@ -34,8 +33,10 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
3433
if varNames.isEmpty then
3534
throwError "well-founded recursion cannot be used, '{preDef.declName}' does not take any (non-fixed) arguments"
3635
let argsPacker := { varNamess }
37-
let preDefsDIte ← preDefs.mapM fun preDef => return { preDef with value := (← iteToDIte preDef.value) }
38-
return (fixedPrefixSize, argsPacker, ← packMutual fixedPrefixSize argsPacker preDefsDIte)
36+
let (preDefsAttached, wfPreprocessProofs) ← Array.unzip <$> preDefs.mapM fun preDef => do
37+
let result ← preprocess preDef.value
38+
return ({preDef with value := result.expr}, result)
39+
return (fixedPrefixSize, argsPacker, ← packMutual fixedPrefixSize argsPacker preDefsAttached, wfPreprocessProofs)
3940

4041
let wf : TerminationMeasures ← do
4142
if let some tms := termMeasures? then pure tms else
@@ -62,10 +63,10 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
6263
Mutual.addPreDefsFromUnary preDefs preDefsNonrec preDefNonRec
6364
let preDefs ← Mutual.cleanPreDefs preDefs
6465
registerEqnsInfo preDefs preDefNonRec.declName fixedPrefixSize argsPacker
65-
for preDef in preDefs do
66+
for preDef in preDefs, wfPreprocessProof in wfPreprocessProofs do
6667
unless preDef.kind.isTheorem do
6768
unless (← isProp preDef.type) do
68-
WF.mkUnfoldEq preDef preDefNonRec.declName
69+
WF.mkUnfoldEq preDef preDefNonRec.declName wfPreprocessProof
6970
Mutual.addPreDefAttributes preDefs
7071

7172
builtin_initialize registerTraceClass `Elab.definition.wf

0 commit comments

Comments
 (0)