Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spurious "Output is complex, so the gradient is not defined" error #1471

Closed
seadra opened this issue Nov 14, 2023 · 2 comments
Closed

Spurious "Output is complex, so the gradient is not defined" error #1471

seadra opened this issue Nov 14, 2023 · 2 comments

Comments

@seadra
Copy link

seadra commented Nov 14, 2023

(Cross-posting this DiffEqFlux issue here since this seems to be Zygote related)

With the code

using Lux, Random, DifferentialEquations, ComponentArrays, Optimization, OptimizationFlux, Lux, Zygote, LinearAlgebra, DiffEqFlux

const T = 10.0;
const ω = π/T;

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    a = ann([t],p,st)[1];
    A = [1.0 a; a -1.0];
    return -im*A*u;
end


u0 = [Complex{Float64}(1) 0; 0 1];
tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ip);

utarget = [Complex{Float64}(0) im; im 0];

function predict_adjoint(p)
  return solve(prob_ode, Tsit5(), p=Complex{Float64}.(p), abstol=1e-12, reltol=1e-12)
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    usol = last(prediction)
    loss = 1.0 - abs(tr(usol*utarget')/2)^2
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, ADAM(0.1), maxiters = 100, progress=true);

I'm getting the following error

┌ Warning: ZygoteVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:94
Output is complex, so the gradient is not defined.┌ Warning: ReverseDiffVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:111
MethodError: no method matching gradient(::SciMLSensitivity.var"#258#264"{ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(f_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, ::Matrix{ComplexF64}, ::ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:1153, Axis(weight = ViewAxis(1:32, ShapedAxis((1, 32), NamedTuple())), bias = ViewAxis(33:33, ShapedAxis((1, 1), NamedTuple())))))}}})

Closest candidates are:
  gradient(::Any, ::Any)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
  gradient(::Any, ::Any, ::ReverseDiff.GradientConfig)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
┌ Warning: TrackerVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:129
Function output is not scalar┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/M3EmS/src/concrete_solve.jl:139

MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})

Closest candidates are:
  default_relstep(::Type, ::Any)
   @ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:25
  default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number}
   @ FiniteDiff ~/.julia/packages/FiniteDiff/grio1/src/epsilons.jl:26

I tried explicitly casting the output of the loss as return real(loss) but that didn't help either.

@mcabbott
Copy link
Member

A phrase in the error does appear in this package:

sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")

Can you post the full stack trace, showing lines where any Zygote code is being called?

Or better, can you reproduce this error with just Zygote?

@seadra
Copy link
Author

seadra commented Nov 15, 2023

It turns out that I made a couple of mistakes. First, I need to add using DiffEqSensitivity, second

    a = ann([t],p,st)[1];
    A = [1.0 a; a -1.0];

should read

    local a, _ = ann([t/T],p,st);
    local A = [a[1] 0.0; 0.0 -a[1]];

After these changes, it worked!

@seadra seadra closed this as completed Nov 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants