diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 372f09a..547c317 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -74,3 +74,7 @@ jobs: continue-on-error: true run: | lake exe graphiti benchmarks/post-processed/gsum_many.dot -o out.dot --oracle $(pwd)/bin/graphiti_oracle --reverse + + - name: test img_avg example + run: | + lake exe graphiti benchmarks/post-processed/img_avg.dot -o out.dot --oracle $(pwd)/bin/graphiti_oracle --reverse diff --git a/Graphiti/Core/Basic.lean b/Graphiti/Core/Basic.lean index d51111e..397ab0f 100644 --- a/Graphiti/Core/Basic.lean +++ b/Graphiti/Core/Basic.lean @@ -75,6 +75,14 @@ meta def fromExpr? (e : Expr) : SimpM (Option Nat) := meta def fromExpr?' (e : Expr) : SimpM (Option (Array Char)) := getListLitOf? e getCharValue? +def ofOption {ε α σ} (e : ε) : Option α → EStateM ε σ α +| some o => pure o +| none => throw e + +def ofOption' {ε α} (e : ε) : Option α → Except ε α +| some o => pure o +| none => throw e + /-- Reduce `toString 5` to `"5"` -/ diff --git a/Graphiti/Core/DotParser.lean b/Graphiti/Core/DotParser.lean index 23f1bed..4ce350f 100644 --- a/Graphiti/Core/DotParser.lean +++ b/Graphiti/Core/DotParser.lean @@ -8,6 +8,7 @@ module public import Lean public import Graphiti.Core.ExprHigh +public import Graphiti.Core.DynamaticTypes public section @@ -242,6 +243,9 @@ def translateSize' : String → Except String String def translateSize (s : String) : Except String String := translateSize' (s.takeWhile (·.isDigit)) +def toSizeInfo (l : List Parser.DotAttr) (k : String) := + toNat? <$> parseIOSizes l k + /-- Parse a dot expression that comes from Dynamatic. It returns the graph expression, as well as a list of additional attributes that should be bypassed @@ -269,50 +273,16 @@ def dotToExprHigh (d : Parser.DotGraph) : Except String (ExprHigh String String current_extra_args ← addOpt current_extra_args "taggers_num" current_extra_args ← addOpt current_extra_args "tagger_id" - -- Different bitwidths matter so we distinguish them with types: Unit vs. Bool vs. T - if typVal = "Branch" then - if keyStartsWith l "in" "in1:0" then - typVal := "branch Unit" - else if keyStartsWith l "in" "in1:1" then - typVal := "branch Bool" - -- Else, if the bitwidth is 32, then: - else typVal := "branch T" - - -- Different bitwidths matter so we distinguish them with types: Unit vs. Bool vs. T - if typVal = "Fork" then - let [sizesIn] ← parseIOSizes l "in" |>.mapM translateSize - | throw "more that one input in fork" - typVal := s!"fork{keyArgNumbers l "out"} {sizesIn}" - - if typVal = "Mux" then - current_extra_args ← addOpt current_extra_args "delay" - if splitAndSearch l "in" "in2:0" then - typVal := s!"mux Unit" - else typVal := s!"mux T" - - if typVal = "Merge" then + if typVal == "Mux" || typVal == "Merge" then current_extra_args ← addOpt current_extra_args "delay" - if splitAndSearch l "in" "in0:0" then - typVal := s!"merge{keyArgNumbers l "in"} Unit" - else typVal := s!"merge{keyArgNumbers l "in"} T" if typVal = "Entry" then current_extra_args ← addOpt current_extra_args "control" - if typVal = "Constant" then let constVal ← keyArg l "value" |> Parser.hexParser.run - typVal := s!"constant {constVal}" current_extra_args ← addOpt current_extra_args "value" - if typVal = "Sink" then - if splitAndSearch l "in" "in1:0" then - typVal := s!"sink Unit 1" - else if splitAndSearch l "in" "in1:1" then - typVal := s!"sink Bool 1" - else - typVal := s!"sink T 1" - if typVal = "Operator" then if splitAndSearch l "op" "mc_store_op" || splitAndSearch l "op" "mc_load_op" then current_extra_args ← addOpt current_extra_args "portId" @@ -324,26 +294,14 @@ def dotToExprHigh (d : Parser.DotGraph) : Except String (ExprHigh String String current_extra_args ← addOpt current_extra_args "constants" current_extra_args ← add current_extra_args "op" - if splitAndSearch l "op" "mc_load_op" then - typVal := s!"load T T" - else if splitAndSearch l "op" "cast" then - typVal := s!"pure T Bool" - else - let sizesIn ← parseIOSizes l "in" |>.mapM translateSize - let sizesOut ← parseIOSizes l "out" |>.mapM translateSize - typVal := s!"operator{keyArgNumbers l "in"} {" ".intercalate sizesIn} {" ".intercalate sizesOut} {keyArg l "op"}" - - -- portId= 0, offset= 0 -- if mc_store_op and mc_load_op - - if typVal = "MC" then - let sizesIn ← parseIOSizes l "in" |>.mapM translateSize - let sizesOut ← parseIOSizes l "out" |>.filter (·.trim.endsWith "*e" |> not) |>.mapM translateSize - typVal := s!"operator{keyArgNumbers l "in"} {" ".intercalate sizesIn} {" ".intercalate sizesOut} MC" + if typVal == "MC" then current_extra_args ← addOpt current_extra_args "memory" current_extra_args ← addOpt current_extra_args "bbcount" current_extra_args ← addOpt current_extra_args "ldcount" current_extra_args ← addOpt current_extra_args "stcount" + typVal := dynamaticToGraphiti typVal (toSizeInfo l "in") (toSizeInfo l "out") + let cluster := l.find? (·.key = "cluster") |>.getD ⟨"cluster", "false"⟩ let .ok clusterB := Parser.parseBool.run cluster.value.trim | throw s!"{s}: cluster could not be parsed" diff --git a/Graphiti/Core/DynamaticPrinter.lean b/Graphiti/Core/DynamaticPrinter.lean index cd4dfb6..abdd042 100644 --- a/Graphiti/Core/DynamaticPrinter.lean +++ b/Graphiti/Core/DynamaticPrinter.lean @@ -8,6 +8,8 @@ module public import Graphiti.Core.Rewriter public import Graphiti.Core.TypeExpr +public import Graphiti.Core.DynamaticTypes +public import Graphiti.Core.WellTyped public section @@ -107,54 +109,6 @@ def translateTypes (key : String) : Option String × String × String × List ( def removeLetter (ch : Char) (s : String) : String := String.mk (s.toList.filter (λ c => c ≠ ch)) -def returnNatInstring (s : String) : Option Nat := - -- Convert the string to a list of characters - let chars := s.toList - let result := List.foldl (λ acc c => - if c.isDigit then - acc * 10 + (Char.toNat c - Char.toNat '0') - else - acc) 0 chars - -- If no non-digit character was encountered, return the result - -- if result = 0 then - -- if s.isEmpty then some 0 else none - -- else - some result - -def incrementDefinitionPortIdx (s direction: String) : String := - -- Split the string by spaces into individual parts (like "out0:32") - let parts := s.splitOn " " - -- Map over each part, incrementing the number after "out" - let updatedParts := parts.map (λ part => - match part.splitOn ":" with - | [pref, num] => - match (returnNatInstring pref) with - | some n => - -- Increment the number found - let incremented := n + 1 - -- Reconstruct the string with the incremented number - direction ++ Nat.repr incremented ++ ":" ++ (num) - | none => part -- If no number is found, keep the part unchanged - | _ => part -- If the part doesn't have ":" or a valid number, keep it unchanged - ) - -- Join the updated parts into a single string with spaces - String.intercalate " " updatedParts - --- #eval incrementDefinitionPortIdx "out1:32" "out" --out1:324 out2:32 out3:32" "out" -- Output: "out1:32 out4:32 out3:32" - --- #eval "out132".splitOn ":" - -def incrementConnectionPortIdx (s direction: String) : String := - match returnNatInstring s with - | some n => - let incremented := n + 1 - -- Convert incremented number to a string and concatenate with the direction part - let incrementedStr := Nat.repr incremented - direction ++ incrementedStr - | none => s -- If no number is found, return the original string - --- #eval incrementConnectionPortIdx "out33" "out" - -- Function became messy... def formatOptions : List (String × String) → String | x :: l => l.foldl @@ -167,56 +121,43 @@ def formatOptions : List (String × String) → String s!", {x.1} = {v2_}") | [] => "" -def extractStandardType (s : String) : String := - let parts := s.splitOn " " - parts[0]! - -def capitalizeFirstChar (s : String) : String := - match s.get? 0 with - | none => s -- If the string is empty, return it as is - | some c => - let newChar := if 'a' ≤ c ∧ c ≤ 'z' then - Char.ofNat (c.toNat - ('a'.toNat - 'A'.toNat)) - else - c - newChar.toString ++ s.drop 1 - --- Join is taken in Dynamatic so rename to Concat -def RenameJoinToConcat (s : String) : String := - if String.isPrefixOf "join" s then - "Concat" - else - s -- Otherwise, return the original string - -def fixComponentNames (s : String) : String := - String.intercalate "_" (s.splitOn "__") +def inferTypeInPortMap (t : TypeUF) (p : PortMap String (InternalPort String)) (sn : String × Nat) : Except String (PortMap String TypeExpr) := + p.foldlM (λ st k v => do + let tc ← toTypeConstraint sn k.name + let concr := t.findConcr tc |>.getD (.var 1000) -- TODO: better handling of not finding a concretization + return st.cons k concr + ) ∅ + +def inferTypeInPortMapping (t : TypeUF) (p : PortMapping String) (sn : String × Nat) : Except String (PortMap String TypeExpr × PortMap String TypeExpr) := do + let inp ← inferTypeInPortMap t p.input sn + let out ← inferTypeInPortMap t p.output sn + return (inp, out) + +def toPortList (typs : PortMap String TypeExpr) : String := + typs.foldl (λ s k v => s ++ s!"{removeLetter 'p' k.name}:{TypeExpr.Parser.getSize v} ") "" --fmt.1: Type --fmt.2.1 and fmt.2.2.1: Input and output attributes --fmt.2.2.2: Additional options. -def dynamaticString (a: ExprHigh String (String × Nat)) (m : AssocList String (AssocList String String)): Option String := do - -- let instances := - -- a.modules.foldl (λ s inst mod => s ++ s!"\n {inst} [mod = \"{mod}\"];") "" - let a ← a.normaliseNames +def dynamaticString (a: ExprHigh String (String × Nat)) (t : TypeUF) (m : AssocList String (AssocList String String)): Except String String := do + let a ← ofOption' "could not normalise names" a.normaliseNames let modules ← a.modules.foldlM (λ s k v => do -- search for the type of the passed node in interfaceTypes - let fmt := translateTypes v.snd.1 + let typeName := graphitiToDynamatic v.2.1 |>.1 match m.find? k with | some input_fmt => -- If the node is found to be coming from the input, -- retrieve its attributes from what we saved and bypass it -- without looking for it in interfaceTypes - return (RenameJoinToConcat s) ++ s!"\"{k}\" [type = \"{capitalizeFirstChar (extractStandardType (fmt.1.getD v.snd.1))}\"{formatOptions input_fmt.toList}];\n" - --return s ++ s!"\"{k}\" [type = \"{fmt.1.getD v.snd}\"{formatOptions input_fmt.toList}];\n" + return s ++ s!"\"{k}\" [type = \"{typeName}\"{formatOptions input_fmt.toList}];\n" | none => + let typs ← inferTypeInPortMapping t v.1 v.2 -- If this is a new node, then we sue `fmt` to correctly add the right -- arguments from what is given in interfaceTypes. We should never be generating constructs like MC, so -- this shouldn't be a problem. - return (RenameJoinToConcat s) ++ s!"\"{k}\" [type = \"{capitalizeFirstChar (extractStandardType (fmt.1.getD v.snd.1))}\", in = \"{removeLetter 'p' fmt.2.1}\", out = \"{fmt.2.2.1}\"{formatOptions fmt.2.2.2}];\n" - --return s ++ s!"\"{k}\" [type = \"{fmt.1.getD v.snd}\", in = \"{removeLetter 'p' fmt.2.1}\", out = \" {fmt.2.2.1} \"{formatOptions fmt.2.2.2}];\n" - + return s ++ s!"\"{k}\" [type = \"{typeName}\", in = \"{toPortList typs.1}\", out = \"{toPortList typs.2}\"];\n" ) "" let connections := a.connections.foldl @@ -225,7 +166,7 @@ def dynamaticString (a: ExprHigh String (String × Nat)) (m : AssocList String ( ++ s!"[from = \"{oport.name}\"," ++ s!" to = \"{removeLetter 'p' iport.name}\" " ++ "];") "" - s!"Digraph G \{ + .ok s!"Digraph G \{ {modules} {connections} }" diff --git a/Graphiti/Core/DynamaticTypes.lean b/Graphiti/Core/DynamaticTypes.lean new file mode 100644 index 0000000..8de7bce --- /dev/null +++ b/Graphiti/Core/DynamaticTypes.lean @@ -0,0 +1,110 @@ +/- +Copyright (c) 2025 VCA Lab, EPFL. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Yann Herklotz +-/ + +module + +public import Graphiti.Core.AssocList +public import Graphiti.Core.Basic + +public section + +open Batteries (AssocList) + +namespace Graphiti + +def graphitiPrefix := "_graphiti_" + +/-- +Mapping from Graphiti names to Dynamatic names. +-/ +def dynamatic_types : AssocList String (String × List (Option Nat) × List (Option Nat)) := + [ ("join", ("Concat", [.none, .none], [.none])) + , ("split", ("Split", [.none], List.replicate 2 .none)) + , ("branch", ("Branch", [.none, .some 1], List.replicate 2 .none)) + , ("fork2", ("Fork", [.none], List.replicate 2 .none)) + , ("fork3", ("Fork", [.none], List.replicate 3 .none)) + , ("fork4", ("Fork", [.none], List.replicate 4 .none)) + , ("fork5", ("Fork", [.none], List.replicate 5 .none)) + , ("fork6", ("Fork", [.none], List.replicate 6 .none)) + , ("fork7", ("Fork", [.none], List.replicate 7 .none)) + , ("fork8", ("Fork", [.none], List.replicate 8 .none)) + , ("fork9", ("Fork", [.none], List.replicate 9 .none)) + , ("fork10", ("Fork", [.none], List.replicate 10 .none)) + , ("merge2", ("Merge", List.replicate 2 .none, [.none])) + , ("operator1", ("Operator", [.none], [.none])) + , ("operator2", ("Operator", List.replicate 2 .none, [.none])) + , ("operator3", ("Operator", List.replicate 3 .none, [.none])) + , ("operator4", ("Operator", List.replicate 4 .none, [.none])) + , ("operator5", ("Operator", List.replicate 5 .none, [.none])) + , ("mc", ("MC", [.some 32], [.some 32, .some 0])) + , ("mux", ("Mux", [.some 1, .none, .none], [.none])) + , ("input", ("Entry", [.some 0], [.some 0])) + , ("inputNat", ("Entry", [.some 32], [.some 32])) + , ("inputBool", ("Entry", [.some 1], [.some 1])) + , ("outputNat0", ("Exit", [], [.some 32])) + , ("output0", ("Exit", [], [.some 0])) + , ("outputBool0", ("Exit", [], [.some 1])) + , ("outputNat1", ("Exit", [.none], [.some 32])) + , ("output1", ("Exit", [.none], [.some 0])) + , ("outputBool1", ("Exit", [.none], [.some 1])) + , ("outputNat2", ("Exit", [.none, .none], [.some 32])) + , ("output2", ("Exit", [.none, .none], [.some 0])) + , ("outputBool2", ("Exit", [.none, .none], [.some 1])) + , ("outputNat3", ("Exit", [.none, .none, .none], [.some 32])) + , ("output3", ("Exit", [.none, .none, .none], [.some 0])) + , ("outputBool3", ("Exit", [.none, .none, .none], [.some 1])) + , ("outputNat4", ("Exit", [.none, .none, .none, .none], [.some 32])) + , ("output4", ("Exit", [.none, .none, .none, .none], [.some 0])) + , ("outputBool4", ("Exit", [.none, .none, .none, .none], [.some 1])) + , ("outputNat5", ("Exit", [.none, .none, .none, .none, .none], [.some 32])) + , ("output5", ("Exit", [.none, .none, .none, .none, .none], [.some 0])) + , ("outputBool5", ("Exit", [.none, .none, .none, .none, .none], [.some 1])) + , ("sink", ("Sink", [.none], [])) + , ("constantNat", ("Constant", [.some 0], [.some 32])) + , ("constantBool", ("Constant", [.some 0], [.some 1])) + , ("initBool", ("Init", [.some 1], [.some 1])) + , ("tag_untagger_val", ("TaggerUntagger", [.none, .none], [.none, .none])) + , ("load", ("Operator", [.some 32, .some 32], [.some 32, .some 32])) + ].toAssocList + +def similar {α} [BEq α] (l1 l2 : List (Option α)) : Bool := + match l1, l2 with + | [], [] => true + | .cons a1 b1, .cons a2 b2 => + match a1, a2 with + | .none, _ => similar b1 b2 + | _, .none => similar b1 b2 + | .some v1, .some v2 => v1 == v2 && similar b1 b2 + | _, _ => false + +def dynamaticToGraphiti (s : String) (inp out : List (Option Nat)) : String := + match dynamatic_types.findEntryP? (λ k v => v.1 == s && similar v.2.1 inp && similar v.2.2 out) with + | some (k, v) => k + | none => graphitiPrefix ++ s + +def graphitiToDynamatic (s : String) : String × List (Option Nat) × List (Option Nat) := + match dynamatic_types.find? s with + | some s' => s' + | none => + if graphitiPrefix.isPrefixOf s then (s.drop graphitiPrefix.length, [], []) else ("__unknown__"++s, [], []) + +def fromDynamaticPorts {α} (pref : String) (l : List α) : PortMap String α := + l.foldl (λ st a => (st.1.cons ↑s!"{pref}{st.2}" a, st.2+1)) (∅, 1) |>.1 + +def fromDynamaticInputPorts {α} := @fromDynamaticPorts α "inp" +def fromDynamaticOutputPorts {α} := @fromDynamaticPorts α "out" + +def toDynamaticPorts {α} (pref : String) (p : PortMap String α) : Nat → Option (List α) +| 0 => .some [] +| n+1 => do + let el ← p.find? ↑s!"{pref}{n}" + let l ← toDynamaticPorts pref p n + return l.concat el + +def toDynamaticInputPorts {α} := @toDynamaticPorts α "inp" +def toDynamaticOutputPorts {α} := @toDynamaticPorts α "out" + +end Graphiti diff --git a/Graphiti/Core/Rewriter.lean b/Graphiti/Core/Rewriter.lean index 13a9569..91a2977 100644 --- a/Graphiti/Core/Rewriter.lean +++ b/Graphiti/Core/Rewriter.lean @@ -39,8 +39,10 @@ structure RuntimeEntry where input_graph : ExprHigh String (String × Nat) output_graph : ExprHigh String (String × Nat) matched_subgraph : List String + matched_subgraph_types : List Nat renamed_input_nodes : AssocList String (Option String) new_output_nodes : List String + fresh_types : Nat debug : Option String := .none name : Option String := .none deriving Repr, Inhabited, DecidableEq @@ -60,11 +62,13 @@ instance : Lean.ToJson RuntimeEntry where Lean.Json.mkObj [ ("type", Lean.Format.pretty <| repr r.type) , ("name", Lean.toJson r.name) - , ("input_graph", toString (repr r.input_graph)) - , ("output_graph", toString (repr r.output_graph)) + , ("input_graph", toString <| repr r.input_graph) + , ("output_graph", toString <| repr r.output_graph) , ("matched_subgraph", Lean.toJson r.matched_subgraph) + , ("matched_subgraph_types", Lean.toJson r.matched_subgraph_types) , ("renamed_input_nodes", Lean.Json.mkObj <| r.renamed_input_nodes.toList.map (λ a => (a.1, Lean.toJson a.2))) , ("new_output_nodes", Lean.toJson r.new_output_nodes) + , ("fresh_types", Lean.toJson r.fresh_types) , ("debug", Lean.toJson r.debug) ] @@ -121,14 +125,6 @@ variable {Ident Typ} variable [Inhabited Ident] variable [Inhabited Typ] -def ofOption {ε α σ} (e : ε) : Option α → EStateM ε σ α -| some o => pure o -| none => throw e - -def ofOption' {ε α} (e : ε) : Option α → Except ε α -| some o => pure o -| none => throw e - def liftError {α σ} : Except String α → EStateM RewriteError σ α | .ok o => pure o | .error s => throw (.error s) @@ -187,6 +183,10 @@ def addRuntimeEntry (rinfo : RuntimeEntry) : RewriteResult Unit := do let l ← EStateM.get EStateM.set <| ⟨l.1.concat rinfo, l.2, l.3⟩ +def incrFreshType (n : Nat) : RewriteResult Unit := do + let l ← EStateM.get + EStateM.set <| ⟨l.1, l.2, l.3+n⟩ + def addRuntimeMarker (s : String) : RewriteResult Unit := do let l ← EStateM.get EStateM.set <| ⟨l.1.concat (RuntimeEntry.marker s), l.2, l.3⟩ @@ -234,7 +234,11 @@ however, currently the low-level expression language does not remember any names -- Pattern match on the graph and extract the first list of nodes that correspond to the first subgraph. let (sub, types) ← rewrite.pattern g |>.runWithState + + addRuntimeEntry <| RuntimeEntry.mk EntryType.debug g default default default default .nil current_state.fresh_type .none rewrite.name + let def_rewrite := rewrite.rewrite types current_state.fresh_type + incrFreshType rewrite.fresh_types -- Extract the actual subgraph from the input graph using the list of nodes `sub`. let (g₁, g₂) ← ofOption (.error "could not extract graph") <| g.extract sub @@ -249,8 +253,11 @@ however, currently the low-level expression language does not remember any names let sub' ← ofOption (.error "could not extract base information") <| sub.mapM (λ a => g.modules.find? a) let g_lower := canon <| ExprLow.comm_bases sub'.reverse g_lower - addRuntimeEntry <| RuntimeEntry.mk EntryType.rewrite g default sub default .nil .none rewrite.name - updRuntimeEntry λ rw => {rw with debug := (.some <| (toString <| repr e_sub) ++ "\n\n" ++ ((toString <| repr def_rewrite.input_expr)))} + updRuntimeEntry λ rw => { rw with + matched_subgraph := sub + matched_subgraph_types := types.toList + debug := (.some <| (toString <| repr e_sub) ++ "\n\n" ++ ((toString <| repr def_rewrite.input_expr))) + } -- beq is an α-equivalence check that returns a mapping to rename one expression into the other. This mapping is -- split into the external mapping and internal mapping. @@ -258,8 +265,6 @@ however, currently the low-level expression language does not remember any names let (ext_mapping, int_mapping) ← liftError <| def_rewrite.input_expr.weak_beq e_sub let comb_mapping := ext_mapping.append int_mapping |>.filterId - -- EStateM.guard (.error "input mapping not invertible") <| comb_mapping.input.invertible - -- EStateM.guard (.error "output mapping not invertible") <| comb_mapping.output.invertible updRuntimeEntry λ rw => {rw with debug := (.some (toString ext_mapping))} @@ -288,7 +293,6 @@ however, currently the low-level expression language does not remember any names -- Finally we do the actual replacement. let (rewritten, b) := g_lower.force_replace (canon e_sub_input) e_sub_output - -- throw (.error s!"mods :: {repr sub'}rhs :: {repr g_lower}\n\ndep :: {repr (canon e_sub_input)}") EStateM.guard (.error s!"rewrite: subexpression not found in the graph: {repr g_lower}\n\n{repr (canon e_sub_input)}") b let out ← rewritten |> ExprLow.higher_correct PortMapping.hashPortMapping @@ -301,10 +305,19 @@ however, currently the low-level expression language does not remember any names |>.map (renamePortMapping · e_output_norm) |>.map PortMapping.hashPortMapping - -- Using comb_mapping to find the portMap does not work because with rewrites where there is a single module, the name - -- won't even appear in the rewrite. - updRuntimeEntry <| λ _ => RuntimeEntry.mk EntryType.rewrite g out sub (sub.zip renamedNodes).toAssocList - addedNodes (.some (toString renamedNodes ++ "\n\n" ++ toString addedNodes)) rewrite.name + updRuntimeEntry <| λ _ => { + type := EntryType.rewrite + input_graph := g + output_graph := out + matched_subgraph := sub + matched_subgraph_types := types.toList + renamed_input_nodes := (sub.zip renamedNodes).toAssocList + new_output_nodes := addedNodes + fresh_types := current_state.fresh_type + debug := (.some (toString renamedNodes ++ "\n\n" ++ toString addedNodes)) + name := rewrite.name + } + -- updRuntimeEntry λ rw => {rw with debug := (.some (toString e_output_norm))} EStateM.guard (.error s!"found duplicate node") out.modules.keysList.Nodup @@ -357,7 +370,7 @@ def reverse_rewrite' (def_rewrite : DefiniteRewrite String (String × Nat)) (rin let lhs_renamed ← ofOption (.error "could not rename") <| def_rewrite.input_expr.renamePorts full_renaming let rhs_renamed ← ofOption (.error "could not rename") <| def_rewrite.output_expr.renamePorts full_renaming - addRuntimeEntry <| RuntimeEntry.mk EntryType.debug default default default default .nil (.some <| s!"{repr lhs_renamed}\n\n{repr rhs_renamed}\n\n{repr full_renaming}\n\n{repr rhs_renaming}\n\n{repr lhs_renaming}") s!"rev-{rinfo.name.getD "unknown"}" + addRuntimeEntry <| RuntimeEntry.mk EntryType.debug default default default default default .nil 0 (.some <| s!"{repr lhs_renamed}\n\n{repr rhs_renamed}\n\n{repr full_renaming}\n\n{repr rhs_renaming}\n\n{repr lhs_renaming}") s!"rev-{rinfo.name.getD "unknown"}" return ({ params := 0 pattern := λ _ => pure (rhsNodes', default), @@ -366,16 +379,18 @@ def reverse_rewrite' (def_rewrite : DefiniteRewrite String (String × Nat)) (rin -- TODO: These dictate ordering of nodes quite strictly. transformedNodes := rhsNodes_renamed.map some ++ rhsNodes_added.map (λ _ => none), addedNodes := lhsNodes.drop rhsNodes_renamed.length + fresh_types := 0 }) -/--rrrr +/-- Generate a reverse rewrite from a rewrite and the RuntimeEntry associated with the execution. -/ def reverse_rewrite (rw : Rewrite String (String × Nat)) (rinfo : RuntimeEntry) : RewriteResult (Rewrite String (String × Nat)) := do - let (_nodes, l) ← rw.pattern rinfo.input_graph |>.runWithState - let current_state ← EStateM.get - let def_rewrite := rw.rewrite l current_state.fresh_type - reverse_rewrite' def_rewrite rinfo + /- let (_nodes, l) ← rw.pattern rinfo.input_graph |>.runWithState -/ + if h : rinfo.matched_subgraph_types.toArray.size = rw.params then + let def_rewrite := rw.rewrite (Vector.mk rinfo.matched_subgraph_types.toArray h) rinfo.fresh_types + reverse_rewrite' def_rewrite rinfo + else throw <| .error s!"{rw.name}: size does not match {rinfo.matched_subgraph_types.toArray.size} != {rw.params}" /-- Abstract a subgraph into a separate node. One can imagine that the node type is then a node in the environment which is diff --git a/Graphiti/Core/RewriterLemmas.lean b/Graphiti/Core/RewriterLemmas.lean index a7150aa..4b78980 100644 --- a/Graphiti/Core/RewriterLemmas.lean +++ b/Graphiti/Core/RewriterLemmas.lean @@ -248,7 +248,7 @@ theorem run'_implies_wt_lhs {b} {ε_global : FinEnv String (String × Nat)} dsimp [RewriteResultSL.runWithState] at Hpattern cases Hpattern cases ‹EStateM.get _ = _› - rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ have wi_wf : ExprLow.well_formed ε_global.toEnv wi = true := by apply ExprLow.refines_comm_connections'_well_formed2 · apply ExprLow.replacement_well_formed2; rotate_left 1 @@ -338,7 +338,7 @@ theorem run'_implies_wf_lhs {b} {ε_global : FinEnv String (String × Nat)} dsimp [RewriteResultSL.runWithState] at Hpattern cases Hpattern cases ‹EStateM.get _ = _› - rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ have wi_wf : ExprLow.well_formed ε_global.toEnv wi = true := by apply ExprLow.refines_comm_connections'_well_formed2 · apply ExprLow.replacement_well_formed2; rotate_left 1 @@ -416,7 +416,7 @@ theorem run'_refines {b} {ε_global : FinEnv String (String × Nat)} dsimp [RewriteResultSL.runWithState] at Hpattern cases Hpattern cases ‹EStateM.get _ = _› - rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ have wi_wf : ExprLow.well_formed ε_global.toEnv wi = true := by apply ExprLow.refines_comm_connections'_well_formed2 · apply ExprLow.replacement_well_formed2; rotate_left 1 @@ -577,7 +577,7 @@ theorem run'_preserves_well_formed {b} {ε_global : FinEnv String (String × Nat dsimp [RewriteResultSL.runWithState] at Hpattern cases Hpattern cases ‹EStateM.get _ = _› - rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ have rw_output_wf : ExprLow.well_formed (ε_global ++ vrw.ε_ext).toEnv (rw.rewrite types st.fresh_type).output_expr = true := by apply ExprLow.refines_subset_well_formed apply FinEnv.independent_subset_of_union @@ -677,7 +677,7 @@ theorem run'_preserves_well_typed {b} {ε_global : FinEnv String (String × Nat) dsimp [RewriteResultSL.runWithState] at Hpattern cases Hpattern cases ‹EStateM.get _ = _› - rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + rename_i wo wi _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ have rw_output_wf : ExprLow.well_formed (ε_global ++ vrw.ε_ext).toEnv (rw.rewrite types st.fresh_type).output_expr = true := by apply ExprLow.refines_subset_well_formed apply FinEnv.independent_subset_of_union diff --git a/Graphiti/Core/Rewrites.lean b/Graphiti/Core/Rewrites.lean index e103479..db5f3d1 100644 --- a/Graphiti/Core/Rewrites.lean +++ b/Graphiti/Core/Rewrites.lean @@ -54,6 +54,8 @@ def rewrite_index := , JoinSplitLoopCond.rewrite , JoinSplitLoopCondAlt.rewrite , ReduceSplitJoin.rewrite + , PureRewrites.ConstantNat.rewrite + , PureRewrites.ConstantBool.rewrite , PureRewrites.Constant.rewrite , PureRewrites.Operator1.rewrite , PureRewrites.Operator2.rewrite @@ -91,37 +93,8 @@ def rewrite_index := ] def reverse_rewrite_with_index (rinfo : RuntimeEntry) : RewriteResult (Rewrite String (String × Nat)) := do - let rw ← ofOption (.error s!"{decl_name%}: rewrite generation failed") <| do - let name ← rinfo.name - match name with - | "join-split-elim" => - let s ← rinfo.matched_subgraph[0]? - return JoinSplitElim.targetedRewrite s - | "join-comm" => - let s ← rinfo.matched_subgraph[0]? - return JoinComm.targetedRewrite s - | "join-assoc-right" => - let s ← rinfo.matched_subgraph[0]? - return JoinAssocR.targetedRewrite s - | "join-assoc-left" => - let s ← rinfo.matched_subgraph[0]? - return JoinAssocL.targetedRewrite s - | "pure-fork" => - let s ← rinfo.matched_subgraph[0]? - return {PureRewrites.Fork.rewrite with pattern := PureRewrites.Fork.match_node s} - | "pure-operator3" => - let s ← rinfo.matched_subgraph[0]? - return {PureRewrites.Operator3.rewrite with pattern := PureRewrites.Operator3.match_node s} - | "pure-operator2" => - let s ← rinfo.matched_subgraph[0]? - return {PureRewrites.Operator2.rewrite with pattern := PureRewrites.Operator2.match_node s} - | "pure-operator1" => - let s ← rinfo.matched_subgraph[0]? - return {PureRewrites.Operator1.rewrite with pattern := PureRewrites.Operator1.match_node s} - | "pure-constant" => - let s ← rinfo.matched_subgraph[0]? - return {PureRewrites.Constant.rewrite with pattern := PureRewrites.Constant.match_node s} - | _ => rewrite_index.find? name + let name ← ofOption (.error s!"{decl_name%}: no rinfo report") rinfo.name + let rw ← ofOption (.error s!"{decl_name%}: '{name}' reverse rewrite generation failed") <| rewrite_index.find? name reverse_rewrite rw rinfo /-- diff --git a/Graphiti/Core/Rewrites/Fork6Rewrite.lean b/Graphiti/Core/Rewrites/Fork6Rewrite.lean index 4308116..4b15d04 100644 --- a/Graphiti/Core/Rewrites/Fork6Rewrite.lean +++ b/Graphiti/Core/Rewrites/Fork6Rewrite.lean @@ -15,7 +15,7 @@ def matcher : Pattern String (String × Nat) 1 := fun g => do let (.some list) ← g.modules.foldlM (λ s inst (pmap, typ) => do if s.isSome then return s - unless "fork" == typ.1 do return none + unless "fork6" == typ.1 do return none return some ([inst], #v[typ.2]) ) none | throw .done diff --git a/Graphiti/Core/Rewrites/ForkJoin.lean b/Graphiti/Core/Rewrites/ForkJoin.lean index 1fcff57..f70cbe4 100644 --- a/Graphiti/Core/Rewrites/ForkJoin.lean +++ b/Graphiti/Core/Rewrites/ForkJoin.lean @@ -20,7 +20,7 @@ def matcher : Pattern String (String × Nat) 2 := fun g => do unless "join" == typ.1 do return none let (.some p) := followOutput g inst "out1" | return none - unless "fork" == p.typ.1 do return none + unless "fork2" == p.typ.1 do return none return some ([inst, p.inst], #v[typ.2, p.typ.2]) ) none | MonadExceptOf.throw RewriteError.done diff --git a/Graphiti/Core/Rewrites/ForkPure.lean b/Graphiti/Core/Rewrites/ForkPure.lean index 708187a..32d7c17 100644 --- a/Graphiti/Core/Rewrites/ForkPure.lean +++ b/Graphiti/Core/Rewrites/ForkPure.lean @@ -20,7 +20,7 @@ def matcher : Pattern String (String × Nat) 2 := fun g => do unless "pure" == typ.1 do return none let (.some p) := followOutput g inst "out1" | return none - unless "fork" == p.typ.1 do return none + unless "fork2" == p.typ.1 do return none return some ([inst, p.inst], #v[typ.2, p.typ.2]) ) none | MonadExceptOf.throw RewriteError.done @@ -32,7 +32,7 @@ def lhs : ExprHigh String (String × Nat) := [graph| o2 [type = "io"]; pure [type = "pure", arg = $(T[0])]; - fork [type = "fork", arg = $(T[1])]; + fork [type = "fork2", arg = $(T[1])]; i -> pure [to="in1"]; pure -> fork [from="out1",to="in1"]; @@ -49,7 +49,7 @@ def rhs : ExprHigh String (String × Nat) := [graph| o1 [type = "io"]; o2 [type = "io"]; - fork [type = "fork", arg = $(M+1)]; + fork [type = "fork2", arg = $(M+1)]; pure1 [type = "pure", arg = $(M+2)]; pure2 [type = "pure", arg = $(M+3)]; diff --git a/Graphiti/Core/Rewrites/JoinAssocL.lean b/Graphiti/Core/Rewrites/JoinAssocL.lean index 04731c4..86bca04 100644 --- a/Graphiti/Core/Rewrites/JoinAssocL.lean +++ b/Graphiti/Core/Rewrites/JoinAssocL.lean @@ -16,13 +16,13 @@ open StringModule variable (T : Vector Nat 2) variable (M : Nat) -def identMatcher (s : String) : Pattern String (String × Nat) 2 := fun g => do - let n ← ofOption' (.error s!"{decl_name%}: could not find '{s}'") <| g.modules.find? s - unless "join" == n.2.1 do throw (.error s!"{decl_name%}: type of '{s}' is '{n.2}' instead of 'join'") - let next ← ofOption' (.error s!"{decl_name%}: could not find next node") <| followInput g s "in2" - unless "join" == next.typ.1 do throw (.error s!"{decl_name%}: type of '{next.inst}' is '{next.typ}' instead of 'join'") +def identMatcher (join2 : String) : Pattern String (String × Nat) 2 := fun g => do + let join2_typ ← ofOption' (.error s!"{decl_name%}: could not find '{join2}'") <| g.modules.find? join2 + unless "join" == join2_typ.2.1 do throw (.error s!"{decl_name%}: type of '{join2}' is '{join2_typ.2}' instead of 'join'") + let join1 ← ofOption' (.error s!"{decl_name%}: could not find next node") <| followInput g join2 "in2" + unless "join" == join1.typ.1 do throw (.error s!"{decl_name%}: type of '{join1.inst}' is '{join1.typ}' instead of 'join' 2") - return ([s, next.inst], #v[n.2.2, next.typ.2]) + return ([join1.inst, join2], #v[join1.typ.2, join2_typ.2.2]) def matcher : Pattern String (String × Nat) 2 := fun g => do throw (.error s!"{decl_name%}: matcher not implemented") @@ -59,8 +59,8 @@ def rhs : ExprHigh String (String × Nat) := [graph| i_2 [type = "io"]; o_out [type = "io"]; - join2 [type = "split", arg = $(M+1)]; - join1 [type = "split", arg = $(M+2)]; + join2 [type = "join", arg = $(M+1)]; + join1 [type = "join", arg = $(M+2)]; pure [type = "pure", arg = $(M+3)]; i_0 -> join2 [to = "in1"]; @@ -73,7 +73,7 @@ def rhs : ExprHigh String (String × Nat) := [graph| pure -> o_out [from = "out1"]; ] -def rhs_extract := (rhs M).extract ["join2", "join1", "pure"] |>.get rfl +def rhs_extract := (rhs M).extract ["join1", "join2", "pure"] |>.get rfl def rhsLower := (rhs_extract M).fst.lower.get rfl def findRhs mod := (rhs_extract 0).fst.modules.find? mod |>.map Prod.fst @@ -83,7 +83,7 @@ def rewrite : Rewrite String (String × Nat) := pattern := matcher, rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ name := "join-assoc-left" - transformedNodes := [findRhs "join2" |>.get rfl, findRhs "join1" |>.get rfl], + transformedNodes := [findRhs "join1" |>.get rfl, findRhs "join2" |>.get rfl], addedNodes := [findRhs "pure" |>.get rfl] fresh_types := 3 } diff --git a/Graphiti/Core/Rewrites/JoinAssocR.lean b/Graphiti/Core/Rewrites/JoinAssocR.lean index 15c7775..339ebc9 100644 --- a/Graphiti/Core/Rewrites/JoinAssocR.lean +++ b/Graphiti/Core/Rewrites/JoinAssocR.lean @@ -16,13 +16,13 @@ open StringModule variable (T : Vector Nat 2) variable (M : Nat) -def identMatcher (s : String) : Pattern String (String × Nat) 2 := fun g => do - let n ← ofOption' (.error s!"{decl_name%}: could not find '{s}'") <| g.modules.find? s - unless "join" == n.2.1 do throw (.error s!"type of '{s}' is '{n.2}' instead of 'join'") - let next ← ofOption' (.error s!"{decl_name%}: could not find next node") <| followInput g s "in1" - unless "join" == next.typ.1 do throw (.error s!"type of '{next.inst}' is '{next.typ}' instead of 'join'") +def identMatcher (join2 : String) : Pattern String (String × Nat) 2 := fun g => do + let join2_typ ← ofOption' (.error s!"{decl_name%}: could not find '{join2}'") <| g.modules.find? join2 + unless "join" == join2_typ.2.1 do throw (.error s!"{decl_name%}: type of '{join2}' is '{join2_typ.2}' instead of 'join'") + let join1 ← ofOption' (.error s!"{decl_name%}: could not find next node") <| followInput g join2 "in1" + unless "join" == join1.typ.1 do throw (.error s!"{decl_name%}: type of '{join1.inst}' is '{join1.typ}' instead of 'join' 2") - return ([s, next.inst], #v[n.2.2, next.typ.2]) + return ([join1.inst, join2], #v[join1.typ.2, join2_typ.2.2]) def matcher : Pattern String (String × Nat) 2 := fun g => do throw (.error s!"{decl_name%}: matcher not implemented") @@ -59,8 +59,8 @@ def rhs : ExprHigh String (String × Nat) := [graph| i_2 [type = "io"]; o_out [type = "io"]; - join2 [type = "split", arg = $(M+1)]; - join1 [type = "split", arg = $(M+2)]; + join2 [type = "join", arg = $(M+1)]; + join1 [type = "join", arg = $(M+2)]; pure [type = "pure", arg = $(M+3)]; i_1 -> join2 [to = "in1"]; @@ -73,7 +73,7 @@ def rhs : ExprHigh String (String × Nat) := [graph| pure -> o_out [from = "out1"]; ] -def rhs_extract := (rhs M).extract ["join2", "join1", "pure"] |>.get rfl +def rhs_extract := (rhs M).extract ["join1", "join2", "pure"] |>.get rfl def rhsLower := (rhs_extract M).fst.lower.get rfl def findRhs mod := (rhs_extract 0).fst.modules.find? mod |>.map Prod.fst @@ -83,7 +83,7 @@ def rewrite : Rewrite String (String × Nat) := pattern := matcher, rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ name := "join-assoc-right" - transformedNodes := [findRhs "join2" |>.get rfl, findRhs "join1" |>.get rfl], + transformedNodes := [findRhs "join1" |>.get rfl, findRhs "join2" |>.get rfl], addedNodes := [findRhs "pure" |>.get rfl] fresh_types := 3 } diff --git a/Graphiti/Core/Rewrites/JoinSplitElim.lean b/Graphiti/Core/Rewrites/JoinSplitElim.lean index 249613f..b674240 100644 --- a/Graphiti/Core/Rewrites/JoinSplitElim.lean +++ b/Graphiti/Core/Rewrites/JoinSplitElim.lean @@ -25,7 +25,7 @@ def identMatcher (s : String) : Pattern String (String × Nat) 2 := fun g => do unless next1.inputPort == "out1" do throw (.error s!"{decl_name%}: output port of split is incorrect") let next2 ← ofOption' (.error s!"{decl_name%}: could not find next node") <| followInput g s "in2" unless "split" == next2.typ.1 do - throw (.error s!"{decl_name%}: type of '{next2.inst}' is '{next2.typ}' instead of 'split'") + throw (.error s!"{decl_name%}: type of '{next2.inst}' is '{next2.typ}' instead of 'split' 2") unless next2.inputPort == "out2" do throw (.error s!"{decl_name%}: output port of split is incorrect") return ([s, next1.inst], #v[n.2.2, next1.typ.2]) @@ -40,8 +40,8 @@ def lhs : ExprHigh String (String × Nat) := [graph| i_0 [type = "io"]; o_out [type = "io"]; - split [type = "split", arg = $(T[0])]; - join [type = "join", arg = $(T[1])]; + split [type = "split", arg = $(T[1])]; + join [type = "join", arg = $(T[0])]; i_0 -> split [to = "in1"]; @@ -51,7 +51,7 @@ def lhs : ExprHigh String (String × Nat) := [graph| join -> o_out [from = "out1"]; ] -def lhs_extract := (lhs T).extract ["split", "join"] |>.get rfl +def lhs_extract := (lhs T).extract ["join", "split"] |>.get rfl theorem double_check_empty_snd : (lhs_extract T).snd = ExprHigh.mk ∅ ∅ := by rfl def lhsLower := (lhs_extract T).fst.lower.get rfl diff --git a/Graphiti/Core/Rewrites/JoinSplitLoopCond.lean b/Graphiti/Core/Rewrites/JoinSplitLoopCond.lean index 0445155..044b0d5 100644 --- a/Graphiti/Core/Rewrites/JoinSplitLoopCond.lean +++ b/Graphiti/Core/Rewrites/JoinSplitLoopCond.lean @@ -45,8 +45,8 @@ def lhs : ExprHigh String (String × Nat) := [graph| o_init [type = "io"]; branch [type = "branch", arg = $(T[0])]; - condFork [type = "fork2", arg = $(T[1])]; - init [type = "initBool", arg = $(T[2])]; + condFork [type = "fork2", arg = $(T[2])]; + init [type = "initBool", arg = $(T[1])]; c_i -> condFork [to="in1"]; d_i -> branch [to="in1"]; diff --git a/Graphiti/Core/Rewrites/JoinSplitLoopCondAlt.lean b/Graphiti/Core/Rewrites/JoinSplitLoopCondAlt.lean index 8b0869c..085f0d0 100644 --- a/Graphiti/Core/Rewrites/JoinSplitLoopCondAlt.lean +++ b/Graphiti/Core/Rewrites/JoinSplitLoopCondAlt.lean @@ -45,8 +45,8 @@ def lhs : ExprHigh String (String × Nat) := [graph| o_init [type = "io"]; branch [type = "branch", arg = $(T[0])]; - condFork [type = "fork2", arg = $(T[1])]; - init [type = "initBool", arg = $(T[2])]; + condFork [type = "fork2", arg = $(T[2])]; + init [type = "initBool", arg = $(T[1])]; c_i -> condFork [to="in1"]; d_i -> branch [to="in1"]; diff --git a/Graphiti/Core/Rewrites/LoadRewrite.lean b/Graphiti/Core/Rewrites/LoadRewrite.lean index 141a32d..668b3bc 100644 --- a/Graphiti/Core/Rewrites/LoadRewrite.lean +++ b/Graphiti/Core/Rewrites/LoadRewrite.lean @@ -20,7 +20,7 @@ def matcher : Pattern String (String × Nat) 2 := fun g => do unless typ.1 == "load" do return none let (.some mc) := followOutput g inst "out2" | return none - unless "operator1" == mc.typ.1 do return none + unless "mc" == mc.typ.1 do return none let (.some load) := followOutput g mc.inst "out1" | return none unless load.inst = inst do return none @@ -34,7 +34,7 @@ def lhs : ExprHigh String (String × Nat) := [graph| o_out [type = "io"]; load [type = "load", arg = $(T[0])]; - mc [type = "operator1", arg = $(T[1])]; + mc [type = "mc", arg = $(T[1])]; i_in -> load [to = "in2"]; load -> o_out [from = "out1"]; diff --git a/Graphiti/Core/Rewrites/LoopRewrite2.lean b/Graphiti/Core/Rewrites/LoopRewrite2.lean index bcde2a7..3c4b8e1 100644 --- a/Graphiti/Core/Rewrites/LoopRewrite2.lean +++ b/Graphiti/Core/Rewrites/LoopRewrite2.lean @@ -79,7 +79,7 @@ def rhs : ExprHigh String (String × Nat) := [graph| split_bool [type = "split", arg = $(M+6)]; join_tag [type = "join", arg = $(M+7)]; join_bool [type = "join", arg = $(M+8)]; - mod [type = "pure", arg = $(M+9)]; + mod [type = "pure", arg = $(T[4])]; i_in -> tagger [to="in2"]; tagger -> o_out [from="out2"]; @@ -99,21 +99,21 @@ def rhs : ExprHigh String (String × Nat) := [graph| branch -> tagger [from="out2", to="in1"]; ] -def rhs_extract := (rhs M).extract ["merge", "branch", "tag_split", "mod", "tagger", "split_tag", "split_bool", "join_tag", "join_bool"] |>.get rfl -def rhsLower := (rhs_extract M).fst.lower.get rfl -def findRhs mod := (rhs_extract 0).fst.modules.find? mod |>.map Prod.fst +def rhs_extract := (rhs T M).extract ["merge", "branch", "tag_split", "mod", "tagger", "split_tag", "split_bool", "join_tag", "join_bool"] |>.get rfl +def rhsLower := (rhs_extract T M).fst.lower.get rfl +def findRhs mod := (rhs_extract #v[0, 0, 0, 0, 0, 0] 0).fst.modules.find? mod |>.map Prod.fst def rewrite : Rewrite String (String × Nat) := { params := 6 pattern := matcher, - rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ + rewrite := λ l n => ⟨lhsLower l, rhsLower l n⟩ name := .some "loop-rewrite" transformedNodes := [ findRhs "merge" |>.get rfl, .none, findRhs "branch" |>.get rfl , findRhs "tag_split" |>.get rfl, findRhs "mod" |>.get rfl, .none] addedNodes := [ findRhs "tagger" |>.get rfl, findRhs "split_tag" |>.get rfl, findRhs "split_bool" |>.get rfl , findRhs "join_tag" |>.get rfl, findRhs "join_bool" |>.get rfl] - fresh_types := 9 + fresh_types := 8 } end Graphiti.LoopRewrite2 diff --git a/Graphiti/Core/Rewrites/PureRewrites.lean b/Graphiti/Core/Rewrites/PureRewrites.lean index 33ac5b1..10bd12f 100644 --- a/Graphiti/Core/Rewrites/PureRewrites.lean +++ b/Graphiti/Core/Rewrites/PureRewrites.lean @@ -70,6 +70,110 @@ def rewrite : Rewrite String (String × Nat) where end Constant +namespace ConstantNat + +def extract_type (typ : String × Nat) : RewriteResultSL (Vector Nat 1) := do + unless typ.1 == "constantNat" do throw .done + return #v[typ.2] + +def match_node := Graphiti.match_node extract_type + +def matcher : Pattern String (String × Nat) 1 := fun g => do + throw (.error s!"{decl_name%}: matcher not implemented") + +variable (T : Vector Nat 1) +variable (M : Nat) + +def lhs : ExprHigh String (String × Nat) := [graph| + i [type = "io"]; + o [type = "io"]; + + const [type = "constantNat", arg = $(T[0])]; + + i -> const [to="in1"]; + const -> o [from="out1"]; + ] + +def lhs_extract := (lhs T).extract ["const"] |>.get rfl +theorem double_check_empty_snd : (lhs_extract T).snd = ExprHigh.mk ∅ ∅ := by rfl +def lhsLower := (lhs_extract T).fst.lower.get rfl + +def rhs : ExprHigh String (String × Nat) := [graph| + i [type = "io"]; + o [type = "io"]; + + const [type = "pure", arg = $(M+1)]; + + i -> const [to="in1"]; + const -> o [from="out1"]; + ] + +def rhsLower := (rhs M).lower.get rfl +def findRhs mod := (rhs 0).modules.find? mod |>.map Prod.fst + +def rewrite : Rewrite String (String × Nat) where + abstractions := [] + params := 1 + pattern := matcher + rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ + name := .some "pure-constant-nat" + transformedNodes := [findRhs "const" |>.get rfl] + fresh_types := 1 + +end ConstantNat + +namespace ConstantBool + +def extract_type (typ : String × Nat) : RewriteResultSL (Vector Nat 1) := do + unless typ.1 == "constantBool" do throw .done + return #v[typ.2] + +def match_node := Graphiti.match_node extract_type + +def matcher : Pattern String (String × Nat) 1 := fun g => do + throw (.error s!"{decl_name%}: matcher not implemented") + +variable (T : Vector Nat 1) +variable (M : Nat) + +def lhs : ExprHigh String (String × Nat) := [graph| + i [type = "io"]; + o [type = "io"]; + + const [type = "constantBool", arg = $(T[0])]; + + i -> const [to="in1"]; + const -> o [from="out1"]; + ] + +def lhs_extract := (lhs T).extract ["const"] |>.get rfl +theorem double_check_empty_snd : (lhs_extract T).snd = ExprHigh.mk ∅ ∅ := by rfl +def lhsLower := (lhs_extract T).fst.lower.get rfl + +def rhs : ExprHigh String (String × Nat) := [graph| + i [type = "io"]; + o [type = "io"]; + + const [type = "pure", arg = $(M+1)]; + + i -> const [to="in1"]; + const -> o [from="out1"]; + ] + +def rhsLower := (rhs M).lower.get rfl +def findRhs mod := (rhs 0).modules.find? mod |>.map Prod.fst + +def rewrite : Rewrite String (String × Nat) where + abstractions := [] + params := 1 + pattern := matcher + rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ + name := .some "pure-constant-nat" + transformedNodes := [findRhs "const" |>.get rfl] + fresh_types := 1 + +end ConstantBool + namespace Operator1 def extract_type (typ : String × Nat) : RewriteResultSL (Vector Nat 1) := do @@ -293,7 +397,7 @@ def rhs : ExprHigh String (String × Nat) := [graph| split [type = "split", arg = $(M+2)]; i -> op [to="in1"]; - op -> split [from="out1",to="in1"]; + op -> split [from="out1", to="in1"]; split -> o1 [from="out1"]; split -> o2 [from="out2"]; ] @@ -322,6 +426,16 @@ def specialisedPureRewrites {n} (p : Pattern String (String × Nat) n) := let (s :: _, t) ← p g | throw RewriteError.done Constant.match_node s g } + , { ConstantNat.rewrite with + pattern := fun g => do + let (s :: _, t) ← p g | throw RewriteError.done + ConstantNat.match_node s g + } + , { ConstantBool.rewrite with + pattern := fun g => do + let (s :: _, t) ← p g | throw RewriteError.done + ConstantBool.match_node s g + } , { Operator1.rewrite with pattern := fun g => do let (s :: _, t) ← p g | throw RewriteError.done @@ -346,6 +460,8 @@ def specialisedPureRewrites {n} (p : Pattern String (String × Nat) n) := def singleNodePureRewrites (s : String) := [ {Constant.rewrite with pattern := Constant.match_node s } + , {ConstantNat.rewrite with pattern := ConstantNat.match_node s } + , {ConstantBool.rewrite with pattern := ConstantBool.match_node s } , {Operator1.rewrite with pattern := Operator1.match_node s } , {Operator2.rewrite with pattern := Operator2.match_node s } , {Operator3.rewrite with pattern := Operator3.match_node s } diff --git a/Graphiti/Core/Rewrites/PureSink.lean b/Graphiti/Core/Rewrites/PureSink.lean index 720781f..d77a85c 100644 --- a/Graphiti/Core/Rewrites/PureSink.lean +++ b/Graphiti/Core/Rewrites/PureSink.lean @@ -59,7 +59,6 @@ def findRhs mod := (rhs_extract 0).1.modules.find? mod |>.map Prod.fst def rewrite : Rewrite String (String × Nat) := { params := 2 - abstractions := [], pattern := matcher, rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ name := "pure-sink" diff --git a/Graphiti/Core/Rewrites/PureSplit.lean b/Graphiti/Core/Rewrites/PureSplit.lean deleted file mode 100644 index 3a79b54..0000000 --- a/Graphiti/Core/Rewrites/PureSplit.lean +++ /dev/null @@ -1,79 +0,0 @@ -/- -Copyright (c) 2025 VCA Lab, EPFL. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Yann Herklotz --/ - -import Graphiti.Core.Rewriter -import Graphiti.Core.ExprHighElaborator - -namespace Graphiti.PureJoinRight - -open StringModule - -variable (T : Vector Nat 2) -variable (M : Nat) - -def matcher : Pattern String (String × Nat) 2 := fun g => do - let (.some list) ← g.modules.foldlM (λ s inst (pmap, typ) => do - if s.isSome then return s - unless "pure" == typ.1 do return none - - let (.some join) := followOutput g inst "out1" | return none - unless "join" == join.typ.1 ∧ join.inputPort = "in2" do return none - - return some ([inst, join.inst], #v[typ.2, join.typ.2]) - ) none | MonadExceptOf.throw RewriteError.done - return list - -def lhs : ExprHigh String (String × Nat) := [graph| - i_0 [type = "io"]; - i_1 [type = "io"]; - o_out [type = "io"]; - - pure [type = "pure", arg = $(T[0])]; - join [type = "join", arg = $(T[1])]; - - i_0 -> join [to="in1"]; - i_1 -> pure [to="in1"]; - - pure -> join [from="out1", to="in2"]; - - join -> o_out [from="out1"]; - ] - -def lhs_extract := (lhs T).extract ["pure", "join"] |>.get rfl -theorem double_check_empty_snd : (lhs_extract T).snd = ExprHigh.mk ∅ ∅ := by rfl -def lhsLower := (lhs_extract T).fst.lower.get rfl - -def rhs : ExprHigh String (String × Nat) := [graph| - i_0 [type = "io"]; - i_1 [type = "io"]; - o_out [type = "io"]; - - join [type = "join", arg = $(M+1)]; - pure [type = "pure", arg = $(M+2)]; - - i_0 -> join [to="in1"]; - i_1 -> join [to="in2"]; - - join -> pure [from="out1", to="in1"]; - - pure -> o_out [from="out1"]; - ] - -def rhs_extract := (rhs M).extract ["pure", "join"] |>.get rfl -def rhsLower := (rhs_extract M).fst.lower.get rfl -def findRhs mod := (rhs_extract 0).fst.modules.find? mod |>.map Prod.fst - -def rewrite : Rewrite String (String × Nat) := - { abstractions := [], - params := 2 - pattern := matcher, - rewrite := λ l n => ⟨lhsLower l, rhsLower n⟩ - transformedNodes := [findRhs "pure" |>.get rfl, findRhs "join" |>.get rfl] - name := "pure-join-right" - fresh_types := 2 - } - -end Graphiti.PureJoinRight diff --git a/Graphiti/Core/Rewrites/ReduceSplitJoin.lean b/Graphiti/Core/Rewrites/ReduceSplitJoin.lean index 9fd30f0..74322e1 100644 --- a/Graphiti/Core/Rewrites/ReduceSplitJoin.lean +++ b/Graphiti/Core/Rewrites/ReduceSplitJoin.lean @@ -28,7 +28,7 @@ def matcher : Pattern String (String × Nat) 2 := fun g => do unless join_nn.inst = join_nn'.inst do return none unless join_nn.inputPort = "in1" && join_nn'.inputPort = "in2" do return none - return some ([join_nn.inst, inst], #v[typ.2, join_nn.typ.2]) + return some ([join_nn.inst, inst], #v[join_nn.typ.2, typ.2]) ) none | throw .done return list @@ -36,7 +36,7 @@ def lhs (T : Vector Nat 2) : ExprHigh String (String × Nat) := [graph| i [type = "io"]; o [type = "io"]; - split [type = "split", arg = $(T[0])]; + split [type = "split", arg = $(T[1])]; join [type = "join", arg = $(T[0])]; i -> split [to="in1"]; diff --git a/Graphiti/Core/TypeExpr.lean b/Graphiti/Core/TypeExpr.lean index c20d3e6..b8ce58a 100644 --- a/Graphiti/Core/TypeExpr.lean +++ b/Graphiti/Core/TypeExpr.lean @@ -22,6 +22,7 @@ inductive TypeExpr where | bool | tag | unit +| var (n : Nat) | pair (left right : TypeExpr) deriving Repr, DecidableEq, Inhabited, Hashable @@ -32,6 +33,7 @@ def toBlueSpec : TypeExpr → String | .tag => "Token" | .bool => "Bool" | .unit => "Void" +| .var n => s!"Var#({n})" | .pair left right => s!"Tuple2#({toBlueSpec left}, {toBlueSpec right})" @@ -77,6 +79,7 @@ noncomputable def TypeExpr.denote : TypeExpr → Type | tag => Unit | bool => Bool | unit => Unit +| var n => Unit | pair t1 t2 => t1.denote × t2.denote def ValExpr.type : ValExpr → TypeExpr @@ -158,6 +161,7 @@ def toString' : TypeExpr → String | TypeExpr.tag => "TagT" -- Unclear how we want to display TagT at the end? | TypeExpr.bool => "Bool" | TypeExpr.unit => "Unit" + | TypeExpr.var n => s!"Var({n})" | TypeExpr.pair left right => let leftStr := toString' left let rightStr := toString' right @@ -168,6 +172,7 @@ def getSize: TypeExpr → Int | TypeExpr.tag => 0 | TypeExpr.bool => 1 | TypeExpr.unit => 0 + | TypeExpr.var n => 0 | TypeExpr.pair left right => let l := getSize left let r := getSize right @@ -201,6 +206,7 @@ def flatten_type (t : TypeExpr) : List TypeExpr := | TypeExpr.tag => [t] | TypeExpr.bool => [t] | TypeExpr.unit => [t] + | TypeExpr.var _ => [t] | TypeExpr.pair left right => flatten_type left ++ flatten_type right diff --git a/Graphiti/Core/WellTyped.lean b/Graphiti/Core/WellTyped.lean index 47c6746..ba53abc 100644 --- a/Graphiti/Core/WellTyped.lean +++ b/Graphiti/Core/WellTyped.lean @@ -140,11 +140,16 @@ theorem build_module_interface_build_module_interface'' {e : ExprLow Ident Typ} rw [h2'] at h1; dsimp at h1; cases h1; cases h3'; cases h4' grind +end BuildModule + +end ExprLow + inductive TypeExpr' where | nat | bool | tag | unit +| var (n : Nat) | pair (left right : Nat) deriving Repr, DecidableEq, Inhabited, Hashable @@ -174,109 +179,191 @@ def union (t : TypeUF) (t1 t2 : TypeConstraint) : TypeUF := {t'' with ufMap := t''.ufMap.union! v1 v2} def findConcrTE'' (t : TypeUF) (c : Nat) : Option TypeExpr' := do - let (.concr n1Concr, _) ← + let (.concr n1Concr, _) := t.typeMap.toList.filter (fun val => t.ufMap.root! val.2 == t.ufMap.root! c && val.1.isConcr?) |>.head? + |>.getD (.concr (.var (t.ufMap.root! c)), 0) | failure return n1Concr def findConcrTE' (t : TypeUF) (c : TypeConstraint) : Option TypeExpr' := findConcrTE'' t (t.insert c).1 -partial def toTypeExpr (t : TypeUF) (e : TypeExpr') : Option TypeExpr := +def toTypeExpr' (t : TypeUF) (e : TypeExpr') : Nat → Option TypeExpr +| 0 => none +| n+1 => match e with | .nat => .some .nat | .bool => .some .bool | .tag => .some .tag | .unit => .some .unit + | .var n => .some (.var n) | .pair n1 n2 => do - let concr := t.typeMap.toList.filter (match ·.fst with | .concr _ => true | _ => false) - let (.concr n1Concr, _) ← concr.filter (fun val => t.ufMap.root! val.2 == t.ufMap.root! n1) |>.head? | failure - let (.concr n2Concr, _) ← concr.filter (fun val => t.ufMap.root! val.2 == t.ufMap.root! n2) |>.head? | failure - let n1Concr ← toTypeExpr t n1Concr - let n2Concr ← toTypeExpr t n2Concr + let n1Concr ← findConcrTE'' t n1 >>= fun x => toTypeExpr' t x n + let n2Concr ← findConcrTE'' t n2 >>= fun x => toTypeExpr' t x n .some <| .pair n1Concr n2Concr +def toTypeExpr (t : TypeUF) (e : TypeExpr') : Option TypeExpr := toTypeExpr' t e 1000 + +def findConcr' (t : TypeUF) (c : Nat) : Option TypeExpr := t.findConcrTE'' c >>= t.toTypeExpr def findConcr (t : TypeUF) (c : TypeConstraint) : Option TypeExpr := t.findConcrTE' c >>= t.toTypeExpr end TypeUF -def toTypeConstraint (sn : String × Nat) (i : String) (t : TypeUF) : Except String (TypeConstraint × TypeUF) := +/-- +Generates the type constraint that corresponds to the type of the port in the module. +-/ +def toTypeConstraint (sn : String × Nat) (i : String) : Except String TypeConstraint := match sn.1 with | "mux" => match i with - | "in1" => .ok (.concr .bool, t) - | "in2" | "in3" | "out1" => .ok (.var sn.2, t) - | _ => .error s!"could not find port" + | "in1" => .ok (.concr .bool) + | "in2" | "in3" | "out1" => .ok (.var sn.2) + | _ => .error s!"could not find port: {sn}/{i}" | "branch" => match i with - | "in2" => .ok (.concr .bool, t) - | "in1" | "out1" | "out2" => .ok (.var sn.2, t) - | _ => .error s!"could not find port" + | "in2" => .ok (.concr .bool) + | "in1" | "out1" | "out2" => .ok (.var sn.2) + | _ => .error s!"could not find port: {sn}/{i}" | "split" => match i with - | "in1" => - let (vfst, t1) := t.insert (.uninterp "fst" (.var sn.2)) - let (vsnd, t2) := t1.insert (.uninterp "snd" (.var sn.2)) - let t3 := t2.union (.var sn.2) (.concr (.pair vfst vsnd)) - .ok (.var sn.2, t3) - | "out1" => .ok (.uninterp "fst" (.var sn.2), t) - | "out2" => .ok (.uninterp "snd" (.var sn.2), t) - | _ => .error s!"could not find port" + | "in1" => .ok (.var sn.2) + | "out1" => .ok (.uninterp "fst" (.var sn.2)) + | "out2" => .ok (.uninterp "snd" (.var sn.2)) + | _ => .error s!"could not find port: {sn}/{i}" | "join" => match i with - | "out1" => - let (vfst, t1) := t.insert (.uninterp "fst" (.var sn.2)) - let (vsnd, t2) := t1.insert (.uninterp "snd" (.var sn.2)) - let t3 := t2.union (.var sn.2) (.concr (.pair vfst vsnd)) - .ok (.var sn.2, t) - | "in1" => .ok (.uninterp "fst" (.var sn.2), t) - | "in2" => .ok (.uninterp "snd" (.var sn.2), t) - | _ => .error s!"could not find port" - | "operator" + | "out1" => .ok (.var sn.2) + | "in1" => .ok (.uninterp "fst" (.var sn.2)) + | "in2" => .ok (.uninterp "snd" (.var sn.2)) + | _ => .error s!"could not find port: {sn}/{i}" + | "operator1" + | "operator2" + | "operator3" + | "operator4" + | "operator5" | "pure" => match i with - | "in1" => .ok (.uninterp "dom1" (.var sn.2), t) - | "in2" => .ok (.uninterp "dom2" (.var sn.2), t) - | "in3" => .ok (.uninterp "dom3" (.var sn.2), t) - | "in4" => .ok (.uninterp "dom4" (.var sn.2), t) - | "in5" => .ok (.uninterp "dom5" (.var sn.2), t) - | "out1" => .ok (.uninterp "codom" (.var sn.2), t) - | _ => .error s!"could not find port" - | "initBool" => + | "in1" => .ok (.uninterp "dom1" (.var sn.2)) + | "in2" => .ok (.uninterp "dom2" (.var sn.2)) + | "in3" => .ok (.uninterp "dom3" (.var sn.2)) + | "in4" => .ok (.uninterp "dom4" (.var sn.2)) + | "in5" => .ok (.uninterp "dom5" (.var sn.2)) + | "out1" => .ok (.uninterp "codom" (.var sn.2)) + | _ => .error s!"could not find port: {sn}/{i}" + | "tag_untagger_val" => match i with - | "in1" => .ok (.concr .bool, t) - | "out1" => .ok (.concr .bool, t) - | _ => .error s!"could not find port" + | "in2" | "out2" => .ok (.uninterp "snd" (.var sn.2)) + | "in1" | "out1" => .ok (.var sn.2) + | _ => .error s!"could not find port: {sn}/{i}" | "fork2" + | "fork3" + | "fork4" + | "fork5" + | "fork6" + | "fork7" + | "fork8" + | "fork9" + | "fork10" | "fork" | "merge" | "merge2" | "sink" - | "queue" => .ok (.var sn.2, t) - | "output" + | "queue" => .ok (.var sn.2) + | "constantNat" + | "mc" + | "load" + | "inputNat" + | "outputNat0" + | "outputNat1" + | "outputNat2" + | "outputNat3" + | "outputNat4" + | "outputNat5" => .ok (.concr .nat) + | "initBool" + | "outputBool0" + | "outputBool1" + | "outputBool2" + | "outputBool3" + | "outputBool4" + | "outputBool5" + | "constantBool" => .ok (.concr .bool) + | "input" + | "output0" + | "output1" + | "output2" + | "output3" + | "output4" + | "output5" => .ok (.concr .unit) + | s => + -- If we don't know the type of the node, then we just return an uninterpreted function per port. + if "_graphiti_".isPrefixOf s then + .ok (.uninterp i (.var sn.2)) + else + .error s!"could not find node: {sn}/{i}" + +/-- +Generates additional typing constraints for type inference for each of the types. +-/ +def additionalConstraints (sn : String × Nat) (t : TypeUF) : TypeUF := + match sn.1 with + | "split" => + let (vfst, t1) := t.insert (.uninterp "fst" (.var sn.2)) + let (vsnd, t2) := t1.insert (.uninterp "snd" (.var sn.2)) + t2.union (.var sn.2) (.concr (.pair vfst vsnd)) + | "join" => + let (vfst, t1) := t.insert (.uninterp "fst" (.var sn.2)) + let (vsnd, t2) := t1.insert (.uninterp "snd" (.var sn.2)) + t2.union (.var sn.2) (.concr (.pair vfst vsnd)) + | "operator" + | "pure" => t.insert (.var sn.2) |>.snd + | "tagger_untagger_val" => + let (_, t) := t.insert (.var sn.2) + let (_, t) := t.insert (.uninterp "snd" (.var sn.2)) + t.union (.uninterp "fst" (.var sn.2)) (.concr .tag) | "constantNat" - | "input" => .ok (.concr .nat, t) - | "constantBool" => .ok (.concr .bool, t) - | _ => .error s!"could not find port" + | "load" + | "mc" + | "inputNat" + | "outputNat" => t.union (.var sn.2) (.concr .nat) + | "initBool" + | "constantBool" => t.union (.var sn.2) (.concr .bool) + /- | "output" + - | "input" -/ + | _ => t + +namespace ExprLow -def infer_types (e : ExprLow String (String × Nat)) (t : TypeUF) : Except String TypeUF := +def infer_equalities (e : ExprLow String (String × Nat)) (t : TypeUF) : Except String TypeUF := match e with | .base i e => .ok t | .connect c e => - match findInputInst c.input e, findOutputInst c.output e, e.infer_types t with + match findInputInst c.input e, findOutputInst c.output e, e.infer_equalities t with | some inp, some out, .ok t => match inp.1.input.inverse.find? c.input, out.1.output.inverse.find? c.output with | some ⟨_, inpV⟩, some ⟨_, outV⟩ => do - let (tinp, t') ← toTypeConstraint inp.2 inpV t - let (tout, t'') ← toTypeConstraint out.2 outV t' - .ok (t''.union tinp tout) + let tinp ← toTypeConstraint inp.2 inpV + let tout ← toTypeConstraint out.2 outV + .ok (additionalConstraints inp.2 t |> additionalConstraints out.2 |>.union tinp tout) | _, _ => .error s!"could not find I/O in portmap {c.input}/{c.output}" | none, _, _ => .error s!"findInputInst failed on inp: {c.input}" | _, none, _ => .error s!"findOutputInst failed on inp: {c.output}" | _, _, .error s => .error s - | .product e1 e2 => e1.infer_types t >>= e2.infer_types - -end BuildModule + | .product e1 e2 => e1.infer_equalities t >>= e2.infer_equalities + +def infer_types (e : ExprLow String (String × Nat)) : Except String (ExprLow String (String × TypeExpr)) := do + let eqs ← e.infer_equalities ⟨∅, Batteries.UnionFind.mkEmpty 100⟩ + go eqs e + where + go (t : TypeUF) : ExprLow String (String × Nat) → Except String (ExprLow String (String × TypeExpr)) + | .base inst typ => + match t.findConcr (.var typ.2) with + | some concr => .ok (.base inst (typ.1, concr)) + | none => .error s!"could not find concrete type for {typ}" + | .connect c e => .connect c <$> go t e + | .product e1 e2 => do + let e1' ← go t e1 + let e2' ← go t e2 + return .product e1' e2' end ExprLow @@ -295,6 +382,21 @@ def well_typed (h : ExprHigh Ident Typ) : Prop := end WellTyped +def infer_equalities (e : ExprHigh String (String × Nat)) (t : TypeUF) : Except String TypeUF := do + let e ← ofOption' "lowering failed" e.lower_TR + e.infer_equalities t + +def infer_types' (e : ExprHigh String (String × Nat)) : Except String (ExprHigh String (String × TypeExpr) × TypeUF) := do + let eqs ← e.infer_equalities ⟨∅, Batteries.UnionFind.mkEmpty 100⟩ + let res ← e.modules.foldlM (λ s k typ => + match eqs.findConcr (.var typ.2.2) with + | some concr => .ok (s.cons k (typ.1, (typ.2.1, concr))) + | none => .error s!"could not find concrete type for {typ}" + ) ∅ + return ({e with modules := res}, eqs) + +def infer_types e := Prod.fst <$> infer_types' e + end ExprHigh end Graphiti diff --git a/Graphiti/Projects/CombinationalStream.lean b/Graphiti/Projects/CombinationalStream.lean index abd9ebd..71dd8cc 100644 --- a/Graphiti/Projects/CombinationalStream.lean +++ b/Graphiti/Projects/CombinationalStream.lean @@ -337,7 +337,6 @@ def et_ms_flip_flop_m := [graphEnv| latch2 -> q_bar [from="q_bar"]; ] -#guard_msgs (drop info) in #eval IO.print <| build_verilog_module "d_latch_m" env d_latch_m.1 (simple_interface ["d", "clk"] ["q", "q_bar"]) #guard_msgs (drop info) in #eval IO.print <| build_verilog_module "et_flip_flop_m" env et_flip_flop_m.1 (simple_interface ["d", "clk"] ["q", "q_bar"]) diff --git a/GraphitiTest/Core/WellTyped.lean b/GraphitiTest/Core/WellTyped.lean index d7c37f4..89b8a6c 100644 --- a/GraphitiTest/Core/WellTyped.lean +++ b/GraphitiTest/Core/WellTyped.lean @@ -22,9 +22,10 @@ def lhs : ExprHigh String (String × Nat) := [graph| loop_init [type = "initBool", arg = $(6)]; queue [type = "queue", arg = $(7)]; queue_out [type = "queue", arg = $(8)]; + output [type = "output", arg = $(9)]; i_in -> mux [to="in2"]; - queue_out -> o_out [from="out1"]; + queue_out -> output [from="out1", to="in1"]; loop_init -> mux [from="out1", to="in1"]; condition_fork -> loop_init [from="out2", to="in1"]; @@ -42,10 +43,11 @@ def lhs : ExprHigh String (String × Nat) := [graph| #eval lhs.lower |>.get! |> (fun x => match (x : ExprLow String (String × Nat)) with | .connect _ e => e | _ => x) |>.findInputInst ⟨.internal "queue_out", "in1"⟩ -#eval lhs.lower |>.get! |>.infer_types ⟨∅, ∅⟩ |>.map (·.typeMap) +#eval lhs.lower |>.get! |>.infer_equalities ⟨∅, ∅⟩ |>.map (·.typeMap) -#eval lhs.lower |>.get! |>.infer_types ⟨∅, ∅⟩ |>.map (·.ufMap) |>.map (·.checkEquiv! 8 6 |>.snd) -#eval lhs.lower |>.get! |>.infer_types ⟨∅, ∅⟩ |>.map (fun x => x.typeMap |>.toList.map (λ y => (y.fst, x.ufMap.root! y.snd))) -#eval lhs.lower |>.get! |>.infer_types ⟨∅, ∅⟩ |>.map (fun x => (x.union (.var 7) (.concr .nat)).findConcr (Graphiti.ExprLow.TypeConstraint.var 4)) +#eval lhs.lower |>.get! |>.infer_equalities ⟨∅, ∅⟩ |>.map (·.ufMap) |>.map (·.checkEquiv! 8 6 |>.snd) +#eval lhs.lower |>.get! |>.infer_equalities ⟨∅, ∅⟩ |>.map (fun x => x.typeMap |>.toList.map (λ y => (y.fst, x.ufMap.root! y.snd))) +#eval lhs.lower |>.get! |>.infer_equalities ⟨∅, ∅⟩ |>.map (fun x => x.findConcr (Graphiti.TypeConstraint.var 4)) +#eval lhs.lower |>.get! |>.infer_types end Graphiti.ExprLow.Test diff --git a/Main.lean b/Main.lean index 235810f..a576d0b 100644 --- a/Main.lean +++ b/Main.lean @@ -168,8 +168,8 @@ def writeLogFile (parsed : CmdArgs) (st : RewriteState) := do def runRewriter {α} (parsed : CmdArgs) (g : α) (st : RewriteState) (r : RewriteResult α) : IO (α × RewriteState) := match r.run st with - | .ok a st' => writeLogFile parsed st' *> pure (a, st') - | .error .done st' => writeLogFile parsed st' *> pure (g, st') + | .ok a st' => pure (a, st') + | .error .done st' => pure (g, st') | .error p st' => do IO.eprintln p writeLogFile parsed st' @@ -177,7 +177,7 @@ def runRewriter {α} (parsed : CmdArgs) (g : α) (st : RewriteState) (r : Rewrit def runRewriter' {α} (parsed : CmdArgs) (st : RewriteState) (r : RewriteResult α) : IO (α × RewriteState) := match r.run st with - | .ok a st' => writeLogFile parsed st' *> pure (a, st') + | .ok a st' => pure (a, st') | .error p st' => do IO.eprintln p writeLogFile parsed st' @@ -200,13 +200,12 @@ def rewriteGraph (parsed : CmdArgs) (g : ExprHigh String (String × Nat)) (st : -- addRuntimeEntry <| {RuntimeEntry.debugEntry (toString rewrittenExprHigh) with name := "debug4"} -- pureGeneration rewrittenExprHigh <| toPattern LoopRewrite.boxLoopBody return rewrittenExprHigh - let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed BranchPureMuxLeft.matchPreAndPost rewrittenExprHigh st <* writeLogFile parsed st + let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed BranchPureMuxLeft.matchPreAndPost rewrittenExprHigh st let (_, st) ← runRewriter' parsed st <| addRuntimeEntry <| {RuntimeEntry.debugEntry (toString rewrittenExprHigh) with name := "debug5"} - writeLogFile parsed st - let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed BranchPureMuxRight.matchPreAndPost rewrittenExprHigh st <* writeLogFile parsed st + let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed BranchPureMuxRight.matchPreAndPost rewrittenExprHigh st let (rewrittenExprHigh, st) ← runRewriter parsed rewrittenExprHigh st <| withUndo <| rewrite_loop [BranchPureMuxLeft.rewrite, BranchPureMuxRight.rewrite, BranchMuxToPure.rewrite] rewrittenExprHigh let (rewrittenExprHigh, st) ← runRewriter parsed rewrittenExprHigh st <| withUndo <| pureGeneration rewrittenExprHigh <| toPattern (n := 0) LoopRewrite.boxLoopBody - let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed LoopRewrite.boxLoopBodyOther rewrittenExprHigh st <* writeLogFile parsed st + let (rewrittenExprHigh, st) ← eggPureGenerator 100 parsed LoopRewrite.boxLoopBodyOther rewrittenExprHigh st let (rewrittenExprHigh, st) ← runRewriter parsed rewrittenExprHigh st (LoopRewrite2.rewrite.run rewrittenExprHigh) return (rewrittenExprHigh, st, st) @@ -222,7 +221,7 @@ def rewriteGraphAbs (parsed : CmdArgs) (g : ExprHigh String (String × Nat)) (st let (g, st) ← runRewriter parsed g st <| pureGeneration g <| toPattern (n := 0) LoopRewrite.boxLoopBody - let (g, st) ← eggPureGenerator 100 parsed LoopRewrite.boxLoopBodyOther' g st <* writeLogFile parsed st + let (g, st) ← eggPureGenerator 100 parsed LoopRewrite.boxLoopBodyOther' g st let .some subexpr@(.base pmap typ) := g.lower | throw <| .userError s!"{decl_name%}: failed to lower graph" @@ -261,20 +260,25 @@ def main (args : List String) : IO Unit := do if !parsed.parseOnly then let (g', _, st') ← (if !parsed.fast then rewriteGraph else rewriteGraphAbs) parsed rewrittenExprHigh st - let (g', st') ← if parsed.reverse then runRewriter' parsed st' <| reverseRewrites g' else pure (g', st') + let (g', st') ← if parsed.reverse then runRewriter parsed g' st' <| reverseRewrites g' else pure (g', st') rewrittenExprHigh := g'; st := st' + writeLogFile parsed st + let .some g' := rewrittenExprHigh.renameModules name_mapping | throw <| .userError s!"{decl_name%}: failed to undo name_mapping" rewrittenExprHigh := g' - let some l := + /- IO.println (repr (renameAssocAll assoc st.1 rewrittenExprHigh)) -/ + + let uf ← IO.ofExcept <| rewrittenExprHigh.infer_equalities ⟨∅, ∅⟩ + + let l ← IO.ofExcept <| if parsed.noDynamaticDot then if parsed.blueSpecDot then pure rewrittenExprHigh.toBlueSpec else pure (toString rewrittenExprHigh) - else dynamaticString rewrittenExprHigh (renameAssocAll assoc st.1 rewrittenExprHigh) - | IO.eprintln s!"Failed to print ExprHigh: {rewrittenExprHigh}" + else dynamaticString rewrittenExprHigh uf (renameAssocAll assoc st.1 rewrittenExprHigh) match parsed.outputFile with | some ofile => IO.FS.writeFile ofile l diff --git a/benchmarks/post-processed/gcd.dot b/benchmarks/post-processed/gcd.dot index 0a60d73..ca25607 100644 --- a/benchmarks/post-processed/gcd.dot +++ b/benchmarks/post-processed/gcd.dot @@ -23,7 +23,7 @@ Digraph G { "icmp_4" [type = "Operator", bbID= 2, op = "icmp_eq_op", in = "in1:32 in2:32 ", out = "out1:1 ", delay=1.530, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1]; "phiC_0" [type = "Mux", bbID= 2, in = "in1?:1 in2:0 in3:0 ", out = "out1:0", delay=0.166, tagged=false, taggers_num=0, tagger_id=-1]; "branch_0" [type = "Branch", bbID= 2, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; - "phi_n1" [type = "init Bool false", bbID= 2, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; + "phi_n1" [type = "Init", bbID= 2, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "branch_1" [type = "Branch", bbID= 2, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branchC_3" [type = "Branch", bbID= 2, in = "in1:0 in2?:1", out = "out1+:0 out2-:0", tagged=false, taggers_num=0, tagger_id=-1]; "fork_0" [type = "Fork", bbID= 2, in = "in1:32", out = "out1:32 out2:32 ", tagged=false, taggers_num=0, tagger_id=-1]; diff --git a/benchmarks/post-processed/gemm.dot b/benchmarks/post-processed/gemm.dot index a938599..0127cf4 100644 --- a/benchmarks/post-processed/gemm.dot +++ b/benchmarks/post-processed/gemm.dot @@ -92,7 +92,7 @@ Digraph G { "phi_n9" [type = "Mux", bbID= 4, in = "in1?:1 in2:32 in3:32 ", out = "out1:32", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "phiC_13" [type = "Mux", bbID= 4, in = "in1?:1 in2:0 in3:0 ", out = "out1:0", delay=0.166, tagged=false, taggers_num=0, tagger_id=-1]; "branch_2" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; - "phi_n39" [type = "init Bool false", bbID= 4, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; + "phi_n39" [type = "Init", bbID= 4, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "branch_3" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branch_7" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branch_13" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; diff --git a/benchmarks/post-processed/img_avg.dot b/benchmarks/post-processed/img_avg.dot index d057d52..789526c 100644 --- a/benchmarks/post-processed/img_avg.dot +++ b/benchmarks/post-processed/img_avg.dot @@ -56,7 +56,7 @@ Digraph G { "icmp_16" [type = "Operator", bbID= 4, op = "icmp_ult_op", in = "in1:32 in2:32 ", out = "out1:1 ", delay=1.530, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1]; "phiC_3" [type = "Mux", bbID= 4, in = "in1?:1 in2:0 in3:0 ", out = "out1:0", delay=0.166, tagged=false, taggers_num=0, tagger_id=-1]; "branch_1" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; - "phi_n15" [type = "init Bool false", bbID= 4, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; + "phi_n15" [type = "Init", bbID= 4, in = "in1:1 ", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "branch_2" [type = "Branch", bbID= 4, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branchC_10" [type = "Branch", bbID= 4, in = "in1:0 in2?:1", out = "out1+:0 out2-:0", tagged=false, taggers_num=0, tagger_id=-1]; "fork_2" [type = "Fork", bbID= 4, in = "in1:32", out = "out1:32 out2:32 ", tagged=false, taggers_num=0, tagger_id=-1]; diff --git a/benchmarks/post-processed/matvec.dot b/benchmarks/post-processed/matvec.dot index 533d740..06c51d0 100644 --- a/benchmarks/post-processed/matvec.dot +++ b/benchmarks/post-processed/matvec.dot @@ -48,7 +48,7 @@ Digraph G { "phi_n0" [type = "Mux", bbID= 3, in = "in1?:1 in2:32 in3:32 ", out = "out1:32", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "phiC_3" [type = "Mux", bbID= 3, in = "in1?:1 in2:0 in3:0 ", out = "out1:0", delay=0.166, tagged=false, taggers_num=0, tagger_id=-1]; "branch_1" [type = "Branch", bbID= 3, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; - "phi_n13" [type = "init Bool false", bbID= 3, in = "in1:1", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; + "phi_n13" [type = "Init", bbID= 3, in = "in1:1", out = "out1:1", delay=0.366, tagged=false, taggers_num=0, tagger_id=-1]; "branch_2" [type = "Branch", bbID= 3, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branch_5" [type = "Branch", bbID= 3, in = "in1:32 in2?:1", out = "out1+:32 out2-:32", tagged=false, taggers_num=0, tagger_id=-1]; "branchC_7" [type = "Branch", bbID= 3, in = "in1:0 in2?:1", out = "out1+:0 out2-:0", tagged=false, taggers_num=0, tagger_id=-1];