diff --git a/src/rewrite.jl b/src/rewrite.jl index 811570f..4b3b338 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -28,7 +28,7 @@ macro rewrite(args...) return rewrite_and_return(args[1]; move_factors_into_sums = args[2].args[2]) end -struct Zero end +struct Zero <: Number end # This method is called in various `promote_operation_fallback` methods if one # of the arguments is `::Zero`. @@ -59,13 +59,19 @@ broadcast!!(::typeof(add_mul), ::Zero, x, y) = x * y # Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)` Base.:*(z::Zero, ::Any) = z +Base.:*(z::Zero, ::Number) = z Base.:*(::Any, z::Zero) = z +Base.:*(::Number, z::Zero) = z Base.:*(z::Zero, ::Zero) = z Base.:+(::Zero, x::Any) = x +Base.:+(::Zero, x::Number) = x Base.:+(x::Any, ::Zero) = x +Base.:+(x::Number, ::Zero) = x Base.:+(z::Zero, ::Zero) = z Base.:-(::Zero, x::Any) = -x +Base.:-(::Zero, x::Number) = -x Base.:-(x::Any, ::Zero) = x +Base.:-(x::Number, ::Zero) = x Base.:-(z::Zero, ::Zero) = z Base.:-(z::Zero) = z Base.:+(z::Zero) = z @@ -79,6 +85,14 @@ function Base.:/(z::Zero, x::Any) end end +function Base.:/(z::Zero, x::Number) + if iszero(x) + throw(DivideError()) + else + return z + end +end + # These methods are used to provide an efficient implementation for the common # case like `x^2 * sum(f for i in 1:0)`, which lowers to # `_MA.operate!!(*, x^2, _MA.Zero())`. We don't need the method with reversed