Skip to content

Commit

Permalink
Update deps & bump to 0.16.1 (#2574)
Browse files Browse the repository at this point in the history
* Update deps

* [AMDGPU] Correct batchnorm rrule

* Mark test as unbroken
  • Loading branch information
pxl-th authored Jan 21, 2025
1 parent 44695a0 commit b1a3a93
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Statistics = "1"
Zygote = "0.6.67"
Zygote = "0.6.67, 0.7"
cuDNN = "1"
julia = "1.10"
9 changes: 2 additions & 7 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module FluxAMDGPUExt

import ChainRulesCore
import ChainRulesCore: NoTangent
import ChainRulesCore: NoTangent, unthunk
import Flux
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib

using MLDataDevices
using AMDGPU
using Adapt
Expand All @@ -13,14 +14,8 @@ using Zygote

const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat


include("functor.jl")
include("batchnorm.jl")
include("conv.jl")


# TODO
# fail early if input to the model is not on the device (e.g. on the host)
# otherwise we get very cryptic errors & segfaults at the rocBLAS level

end
2 changes: 1 addition & 1 deletion ext/FluxAMDGPUExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function ChainRulesCore.rrule(
)
y, μ_saved, ν_saved = _amdgpu_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
function _batchnorm_pullback(Δ)
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(unthunk(Δ), x, γ, β, μ_saved, ν_saved)
(NoTangent(), dx, dγ, dβ)
end
y, _batchnorm_pullback
Expand Down
4 changes: 2 additions & 2 deletions test/ext_cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end
# Trivial functions
@test gradient(x -> sum(abs, gpu(x)), a)[1] isa Matrix
@test gradient(x -> sum(gpu(x)), a)[1] isa Matrix
@test_broken gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill
@test gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill
@test gradient(x -> sum(abs, cpu(x)), ca)[1] isa CuArray
# This test should really not go through indirections and pull out Fills for efficiency
# but we forcefully materialise. TODO: remove materialising CuArray here
Expand Down Expand Up @@ -207,4 +207,4 @@ end
@test collect(post2) isa Vector{<:NamedTuple{(:x, :y)}} # collect makes no sense, but check eltype?

# @test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),))
end
end

0 comments on commit b1a3a93

Please sign in to comment.