Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support transformation on CUDA gpu? #363

Open
bgctw opened this issue Jan 20, 2025 · 0 comments
Open

support transformation on CUDA gpu? #363

bgctw opened this issue Jan 20, 2025 · 0 comments

Comments

@bgctw
Copy link

bgctw commented Jan 20, 2025

Will Bijectors support transformations on CuArrays?
Maybe a subset of transformations?

Currently, the MWE that applies a Stacked transformation on a CuArray works, but computing the logabsdet of the Jacobian fails:

#]activate --temp
#add Bijectors, Zygote, CUDA
using Bijectors
using Zygote
x = [-2, 3.0, 2.1, 2.2]
ranges = [1:1, 2:2, 3:4]
tr = Stacked((elementwise(identity), elementwise(exp), elementwise(exp)), ranges)

using CUDA
xg = CuArray(x)
y = tr(xg)
gr = Zygote.jacobian(tr, xg) # works nicely

y, logjac = Bijectors.with_logabsdet_jacobian(tr, xg)
gr = Zygote.jacobian(x ->  Bijectors.with_logabsdet_jacobian(tr, x)[1], xg)

With error:

julia> gr = Zygote.jacobian(x ->  Bijectors.with_logabsdet_jacobian(tr, xg)[1], x)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#gpu_broadcast_kernel_linear#38")(::KernelAbstractions.CompilerMetadata{…}, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}) failed
KernelError: passing non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}}, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not a bitstype:
  .f is of type Zygote.var"#683#687"{Zygote.Context{false}, typeof(first)} which is not isbits.
    .cx is of type Zygote.Context{false} which is not isbits.
      .cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.


Only bitstypes, which are "plain data" types that are immutable
and contain no references to other values, can be used in GPU kernels.
For more information, see the `Base.isbitstype` function.

Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/scratch/twutz/julia_cluster_depots/packages/GPUCompiler/Nxf8r/src/validation.jl:108
...
 [63] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/scratch/twutz/julia_cluster_depots/packages/Zygote/59YyM/src/lib/grad.jl:168

Environment:

  [76274a88] Bijectors v0.15.4
  [052768ef] CUDA v5.6.1
  [e88e6eb3] Zygote v0.7.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant