From 99fabf50f986f2bda840024fc899f66d08a27a89 Mon Sep 17 00:00:00 2001
From: Willow Ahrens <willow@csail.mit.edu>
Date: Tue, 18 Mar 2025 13:48:46 -0400
Subject: [PATCH 1/2] small tweaks to rules

---
 src/interface/lazy.jl          |  4 +++-
 src/symbolic/simplify.jl       |  3 +++
 src/symbolic/symbolic.jl       | 22 ++++++++++++++++------
 test/suites/interface_tests.jl | 31 +++++++++++++++++++++++++++++++
 4 files changed, 53 insertions(+), 7 deletions(-)

diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl
index 56cfe2008..51bcd68fd 100644
--- a/src/interface/lazy.jl
+++ b/src/interface/lazy.jl
@@ -396,12 +396,13 @@ struct Square{T,S}
     scale::S
 end
 
-@inline square(x) = Square(sign(x)^2, norm(x))
+@inline square(x) = Square(sign(x)^2 / one(x), norm(x))
 
 @inline root(x::Square) = sqrt(x.arg) * x.scale
 
 @inline Base.zero(::Type{Square{T,S}}) where {T,S} = Square{T,S}(zero(T), zero(S))
 @inline Base.zero(::Square{T,S}) where {T,S} = Square{T,S}(zero(T), zero(S))
+@inline Base.isone(x::Square) = isone(root(x))
 
 @inline Base.isinf(x::Finch.Square) = isinf(x.arg) || isinf(x.scale)
 
@@ -458,6 +459,7 @@ end
     Power{T,S,E}(zero(T), zero(S), one(E))
 @inline Base.zero(x::Power) = Power(zero(x.arg), zero(x.scale), x.exponent)
 @inline Base.isinf(x::Finch.Power) = isinf(x.arg) || isinf(x.scale) || isinf(x.exponent)
