diff --git a/Graphiti/Core/ExprLow.lean b/Graphiti/Core/ExprLow.lean index b8bb32d..12d9592 100644 --- a/Graphiti/Core/ExprLow.lean +++ b/Graphiti/Core/ExprLow.lean @@ -518,19 +518,100 @@ def comm_connection_inv (e : ExprLow Ident Typ) := fix_point_opt comm_connection def comm_base (binst : PortMapping Ident) (btyp : Typ) e := fix_point_opt (comm_base_ binst btyp) e 10000 -def comm_connections' {Ident} [DecidableEq Ident] (conn : List (Connection Ident)) (e : ExprLow Ident Typ): ExprLow Ident Typ := - conn.foldr comm_connection' e - def comm_connections {Ident} [DecidableEq Ident] (conn : List (Connection Ident)) (e : ExprLow Ident Typ): ExprLow Ident Typ := conn.foldr comm_connection e -def comm_bases {Ident} [DecidableEq Ident] (bases : List (PortMapping Ident × Typ)) (e : ExprLow Ident Typ): ExprLow Ident Typ := - bases.foldr (Function.uncurry ExprLow.comm_base) e - def getPortMaps {α β} : ExprLow α β → List (PortMapping α) | .base inst typ => [inst] | .connect c e => getPortMaps e | .product e₁ e₂ => getPortMaps e₁ ++ getPortMaps e₂ +def comm_base_only_conn {Ident Typ} [DecidableEq Ident] [DecidableEq Typ] + (binst : PortMapping Ident) (btyp : Typ) (e : ExprLow Ident Typ) : Nat → ExprLow Ident Typ +| 0 => e +| n+1 => + match e with + | .product e₁ e₂ => + if e₁ = .base binst btyp then + match e₂ with + | .connect c e => + if c.output ∉ binst.output.valsList ∧ c.input ∉ binst.input.valsList then + .connect c <| (comm_base_only_conn binst btyp (.product e₁ e) n) + else .product e₁ e₂ + | _ => .product e₁ e₂ + else .product e₁ <| comm_base_only_conn binst btyp e₂ n + | .connect c e => .connect c <| comm_base_only_conn binst btyp e n + | e => e + +def comm_bases_only_conn {Ident Typ} [DecidableEq Ident] [DecidableEq Typ] + (bases : List (PortMapping Ident × Typ)) (e : ExprLow Ident Typ) : ExprLow Ident Typ := + bases.foldr (Function.uncurry (fun a b c => ExprLow.comm_base_only_conn a b c 10000000)) e + +def comm_base_fast {Ident Typ} [DecidableEq Ident] [DecidableEq Typ] + (binst : PortMapping Ident) (btyp : Typ) (e : ExprLow Ident Typ) : Nat → ExprLow Ident Typ +| 0 => e +| n+1 => + match e with + | .product e₁ e₂ => + if e₁ = .base binst btyp then + match e₂ with + | .connect c e => + if c.output ∉ binst.output.valsList ∧ c.input ∉ binst.input.valsList then + .connect c <| (comm_base_fast binst btyp (.product e₁ e) n) + else .product e₁ e₂ + | .product (.base binst' btyp') e₂' => + if inst_disjoint binst' binst then + .product (.base binst' btyp') (comm_base_fast binst btyp (.product e₁ e₂') n) + else .product e₁ e₂ + | .base binst' btyp' => + if inst_disjoint binst' binst then .product e₂ e₁ else .product e₁ e₂ + | _ => .product e₁ e₂ + else .product e₁ <| comm_base_fast binst btyp e₂ n + | .connect c e => .connect c <| comm_base_fast binst btyp e n + | e => e + +def comm_bases_fast + (bases : List (PortMapping Ident × Typ)) (e : ExprLow Ident Typ) : ExprLow Ident Typ := + bases.foldr (Function.uncurry (fun a b c => ExprLow.comm_base_fast a b c 10000000)) e + +def comm_connection'_fast (conn : Connection Ident) (e_strt : ExprLow Ident Typ) : Nat → ExprLow Ident Typ +| 0 => e_strt +| n+1 => + match e_strt with + | .connect c e => + if c.output = conn.output ∧ c.input = conn.input then + match e with + | .connect c' e' => + if c.output ≠ c'.output ∧ c.input ≠ c'.input then + .connect c' (comm_connection'_fast conn (.connect c e') n) + else e_strt + | .product e₁ e₂ => + let a := e₁.findInput c.input + let b := e₁.findOutput c.output + if a ∧ b then + -- .product (comm_connection' conn <| .connect o i e₁) e₂ + -- We actually don't want to commute (assuming we are left associative) + e_strt + else if ¬ a ∧ ¬ b ∧ e₂.findInput c.input ∧ e₂.findOutput c.output then + .product e₁ (comm_connection'_fast conn (.connect c e₂) n) + else e_strt + | _ => e_strt + else .connect c (comm_connection'_fast conn e n) + | .product e₁ e₂ => + .product (comm_connection'_fast conn e₁ n) (comm_connection'_fast conn e₂ n) + | e => e + +def comm_connections'_fast + (conn : List (Connection Ident)) (e : ExprLow Ident Typ) : ExprLow Ident Typ := + conn.foldr (fun a b => ExprLow.comm_connection'_fast a b 10000000) e + +@[implemented_by comm_connections'_fast] +def comm_connections' (conn : List (Connection Ident)) (e : ExprLow Ident Typ): ExprLow Ident Typ := + conn.foldr comm_connection' e + +@[implemented_by comm_bases_fast] +def comm_bases (bases : List (PortMapping Ident × Typ)) (e : ExprLow Ident Typ): ExprLow Ident Typ := + bases.foldr (Function.uncurry ExprLow.comm_base) e + end ExprLow end Graphiti