From 6dbebb0a16acafe3ec00af3108fbb329727a62f2 Mon Sep 17 00:00:00 2001 From: jverzani Date: Fri, 29 Jul 2022 19:29:10 -0400 Subject: [PATCH] rework round and trunc --- Project.toml | 2 +- src/generic.jl | 30 ++++++++++++++++++++++++------ test/test-ode.jl | 2 +- test/tests.jl | 15 +++++++++++++++ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 785e1ad..60af4dd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SymPy" uuid = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" -version = "1.1.6" +version = "1.1.7" [deps] CommonEq = "3709ef60-1bee-4518-9f2f-acd86f176c50" diff --git a/src/generic.jl b/src/generic.jl index b27e56e..53fd60c 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -38,16 +38,34 @@ Base.mod(x::SymbolicObject, args...)= Mod(x, args...) # so we can compare numbers with ≈ Base.rtoldefault(::Type{<:SymbolicObject}) = eps() -function Base.round(x::Sym; kwargs...) - length(free_symbols(x)) > 0 && throw(ArgumentError("can't round a symbolic expression")) - round(N(x); kwargs...) + +# SymPy round only has a digits argument and errors on symbols +# here we relax it a git +function Base.round(x::Sym, r::RoundingMode=RoundNearest; digits=nothing, kwargs...) + x.is_number != true && return x + digits === nothing && return _round(x, r) + return x.round(digits) +end +_round(x::Sym, r::RoundingMode{:Nearest}) = x.round() +_round(x::Sym, r::RoundingMode) = begin + Sym(round(N(x), r)) +end +function Base.round(::Type{T}, x::Sym, r::RoundingMode=RoundNearest) where {T <: Integer} + r != RoundToZero && (x = round(x, r)) + convert(T, trunc(x)) end -function Base.trunc(x::Sym; kwargs...) - length(free_symbols(x)) > 0 && throw(ArgumentError("can't truncate a symbolic expression")) - trunc(N(x); kwargs...) + +#trunc(x) returns the nearest integral value of the same type as x whose absolute value is less than or equal to the absolute value of x. +function Base.trunc(x::Sym) + x.is_real == true || return x + y = x.evalf().floor() + x.is_positive && return y + return 1 + y end + + # check on type of number # these are boolean: true/false; not tru/false/nothing,as in SymPy Base.isfinite(x::Sym) = !is_(:infinite, x) diff --git a/test/test-ode.jl b/test/test-ode.jl index 7fa57b3..6e2db0f 100644 --- a/test/test-ode.jl +++ b/test/test-ode.jl @@ -1,7 +1,7 @@ using SymPy using Test -@testset "ODes" begin +@testset "ODEs" begin ## ODEs x, a = Sym("x, a") F = SymFunction("F") diff --git a/test/tests.jl b/test/tests.jl index bb46140..eb197b3 100644 --- a/test/tests.jl +++ b/test/tests.jl @@ -772,6 +772,21 @@ end @syms x ex = integrate(sqrt(1 + (1/x)^2), (x, 1/sympy.E, sympy.E)) @test N(ex) ≈ 3.196198513599507 + + ## Issue #472 + @syms x + a = Sym(1)/2 + @test round(x) == x + @test round(a) == 0 + @test round(a, RoundUp) == 1 + @test round(PI; digits = 2) * 100 == 314 # this one is odd + @test trunc(x) == x + @test trunc(PI) == 3 + @test trunc(-PI) == -3 + @test typeof(trunc(PI)) == typeof(PI) + M = [PI -PI; a x] + @test round.(M) == [3 -3; 0 x] + end @testset "generic programming, issue 223" begin