diff --git a/Project.toml b/Project.toml index 07a6486dc8..566e3b4c69 100644 --- a/Project.toml +++ b/Project.toml @@ -64,6 +64,6 @@ Reexport = "1.0" Setfield = "1.1" SpecialFunctions = "2.1.2" Statistics = "1" -Zygote = "0.6.67" +Zygote = "0.6.67, 0.7" cuDNN = "1" julia = "1.10" diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 2398b322d5..2eb29c7ec6 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -1,10 +1,11 @@ module FluxAMDGPUExt import ChainRulesCore -import ChainRulesCore: NoTangent +import ChainRulesCore: NoTangent, unthunk import Flux import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib + using MLDataDevices using AMDGPU using Adapt @@ -13,14 +14,8 @@ using Zygote const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat - include("functor.jl") include("batchnorm.jl") include("conv.jl") - -# TODO -# fail early if input to the model is not on the device (e.g. on the host) -# otherwise we get very cryptic errors & segfaults at the rocBLAS level - end diff --git a/ext/FluxAMDGPUExt/batchnorm.jl b/ext/FluxAMDGPUExt/batchnorm.jl index 91dc61ba4d..4dab0b58ad 100644 --- a/ext/FluxAMDGPUExt/batchnorm.jl +++ b/ext/FluxAMDGPUExt/batchnorm.jl @@ -17,7 +17,7 @@ function ChainRulesCore.rrule( ) y, μ_saved, ν_saved = _amdgpu_batchnorm(x, γ, β; μ, σ², ϵ, within_grad) function _batchnorm_pullback(Δ) - dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved) + dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(unthunk(Δ), x, γ, β, μ_saved, ν_saved) (NoTangent(), dx, dγ, dβ) end y, _batchnorm_pullback diff --git a/test/ext_cuda/cuda.jl b/test/ext_cuda/cuda.jl index fc5e2c7bde..7758ff3d10 100644 --- a/test/ext_cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -106,7 +106,7 @@ end # Trivial functions @test gradient(x -> sum(abs, gpu(x)), a)[1] isa Matrix @test gradient(x -> sum(gpu(x)), a)[1] isa Matrix - @test_broken gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill + @test gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill @test gradient(x -> sum(abs, cpu(x)), ca)[1] isa CuArray # This test should really not go through indirections and pull out Fills for efficiency # but we forcefully materialise. TODO: remove materialising CuArray here @@ -207,4 +207,4 @@ end @test collect(post2) isa Vector{<:NamedTuple{(:x, :y)}} # collect makes no sense, but check eltype? # @test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),)) -end \ No newline at end of file +end