From 99fabf50f986f2bda840024fc899f66d08a27a89 Mon Sep 17 00:00:00 2001 From: Willow Ahrens <willow@csail.mit.edu> Date: Tue, 18 Mar 2025 13:48:46 -0400 Subject: [PATCH 1/2] small tweaks to rules --- src/interface/lazy.jl | 4 +++- src/symbolic/simplify.jl | 3 +++ src/symbolic/symbolic.jl | 22 ++++++++++++++++------ test/suites/interface_tests.jl | 31 +++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 56cfe2008..51bcd68fd 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -396,12 +396,13 @@ struct Square{T,S} scale::S end -@inline square(x) = Square(sign(x)^2, norm(x)) +@inline square(x) = Square(sign(x)^2 / one(x), norm(x)) @inline root(x::Square) = sqrt(x.arg) * x.scale @inline Base.zero(::Type{Square{T,S}}) where {T,S} = Square{T,S}(zero(T), zero(S)) @inline Base.zero(::Square{T,S}) where {T,S} = Square{T,S}(zero(T), zero(S)) +@inline Base.isone(x::Square) = isone(root(x)) @inline Base.isinf(x::Finch.Square) = isinf(x.arg) || isinf(x.scale) @@ -458,6 +459,7 @@ end Power{T,S,E}(zero(T), zero(S), one(E)) @inline Base.zero(x::Power) = Power(zero(x.arg), zero(x.scale), x.exponent) @inline Base.isinf(x::Finch.Power) = isinf(x.arg) || isinf(x.scale) || isinf(x.exponent) +@inline Base.isone(x::Power) = isone(root(x)) function Base.promote_rule( ::Type{Power{T1,S1,E1}}, ::Type{Power{T2,S2,E2}} diff --git a/src/symbolic/simplify.jl b/src/symbolic/simplify.jl index b38824ed1..42bee09eb 100644 --- a/src/symbolic/simplify.jl +++ b/src/symbolic/simplify.jl @@ -119,6 +119,9 @@ function get_simplify_rules(alg, shash) (@rule call(norm, ~x::isliteral, ~y) => if iszero(x.val) x end), + (@rule call(^, ~x, ~p::isliteral) => if isone(p.val) + x + end), (@rule block(~a1..., sieve(~c, ~b1), sieve(~c, ~b2), ~a2...) => block(a1..., sieve(~c, block(b1, b2)), a2...) ), diff --git a/src/symbolic/symbolic.jl b/src/symbolic/symbolic.jl index 09abebeb0..38b6083d0 100644 --- a/src/symbolic/symbolic.jl +++ b/src/symbolic/symbolic.jl @@ -93,8 +93,12 @@ isidentity(::AbstractAlgebra, ::typeof(|), x) = !ismissing(x) && iszero(x) isidentity(::AbstractAlgebra, ::typeof(&), x) = !ismissing(x) && x == ~(zero(x)) isidentity(::AbstractAlgebra, ::typeof(min), x) = !ismissing(x) && isinf(x) && x > 0 isidentity(::AbstractAlgebra, ::typeof(max), x) = !ismissing(x) && isinf(x) && x < 0 -isidentity(::Finch.AbstractAlgebra, ::InitMax{D}, x) where {D} = isequal(x, D) -isidentity(::Finch.AbstractAlgebra, ::InitMin{D}, x) where {D} = isequal(x, D) +function isidentity(::Finch.AbstractAlgebra, ::InitMax{D}, x) where {D} + isequal(x, D) || isequal(x == D, true) +end +function isidentity(::Finch.AbstractAlgebra, ::InitMin{D}, x) where {D} + isequal(x, D) || isequal(x == D, true) +end function isidentity_by_fn(alg::AbstractAlgebra, ::typeof(minby), x::FinchNode) if @capture x call(tuple, ~a::isliteral, ~b) @@ -112,8 +116,12 @@ function isidentity_by_fn(alg::AbstractAlgebra, ::typeof(maxby), x::FinchNode) end return false end -isidentity(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} = isequal(x, Vf) -isidentity(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = isequal(x, Vf) +function isidentity(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} + isequal(x, Vf) || isequal(x == Vf, true) +end +function isidentity(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} + isequal(x, Vf) || isequal(x == Vf, true) +end isannihilator(alg) = (f, x) -> isannihilator(alg, f, x) function isannihilator(alg, f::FinchNode, x::FinchNode) @@ -150,8 +158,10 @@ function isannihilator_by_fn(alg::AbstractAlgebra, ::typeof(maxby), x::FinchNode end return false end -isannihilator(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} = !isequal(x, Vf) -#isannihilator(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = !isequal(x, Vf) +function isannihilator(::AbstractAlgebra, ::Chooser{Vf}, x) where {Vf} + !(isequal(x, Vf) || isequal(x == Vf, true)) +end +#isannihilator(::AbstractAlgebra, ::InitWriter{Vf}, x) where {Vf} = !(isequal(x, Vf) || isequal(x == Vf, true)) isinverse(alg) = (f, g) -> isinverse(alg, f, g) function isinverse(alg, f::FinchNode, g::FinchNode) diff --git a/test/suites/interface_tests.jl b/test/suites/interface_tests.jl index 3f2b77a80..dcd90209c 100644 --- a/test/suites/interface_tests.jl +++ b/test/suites/interface_tests.jl @@ -806,6 +806,37 @@ end @test A * x == C * x @test A * x == compute(lazy(C) * x) end + + # https://github.com/finch-tensor/Finch.jl/issues/708 + let + for p in [-Inf, -3, -2, -1, 0, 1, 2, 3, Inf] + @testset "$p-norm" begin + A_ref = rand(4, 3) + A = Tensor(A_ref) + @test norm(A_ref, p) ≈ norm(A, p) + + A_ref = sprand(4, 3, 1.0) + A = Tensor(A_ref) + @test norm(A_ref, p) ≈ norm(A, p) + end + end + end + + # https://github.com/finch-tensor/Finch.jl/pull/709 + let + A_ref = [1 2 0 4 0; 0 -2 1 0 1] + A = swizzle(Tensor(Dense(SparseList(Element(0))), A_ref), 1, 2) + + @test norm(A, 1) == norm(A_ref, 1) + @test norm(A, 2) == norm(A_ref, 2) + @test norm(A, 3) == norm(A_ref, 3) + end + + # https://github.com/finch-tensor/Finch.jl/issues/686 + let + A = fsprand(5, 5, 3) + @test countstored(A - A) == 3 skip=(key!="default") + end end end end From 421894f3b9fc966426befa78192b33a531933124 Mon Sep 17 00:00:00 2001 From: Willow Ahrens <willow@csail.mit.edu> Date: Tue, 18 Mar 2025 15:26:32 -0400 Subject: [PATCH 2/2] style --- test/suites/interface_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/suites/interface_tests.jl b/test/suites/interface_tests.jl index dcd90209c..293919d43 100644 --- a/test/suites/interface_tests.jl +++ b/test/suites/interface_tests.jl @@ -835,7 +835,7 @@ end # https://github.com/finch-tensor/Finch.jl/issues/686 let A = fsprand(5, 5, 3) - @test countstored(A - A) == 3 skip=(key!="default") + @test countstored(A - A) == 3 skip = (key != "default") end end end