diff --git a/src/forward/lib.jl b/src/forward/lib.jl index fecbd90e2..227083c2d 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -60,10 +60,8 @@ function _pushforward(dargs, ::typeof(Core._apply), f, args...) Core._apply(_pushforward, ((df, dargs...), f), args...) end -if VERSION >= v"1.4.0-DEV.304" - _pushforward(dargs, ::typeof(Core._apply_iterate), ::typeof(iterate), f, args...) = - _pushforward((first(args), tail(tail(dargs))...), Core._apply, f, args...) -end +_pushforward(dargs, ::typeof(Core._apply_iterate), ::typeof(iterate), f, args...) = + _pushforward((first(args), tail(tail(dargs))...), Core._apply, f, args...) using ..Zygote: literal_getproperty, literal_getfield, literal_getindex diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 4559237e7..1d9678807 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -198,15 +198,13 @@ unapply(t, xs) = _unapply(t, xs)[1] end end -if VERSION >= v"1.4.0-DEV.304" - @adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) - y, back = Core._apply(_pullback, (__context__, f), args...) - st = map(_empty, args) - y, function (Δ) - Δ = back(Δ) - Δ === nothing ? nothing : - (nothing, first(Δ), unapply(st, Base.tail(Δ))...) - end +@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) + y, back = Core._apply(_pullback, (__context__, f), args...) + st = map(_empty, args) + y, function (Δ) + Δ = back(Δ) + Δ === nothing ? nothing : + (nothing, first(Δ), unapply(st, Base.tail(Δ))...) end end diff --git a/test/chainrules.jl b/test/chainrules.jl index 3d5fcb035..3017a9e18 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -163,14 +163,10 @@ using Zygote: ZygoteRuleConfig @test (1,) == h(1) - if VERSION >= v"1.6-" - @test begin - a3, pb3 = Zygote.pullback(h, 1) - ((1,),) == pb3(1) - end - else + + @test begin a3, pb3 = Zygote.pullback(h, 1) - @test ((1,),) == pb3(1) + ((1,),) == pb3(1) end end diff --git a/test/compiler.jl b/test/compiler.jl index 66ab79248..07d498ecb 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -33,7 +33,7 @@ y, back = pullback(badly, 2) bt = try back(1) catch e stacktrace(catch_backtrace()) end @test trace_contains(bt, nothing, "compiler.jl", bad_def_line) -if VERSION <= v"1.6-" || VERSION >= v"1.10-" +if VERSION >= v"1.10-" @test trace_contains(bt, :badly, "compiler.jl", bad_call_line) else @test_broken trace_contains(bt, :badly, "compiler.jl", bad_call_line) @@ -319,14 +319,17 @@ end @test res == 12. @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - @test occursin("Can't differentiate function execution in catch block", - string(err)) + if VERSION >= v"1.11" + @test_broken occursin("Can't differentiate function execution in catch block", string(err)) + else + @test occursin("Can't differentiate function execution in catch block", string(err)) + end end if VERSION >= v"1.8" @testset "try/catch/else" begin @test Zygote.gradient(try_catch_else, false, 1.0) == (nothing, 8.0) - @test_throws "Can't differentiate function execution in catch block" Zygote.gradient(try_catch_else, true, 1.0) + @test_throws ErrorException Zygote.gradient(try_catch_else, true, 1.0) end end @@ -348,6 +351,9 @@ end @test_throws ErrorException pull(1.) err = try pull(1.) catch ex; ex end - @test occursin("Can't differentiate function execution in catch block", - string(err)) + if VERSION >= v"1.11" + @test_broken occursin("Can't differentiate function execution in catch block", string(err)) + else + @test occursin("Can't differentiate function execution in catch block", string(err)) + end end diff --git a/test/features.jl b/test/features.jl index ff777dd48..e7ca22316 100644 --- a/test/features.jl +++ b/test/features.jl @@ -95,286 +95,279 @@ using FillArrays: Fill end end +@testset "misc" begin + add(a, b) = a+b + _relu(x) = x > 0 ? x : 0 + f(a, b...) = +(a, b...) + + y, back = pullback(identity, 1) + dx = back(2) + @test y == 1 + @test dx == (2,) + + mul(a, b) = a*b + y, back = pullback(mul, 2, 3) + dx = back(4) + @test y == 6 + @test dx == (12, 8) -add(a, b) = a+b -_relu(x) = x > 0 ? x : 0 -f(a, b...) = +(a, b...) - -y, back = pullback(identity, 1) -dx = back(2) -@test y == 1 -@test dx == (2,) - -mul(a, b) = a*b -y, back = pullback(mul, 2, 3) -dx = back(4) -@test y == 6 -@test dx == (12, 8) - -@test gradient(mul, 2, 3) == (3, 2) -@test withgradient(mul, 2, 3) == (val = 6, grad = (3, 2)) + @test gradient(mul, 2, 3) == (3, 2) + @test withgradient(mul, 2, 3) == (val = 6, grad = (3, 2)) -bool = true -b(x) = bool ? 2x : x + bool = true + b(x) = bool ? 2x : x -y, back = pullback(b, 3) -dx = back(4) -@test y == 6 -@test dx == (8,) + y, back = pullback(b, 3) + dx = back(4) + @test y == 6 + @test dx == (8,) -fglobal = x -> 5x -gglobal = x -> fglobal(x) + fglobal = x -> 5x + gglobal = x -> fglobal(x) -@test gradient(gglobal, 2) == (5,) + @test gradient(gglobal, 2) == (5,) -global bool = false + global bool = false -y, back = pullback(b, 3) -dx = back(4) -@test y == 3 -@test getindex.(dx) == (4,) + y, back = pullback(b, 3) + dx = back(4) + @test y == 3 + @test getindex.(dx) == (4,) -y, back = pullback(x -> sum(x.*x), [1, 2, 3]) -dxs = back(1) -@test y == 14 -@test dxs == ([2,4,6],) + y, back = pullback(x -> sum(x.*x), [1, 2, 3]) + dxs = back(1) + @test y == 14 + @test dxs == ([2,4,6],) -function pow(x, n) - r = 1 - while n > 0 - n -= 1 - r *= x + function pow(x, n) + r = 1 + while n > 0 + n -= 1 + r *= x + end + return r end - return r -end -@test gradient(pow, 2, 3) == (12,nothing) + @test gradient(pow, 2, 3) == (12,nothing) -function pow_mut(x, n) - r = Ref(one(x)) - while n > 0 - n -= 1 - r[] = r[] * x + function pow_mut(x, n) + r = Ref(one(x)) + while n > 0 + n -= 1 + r[] = r[] * x + end + return r[] end - return r[] -end -@test gradient(pow_mut, 2, 3) == (12,nothing) + @test gradient(pow_mut, 2, 3) == (12,nothing) -r = 1 -function pow_global(x, n) - global r - while n > 0 - r *= x - n -= 1 - end - return r -end + @test gradient(x -> 1, 2) == (nothing,) -@test gradient(pow_global, 2, 3) == (12,nothing) + @test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),) -@test gradient(x -> 1, 2) == (nothing,) + @test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) # one NamedTuple + @test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) # two, different fields + @test gradient(x -> x.re*x.im + x.re, 2+3im) == (4.0 + 2.0im,) # three, with accumulation -@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),) + @test gradient(x -> abs2(x * x.re), 4+5im) == (456.0 + 160.0im,) # gradient participates + @test gradient(x -> abs2(x * real(x)), 4+5im) == (456.0 + 160.0im,) # function not getproperty + @test gradient(x -> abs2(x * getfield(x, :re)), 4+5im) == (456.0 + 160.0im,) -@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) # one NamedTuple -@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) # two, different fields -@test gradient(x -> x.re*x.im + x.re, 2+3im) == (4.0 + 2.0im,) # three, with accumulation + struct Bar{T} + a::T + b::T + end -@test gradient(x -> abs2(x * x.re), 4+5im) == (456.0 + 160.0im,) # gradient participates -@test gradient(x -> abs2(x * real(x)), 4+5im) == (456.0 + 160.0im,) # function not getproperty -@test gradient(x -> abs2(x * getfield(x, :re)), 4+5im) == (456.0 + 160.0im,) + function mul_struct(a, b) + c = Bar(a, b) + c.a * c.b + end -struct Bar{T} - a::T - b::T -end + @test gradient(mul_struct, 2, 3) == (3, 2) -function mul_struct(a, b) - c = Bar(a, b) - c.a * c.b -end + function mul_tuple(a, b) + c = (a, b) + c[1] * c[2] + end -@test gradient(mul_struct, 2, 3) == (3, 2) + @test gradient(mul_tuple, 2, 3) == (3, 2) -function mul_tuple(a, b) - c = (a, b) - c[1] * c[2] -end + function mul_lambda(x, y) + g = z -> x * z + g(y) + end -@test gradient(mul_tuple, 2, 3) == (3, 2) + @test gradient(mul_lambda, 2, 3) == (3, 2) -function mul_lambda(x, y) - g = z -> x * z - g(y) -end + @test gradient((a, b...) -> *(a, b...), 2, 3) == (3, 2) -@test gradient(mul_lambda, 2, 3) == (3, 2) + @test gradient((x, a...) -> x, 1) == (1,) + @test gradient((x, a...) -> x, 1, 1) == (1,nothing) + @test gradient((x, a...) -> x == a, 1) == (nothing,) + @test gradient((x, a...) -> x == a, 1, 2) == (nothing,nothing) -@test gradient((a, b...) -> *(a, b...), 2, 3) == (3, 2) + kwmul(; a = 1, b) = a*b -@test gradient((x, a...) -> x, 1) == (1,) -@test gradient((x, a...) -> x, 1, 1) == (1,nothing) -@test gradient((x, a...) -> x == a, 1) == (nothing,) -@test gradient((x, a...) -> x == a, 1, 2) == (nothing,nothing) + mul_kw(a, b) = kwmul(a = a, b = b) -kwmul(; a = 1, b) = a*b + @test gradient(mul_kw, 2, 3) == (3, 2) -mul_kw(a, b) = kwmul(a = a, b = b) + function myprod(xs) + s = 1 + for x in xs + s *= x + end + return s + end -@test gradient(mul_kw, 2, 3) == (3, 2) + @test gradient(myprod, [1,2,3])[1] == [6,3,2] -function myprod(xs) - s = 1 - for x in xs - s *= x + function mul_vec(a, b) + xs = [a, b] + xs[1] * xs[2] end - return s -end -@test gradient(myprod, [1,2,3])[1] == [6,3,2] + @test gradient(mul_vec, 2, 3) == (3, 2) -function mul_vec(a, b) - xs = [a, b] - xs[1] * xs[2] -end + @test gradient(2) do x + d = Dict() + d[:x] = x + x * d[:x] + end == (4,) -@test gradient(mul_vec, 2, 3) == (3, 2) + f(args...;a=nothing,kwargs...) = g(a,args...;kwargs...) + g(args...;x=1,idx=Colon(),kwargs...) = x[idx] + @test gradient(x->sum(f(;x=x,idx=1:1)),ones(2))[1] == [1., 0.] -@test gradient(2) do x - d = Dict() - d[:x] = x - x * d[:x] -end == (4,) + pow_rec(x, n) = n == 0 ? 1 : x*pow_rec(x, n-1) -f(args...;a=nothing,kwargs...) = g(a,args...;kwargs...) -g(args...;x=1,idx=Colon(),kwargs...) = x[idx] -@test gradient(x->sum(f(;x=x,idx=1:1)),ones(2))[1] == [1., 0.] + @test gradient(pow_rec, 2, 3) == (12, nothing) -pow_rec(x, n) = n == 0 ? 1 : x*pow_rec(x, n-1) + f(x) = throw(DimensionMismatch("fubar")) -@test gradient(pow_rec, 2, 3) == (12, nothing) + @test_throws DimensionMismatch gradient(f, 1) -# For nested AD, until we support errors -function grad(f, args...) - y, back = pullback(f, args...) - return back(1) -end + struct Layer{T} + W::T + end -D(f, x) = grad(f, x)[1] + (f::Layer)(x) = f.W * x -@test D(x -> D(sin, x), 0.5) == -sin(0.5) -@test D(x -> x*D(y -> x+y, 1), 1) == 1 -@test D(x -> x*D(y -> x*y, 1), 4) == 8 + W = [1 0; 0 1] + x = [1, 2] -@test sin''(1.0) == -sin(1.0) -@test sin'''(1.0) == -cos(1.0) + y, back = pullback(() -> W * x, Params([W])) + @test y == [1, 2] + @test back([1, 1])[W] == [1 2; 1 2] -f(x) = throw(DimensionMismatch("fubar")) + layer = Layer(W) -@test_throws DimensionMismatch gradient(f, 1) + y, back = pullback(() -> layer(x), Params([W])) + @test y == [1, 2] + @test back([1, 1])[W] == [1 2; 1 2] -struct Layer{T} - W::T -end + @test gradient(() -> sum(W * x), Params([W]))[W] == [1 2; 1 2] + y, gr = withgradient(() -> sum(W * x), Params([W])) + @test y == 3 + @test gr[W] == [1 2; 1 2] -(f::Layer)(x) = f.W * x + let + p = [1] + θ = Zygote.Params([p]) + θ̄ = gradient(θ) do + p′ = (p,)[1] + p′[1] + end + @test θ̄[p][1] == 1 + end -W = [1 0; 0 1] -x = [1, 2] + @test gradient(2) do x + H = [1 x; 3 4] + sum(H) + end[1] == 1 -y, back = pullback(() -> W * x, Params([W])) -@test y == [1, 2] -@test back([1, 1])[W] == [1 2; 1 2] + @test gradient(2) do x + if x < 0 + throw("foo") + end + return x*5 + end[1] == 5 -layer = Layer(W) + @test gradient(x -> one(eltype(x)), rand(10))[1] === nothing -y, back = pullback(() -> layer(x), Params([W])) -@test y == [1, 2] -@test back([1, 1])[W] == [1 2; 1 2] + # Three-way control flow merge + @test gradient(1) do x + if x > 0 + x *= 2 + elseif x < 0 + x *= 3 + end + x + end[1] == 2 -@test gradient(() -> sum(W * x), Params([W]))[W] == [1 2; 1 2] -y, gr = withgradient(() -> sum(W * x), Params([W])) -@test y == 3 -@test gr[W] == [1 2; 1 2] + # Gradient of closure + grad_closure(x) = 2x -let - p = [1] - θ = Zygote.Params([p]) - θ̄ = gradient(θ) do - p′ = (p,)[1] - p′[1] - end - @test θ̄[p][1] == 1 -end + Zygote.@adjoint (f::typeof(grad_closure))(x) = f(x), Δ -> (1, 2) -@test gradient(2) do x - H = [1 x; 3 4] - sum(H) -end[1] == 1 + @test gradient((f, x) -> f(x), grad_closure, 5) == (1, 2) -@test gradient(2) do x - if x < 0 - throw("foo") - end - return x*5 -end[1] == 5 + invokable(x) = 2x + invokable(x::Integer) = 3x + @test gradient(x -> invoke(invokable, Tuple{Any}, x), 5) == (2,) -@test gradient(x -> one(eltype(x)), rand(10))[1] === nothing + y, back = Zygote.pullback(x->tuple(x...), [1, 2, 3]) + @test back((1, 1, 1)) == ((1,1,1),) +end -# Three-way control flow merge -@test gradient(1) do x - if x > 0 - x *= 2 - elseif x < 0 - x *= 3 +@testset "nested AD" begin + # For nested AD, until we support errors + function grad(f, args...) + y, back = pullback(f, args...) + return back(1) end - x -end[1] == 2 - -# Gradient of closure -grad_closure(x) = 2x -Zygote.@adjoint (f::typeof(grad_closure))(x) = f(x), Δ -> (1, 2) + D(f, x) = grad(f, x)[1] -@test gradient((f, x) -> f(x), grad_closure, 5) == (1, 2) + @test D(x -> D(sin, x), 0.5) == -sin(0.5) + @test D(x -> x*D(y -> x+y, 1), 1) == 1 + @test D(x -> x*D(y -> x*y, 1), 4) == 8 -invokable(x) = 2x -invokable(x::Integer) = 3x -@test gradient(x -> invoke(invokable, Tuple{Any}, x), 5) == (2,) - -y, back = Zygote.pullback(x->tuple(x...), [1, 2, 3]) -@test back((1, 1, 1)) == ((1,1,1),) + @test sin''(1.0) == -sin(1.0) + @test sin'''(1.0) == -cos(1.0) +end -# Test for some compiler errors on complex CFGs -function f(x) - while true - true && return - foo(x) && break +@testset "compiler error" begin + # Test for some compiler errors on complex CFGs + function f(x) + while true + true && return + foo(x) && break + end end -end -@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint + @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint -@test_throws ErrorException Zygote.gradient(1) do x - push!([], x) - return x -end + @test_throws ErrorException Zygote.gradient(1) do x + push!([], x) + return x + end -@test gradient(1) do x - stk = [] - Zygote._push!(stk, x) - stk = Zygote.Stack(stk) - pop!(stk) -end == (1,) + @test gradient(1) do x + stk = [] + Zygote._push!(stk, x) + stk = Zygote.Stack(stk) + pop!(stk) + end == (1,) -@test gradient(x -> [x][1].a, Bar(1, 1)) == ((a=1, b=nothing),) + @test gradient(x -> [x][1].a, Bar(1, 1)) == ((a=1, b=nothing),) -@test gradient((a, b) -> Zygote.hook(-, a)*b, 2, 3) == (-3, 2) + @test gradient((a, b) -> Zygote.hook(-, a)*b, 2, 3) == (-3, 2) -@test gradient(5) do x - forwarddiff(x -> x^2, x) -end == (10,) + @test gradient(5) do x + forwarddiff(x -> x^2, x) + end == (10,) +end @testset "Gradient chunking" begin for chunk_threshold in 1:10:100 @@ -395,6 +388,7 @@ end end == (2,) global_param = 3 +global_r = 1 @testset "Global Params" begin cx = Zygote.Context() @@ -406,27 +400,41 @@ global_param = 3 @test ref.mod == Main @test ref.name == :global_param @test Zygote.cache(cx)[ref] == 2 -end -function pow_try(x) - try - 2x - catch e - println("error") + function pow_global(x, n) + global global_r + while n > 0 + global_r *= x + n -= 1 + end + return global_r end + + @test gradient(pow_global, 2, 3) == (12, nothing) end -@test gradient(pow_try, 1) == (2,) -function pow_simd(x, n) - r = 1 - @simd for i = 1:n - r *= x +@testset "pow" begin + function pow_try(x) + try + 2x + catch e + println("error") + end + end + + @test gradient(pow_try, 1) == (2,) + + function pow_simd(x, n) + r = 1 + @simd for i = 1:n + r *= x + end + return r end - return r -end -@test gradient(pow_simd, 2, 3) == (12,nothing) + @test gradient(pow_simd, 2, 3) == (12,nothing) +end @testset "tuple getindex" begin @test gradient(x -> size(x)[2], ones(2,2,2)) == (nothing,) @@ -444,7 +452,11 @@ end end @testset "@timed" begin - @test gradient(x -> first(@timed x), 0) == (1,) + if VERSION >= v"1.11" + @test_broken gradient(x -> first(@timed x), 0) == (1,) + else + @test gradient(x -> first(@timed x), 0) == (1,) + end end mutable struct MyMutable @@ -616,19 +628,17 @@ end end == ([1, 2 * 10^1, 3 * 100^2],) # zip - if VERSION >= v"1.5" - # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch, - # while on 1.5 - 1.7 it stops early. - - @test gradient(10:14, 1:10) do xs, ys - sum([x/y for (x,y) in zip(xs, ys)]) - end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) - - @test_broken gradient(10:14, 1:10) do xs, ys - sum(x/y for (x,y) in zip(xs, ys)) # same without collect - # Here @adjoint function Iterators.Zip(xs) gets dy = (is = (nothing, nothing),) - end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) - end + # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch, + # while on 1.5 - 1.7 it stops early. + + @test gradient(10:14, 1:10) do xs, ys + sum([x/y for (x,y) in zip(xs, ys)]) + end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) + + @test_broken gradient(10:14, 1:10) do xs, ys + sum(x/y for (x,y) in zip(xs, ys)) # same without collect + # Here @adjoint function Iterators.Zip(xs) gets dy = (is = (nothing, nothing),) + end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) bk_z = pullback((xs,ys) -> sum([abs2(x*y) for (x,y) in zip(xs,ys)]), [1,2], [3im,4im])[2] @test bk_z(1.0)[1] isa AbstractVector{<:Real} # projection diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 133383048..561163c6b 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -405,10 +405,8 @@ end @testset "vararg map" begin # early stop - if VERSION >= v"1.5" # In Julia 1.4 and earlier, map(*,rand(5),[1,2,3]) is a DimensionMismatch - @test gradient(x -> sum(map(*,x,[1,2,3])), rand(5)) == ([1,2,3,0,0],) - end + @test gradient(x -> sum(map(*,x,[1,2,3])), rand(5)) == ([1,2,3,0,0],) @test gradient(x -> sum(map(*,x,(1,2,3))), rand(5)) == ([1,2,3,0,0],) @test gradient(x -> sum(map(*,x,[1,2,3])), Tuple(rand(5))) == ((1.0, 2.0, 3.0, nothing, nothing),) @@ -1063,7 +1061,13 @@ _randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing @testset "similar eigenvalues" begin λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10 A2 = U * Diagonal(λ) * U' - @test _gradtest_hermsym(f, ST, A2) + @static if VERSION >= v"1.11" + broken = f == sqrt && MT <: Symmetric{Float64} && domain == Real + # @show f MT domain + @test _gradtest_hermsym(f, ST, A2) broken=broken + else + @test _gradtest_hermsym(f, ST, A2) + end end if f ∉ (log, sqrt) # only defined for invertible matrices @@ -1168,10 +1172,20 @@ end return sum(sin.(vcat(vec.(_splitreim(B))...))) end === map(_->nothing, _splitreim(A)) else - @test gradtest(_splitreim(collect(A))...) do (args...) - A = ST(_joinreim(_dropimaggrad.(args)...)) - B = A^p - return vcat(vec.(_splitreim(B))...) + @static if VERSION >= v"1.11" + # @show MT p + broken = MT <: Symmetric{Float64} && p == -3 + @test gradtest(_splitreim(collect(A))...) do (args...) + A = ST(_joinreim(_dropimaggrad.(args)...)) + B = A^p + return vcat(vec.(_splitreim(B))...) + end broken=broken + else + @test gradtest(_splitreim(collect(A))...) do (args...) + A = ST(_joinreim(_dropimaggrad.(args)...)) + B = A^p + return vcat(vec.(_splitreim(B))...) + end end end @@ -1862,15 +1876,13 @@ end @test gradient(x -> sum(randexp(Random.GLOBAL_RNG, Float32, 1,1)), 1) == (nothing,) @test gradient(x -> sum(randexp(Random.GLOBAL_RNG, Float32, (1,1))), 1) == (nothing,) - @static if VERSION > v"1.3" - @test gradient(x -> sum(rand(Random.default_rng(), 4)), 1) == (nothing,) - @test gradient(x -> sum(rand(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) - @test gradient(x -> sum(rand(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) - @test gradient(x -> sum(randn(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) - @test gradient(x -> sum(randn(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) - @test gradient(x -> sum(randexp(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) - @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) - end + @test gradient(x -> sum(rand(Random.default_rng(), 4)), 1) == (nothing,) + @test gradient(x -> sum(rand(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) + @test gradient(x -> sum(rand(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) + @test gradient(x -> sum(randn(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) + @test gradient(x -> sum(randn(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) + @test gradient(x -> sum(randexp(Random.default_rng(), Float32, 1,1)), 1) == (nothing,) + @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) end @testset "broadcasted($op, Array, Bool)" for op in (+,-,*) diff --git a/test/tools.jl b/test/tools.jl index 77b268646..0cf7deefa 100644 --- a/test/tools.jl +++ b/test/tools.jl @@ -45,14 +45,15 @@ end struct Tester cpu_offload::Float64 - end - function Tester(p) + function Tester(p) # @show Zygote.isderiving(p) cpu_offload = Zygote.isderiving(p) ? 0.0 : 0.2 - Tester(cpu_offload) + new(cpu_offload) + end end + function f56(p) sum(Tester(p).cpu_offload .* p) end diff --git a/test/utils.jl b/test/utils.jl index 4e7f4929b..cfe383f49 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,22 +12,20 @@ using Zygote: hessian_dual, hessian_reverse @test_throws Exception hess(identity, randn(2)) end -VERSION > v"1.6-" && @testset "diagonal hessian" begin +@testset "diagonal hessian" begin @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],) - if VERSION > v"1.6-" - # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: - # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError: - # https://github.com/JuliaDiff/ChainRules.jl/issues/525 - # Fixed in 1.6 by: https://github.com/JuliaLang/julia/pull/36030 - xs, y = randn(2,3), rand() - f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments - - dx, dy = diaghessian(f34, xs, y) - @test size(dx) == size(xs) - @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) - @test dy ≈ hessian(y -> f34(xs,y), y) - end + # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6: + # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError: + # https://github.com/JuliaDiff/ChainRules.jl/issues/525 + # Fixed in 1.6 by: https://github.com/JuliaLang/julia/pull/36030 + xs, y = randn(2,3), rand() + f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments + + dx, dy = diaghessian(f34, xs, y) + @test size(dx) == size(xs) + @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) + @test dy ≈ hessian(y -> f34(xs,y), y) zs = randn(7,13) # test chunk mode @test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD