diff --git a/src/chainrules.jl b/src/chainrules.jl index 97d4d22..7d1bb8d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,121 +1,152 @@ # ffts -function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) - y = fft(x, dims) - Δy = fft(Δx, dims) +# we explicitly handle both unprovided and provided dims arguments in all rules, which +# results in some additional complexity here but means no assumptions are made on what +# signatures downstream implementations support. +function ChainRulesCore.frule(Δargs, ::typeof(fft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = fft(x, dims_args...) + Δy = fft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) - y = fft(x, dims) +function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + y = fft(x, dims_args...) project_x = ChainRulesCore.ProjectTo(x) function fft_pullback(ȳ) - x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, fft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) - y = rfft(x, dims) - Δy = rfft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = rfft(x, dims_args...) + Δy = rfft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) - y = rfft(x, dims) +function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = rfft(x, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) d = size(x, halfdim) n = size(y, halfdim) scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) - x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, rfft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) - y = ifft(x, dims) - Δy = ifft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(ifft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + args = (dims === nothing) ? () : (dims,) + y = ifft(x, args...) + Δy = ifft(Δx, args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) - y = ifft(x, dims) - invN = normalization(y, dims) +function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = ifft(x, dims_args...) + invN = normalization(y, true_dims) project_x = ChainRulesCore.ProjectTo(x) function ifft_pullback(ȳ) - x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, ifft_pullback end -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) - y = irfft(x, d, dims) - Δy = irfft(Δx, d, dims) +function ChainRulesCore.frule(Δargs, ::typeof(irfft), x::AbstractArray, d::Int, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = irfft(x, d, dims_args...) + Δy = irfft(Δx, d, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) - y = irfft(x, d, dims) +function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = irfft(x, d, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) n = size(x, halfdim) - invN = normalization(y, dims) + invN = normalization(y, true_dims) twoinvN = 2 * invN scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent... end return y, irfft_pullback end -function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) - y = bfft(x, dims) - Δy = bfft(Δx, dims) +function ChainRulesCore.frule(Δargs, ::typeof(bfft), x::AbstractArray, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = bfft(x, dims_args...) + Δy = bfft(Δx, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) - y = bfft(x, dims) +function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + y = bfft(x, dims_args...) project_x = ChainRulesCore.ProjectTo(x) function bfft_pullback(ȳ) - x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, dims_args_tangent... end return y, bfft_pullback end -function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) - y = brfft(x, d, dims) - Δy = brfft(Δx, d, dims) +function ChainRulesCore.frule(Δargs, ::typeof(brfft), x::AbstractArray, d::Int, dims=nothing) + Δx = Δargs[2] + dims_args = (dims === nothing) ? () : (dims,) + y = brfft(x, d, dims_args...) + Δy = brfft(Δx, d, dims_args...) return y, Δy end -function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) - y = brfft(x, d, dims) +function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims=nothing) + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:ndims(x)) : dims + y = brfft(x, d, dims_args...) # compute scaling factors - halfdim = first(dims) + halfdim = first(true_dims) n = size(x, halfdim) scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ntuple(i -> i == first(true_dims) ? n : 1, Val(ndims(x))), ) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) - return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims_args...)) + dims_args_tangent = (dims === nothing) ? () : (ChainRulesCore.NoTangent(),) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), dims_args_tangent... end return y, brfft_pullback end diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..b8998bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -216,14 +216,14 @@ end @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + # type inference checks of `rrule` fail on old Julia versions + # for higher-dimensional arrays: + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 + check_inferred = ndims(x) < 3 || VERSION >= v"1.6" + for dims in ((), 1, 2, (1,2), 1:2) any(d > ndims(x) for d in dims) && continue - # type inference checks of `rrule` fail on old Julia versions - # for higher-dimensional arrays: - # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 - check_inferred = ndims(x) < 3 || VERSION >= v"1.6" - test_frule(AbstractFFTs.fftshift, x, dims) test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) @@ -237,23 +237,27 @@ end for x in (randn(3), randn(3, 4), randn(3, 4, 5)) N = ndims(x) complex_x = complex.(x) - for dims in unique((1, 1:N, N)) + for dims in unique((1, 1:N, N, nothing)) + # if dims=nothing, test handling of default dims argument + dims_args = (dims === nothing) ? () : (dims,) + true_dims = (dims === nothing) ? (1:N) : dims + for f in (fft, ifft, bfft) - test_frule(f, x, dims) - test_rrule(f, x, dims) - test_frule(f, complex_x, dims) - test_rrule(f, complex_x, dims) + test_frule(f, x, dims_args...) + test_rrule(f, x, dims_args...) + test_frule(f, complex_x, dims_args...) + test_rrule(f, complex_x, dims_args...) end - test_frule(rfft, x, dims) - test_rrule(rfft, x, dims) + test_frule(rfft, x, dims_args...) + test_rrule(rfft, x, dims_args...) for f in (irfft, brfft) - for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) - test_frule(f, x, d, dims) - test_rrule(f, x, d, dims) - test_frule(f, complex_x, d, dims) - test_rrule(f, complex_x, d, dims) + for d in (2 * size(x, first(true_dims)) - 1, 2 * size(x, first(true_dims)) - 2) + test_frule(f, x, d, dims_args...) + test_rrule(f, x, d, dims_args...) + test_frule(f, complex_x, d, dims_args...) + test_rrule(f, complex_x, d, dims_args...) end end end