+@inline Base.isone(x::Power) = isone(root(x))
 
 function Base.promote_rule(
     ::Type{Power{T1,S1,E1}}, ::Type{Power{T2,S2,E2}}
diff --git a/src/symbolic/simplify.jl b/src/symbolic/simplify.jl
index b38824ed1..42bee09eb 100644
--- a/src/symbolic/simplify.jl
+++ b/src/symbolic/simplify.jl
@@ -119,6 +119,9 @@ function get_simplify_rules(alg, shash)
         (@rule call(norm, ~x::isliteral, ~y) => if iszero(x.val)
             x
         end),
+        (@rule call(^, ~x, ~p::isliteral) => if isone(p.val)
+            x
+        end),
         (@rule block(~a1..., sieve(~c, ~b1), sieve(~c, ~b2), ~a2...) =>
             block(a1..., sieve(~c, block(b1, b2)), a2...)
     ),
diff --git a/src/symbolic/symbolic.jl b/src/symbolic/symbolic.jl
index 09abebeb0..38b6083d0 100644
--- a/src/symbolic/symbolic.jl
+++ b/src/symbolic/symbolic.jl
@@ -93,8 +93,12 @@ isidentity(::AbstractAlgebra, ::typeof(|), x) = !ismissing(x) && iszero(x)
 isidentity(::AbstractAlgebra, ::typeof(&), x) = !ismissing(x) && x == ~(zero(x))
 isidentity(::AbstractAlgebra, ::typeof(min), x) = !ismissing(x) && isinf(x) && x > 0
 isidentity(::AbstractAlgebra, ::typeof(max), x) = !ismissing(x) && isinf(x) && x < 0
-isidentity(::Finch.AbstractAlgebra, ::InitMax{D}, x) where {D} = isequal(x, D)
-isidentity(::Finch.AbstractAlgebra, ::InitMin{D}, x) where {D} = isequal(x, D)
+function isidentity(::Finch.AbstractAlgebra, ::InitMax{D}, x) where {D}
+    isequal(x, D) || isequal(x == D, true)
+end
+function isidentity(::Finch.AbstractAlgebra, ::InitMin{D}, x) where {D}
+    isequal(x, D) || isequal(x == D, true)
+end
 
 function isidentity_by_fn(alg::AbstractAlgebra, ::typeof(minby), x::FinchNode)
     if @capture x call(tuple, ~a::isliteral, ~b)
@@ -112,8 +116,12 @@ function isidentity_by_fn(alg::AbstractAlgebra, ::typeof(maxby), x::FinchNode)
     end
     return false
 end
-isidentity(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} = isequal(x, Vf)
-isidentity(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = isequal(x, Vf)
+function isidentity(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf}
+    isequal(x, Vf) || isequal(x == Vf, true)
+end
+function isidentity(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf}
+    isequal(x, Vf) || isequal(x == Vf, true)
+end
 
 isannihilator(alg) = (f, x) -> isannihilator(alg, f, x)
 function isannihilator(alg, f::FinchNode, x::FinchNode)
@@ -150,8 +158,10 @@ function isannihilator_by_fn(alg::AbstractAlgebra, ::typeof(maxby), x::FinchNode
     end
     return false
 end
-isannihilator(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} = !isequal(x, Vf)
-#isannihilator(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = !isequal(x, Vf)
+function isannihilator(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf}
+    !(isequal(x, Vf) || isequal(x == Vf, true))
+end
+#isannihilator(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = !(isequal(x, Vf) || isequal(x == Vf, true))
 
 isinverse(alg) = (f, g) -> isinverse(alg, f, g)
 function isinverse(alg, f::FinchNode, g::FinchNode)
diff --git a/test/suites/interface_tests.jl b/test/suites/interface_tests.jl
index 3f2b77a80..dcd90209c 100644
--- a/test/suites/interface_tests.jl
+++ b/test/suites/interface_tests.jl
@@ -806,6 +806,37 @@ end
                     @test A * x == C * x
                     @test A * x == compute(lazy(C) * x)
                 end
+
+                # https://github.com/finch-tensor/Finch.jl/issues/708
+                let
+                    for p in [-Inf, -3, -2, -1, 0, 1, 2, 3, Inf]
+                        @testset "$p-norm" begin
+                            A_ref = rand(4, 3)
+                            A = Tensor(A_ref)
+                            @test norm(A_ref, p) ≈ norm(A, p)
+
+                            A_ref = sprand(4, 3, 1.0)
+                            A = Tensor(A_ref)
+                            @test norm(A_ref, p) ≈ norm(A, p)
+                        end
+                    end
+                end
+
+                # https://github.com/finch-tensor/Finch.jl/pull/709
+                let
+                    A_ref = [1 2 0 4 0; 0 -2 1 0 1]
+                    A = swizzle(Tensor(Dense(SparseList(Element(0))), A_ref), 1, 2)
+
+                    @test norm(A, 1) == norm(A_ref, 1)
+                    @test norm(A, 2) == norm(A_ref, 2)
+                    @test norm(A, 3) == norm(A_ref, 3)
+                end
+
+                # https://github.com/finch-tensor/Finch.jl/issues/686
+                let
+                    A = fsprand(5, 5, 3)
+                    @test countstored(A - A) == 3 skip=(key!="default")
+                end
             end
         end
     end

From 421894f3b9fc966426befa78192b33a531933124 Mon Sep 17 00:00:00 2001
From: Willow Ahrens <willow@csail.mit.edu>
Date: Tue, 18 Mar 2025 15:26:32 -0400
Subject: [PATCH 2/2] style

---
 test/suites/interface_tests.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/suites/interface_tests.jl b/test/suites/interface_tests.jl
index dcd90209c..293919d43 100644
--- a/test/suites/interface_tests.jl
+++ b/test/suites/interface_tests.jl
@@ -835,7 +835,7 @@ end
                 # https://github.com/finch-tensor/Finch.jl/issues/686
                 let
                     A = fsprand(5, 5, 3)
-                    @test countstored(A - A) == 3 skip=(key!="default")
+                    @test countstored(A - A) == 3 skip = (key != "default")
                 end
             end
         end