|
| 1 | +using Lux: Lux, relu, leakyrelu |
| 2 | +using LuxCUDA |
| 3 | +using LuxCore: AbstractLuxLayer |
| 4 | +using Random: AbstractRNG |
| 5 | +using ComponentArrays: ComponentArray |
| 6 | +using KernelAbstractions |
| 7 | +using Atomix: @atomic |
| 8 | +using AbstractFFTs: fft, ifft |
| 9 | +using FFTW: fft, ifft |
| 10 | + |
| 11 | +@kernel inbounds = true function convolve_kernel(ffty_r, ffty_im, fft_x, fft_k, ch_x) |
| 12 | + i, j, c, b = @index(Global, NTuple) |
| 13 | + for ci = 1:ch_x |
| 14 | + y = fft_x[i, j, ci, b] * fft_k[c, i, j] |
| 15 | + # In order to use atomic operation I have to split the real and imaginary part |
| 16 | + @atomic ffty_r[i, j, c, b] += real(y) |
| 17 | + @atomic ffty_im[i, j, c, b] += imag(y) |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +function convolve(x, k) |
| 22 | + fft_x = fft(x, (1, 2)) |
| 23 | + fft_k = fft(k, (2, 3)) |
| 24 | + |
| 25 | + if CUDA.functional() && k isa CuArray |
| 26 | + # TODO: type is hardcoded |
| 27 | + ffty_r = CUDA.zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4)) |
| 28 | + ffty_im = CUDA.zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4)) |
| 29 | + backend = CUDABackend() |
| 30 | + workgroupsize = 256 |
| 31 | + else |
| 32 | + ffty_r = zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4)) |
| 33 | + ffty_im = zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4)) |
| 34 | + backend = CPU() |
| 35 | + workgroupsize = 64 |
| 36 | + end |
| 37 | + |
| 38 | + # Launch the kernel |
| 39 | + convolve_kernel(backend, workgroupsize)( |
| 40 | + ffty_r, |
| 41 | + ffty_im, |
| 42 | + fft_x, |
| 43 | + fft_k, |
| 44 | + size(x, 3); |
| 45 | + ndrange = size(ffty_r), |
| 46 | + ) |
| 47 | + |
| 48 | + real(ifft(ComplexF32.(ffty_r, ffty_im), (1, 2))) |
| 49 | +end |
| 50 | + |
| 51 | + |
| 52 | +function ChainRulesCore.rrule(::typeof(convolve), x, k) |
| 53 | + # Given Y = X * K (where * denotes convolution), |
| 54 | + # the gradients for backpropagation are: |
| 55 | + # |
| 56 | + # 1. Gradient w.r.t. X: |
| 57 | + # ∂L/∂X = (∂L/∂Y) * flip(K) |
| 58 | + # In the Fourier domain: ℱ(∂L/∂X) = ℱ(∂L/∂Y) * conj(ℱ(K)) |
| 59 | + # |
| 60 | + # 2. Gradient w.r.t. K: |
| 61 | + # ∂L/∂K = flip(X * (∂L/∂Y)) |
| 62 | + # In the Fourier domain: ℱ(∂L/∂K) = conj(ℱ(X)) * ℱ(∂L/∂Y) |
| 63 | + # |
| 64 | + # Here, flip(K) represents a 180-degree rotation (flipping in both dimensions), |
| 65 | + # and conj() denotes the complex conjugate in the Fourier domain. |
| 66 | + |
| 67 | + y = convolve(x, k) |
| 68 | + fft_x = fft(x, (1, 2)) |
| 69 | + fft_k = fft(k, (2, 3)) |
| 70 | + |
| 71 | + function convolve_pb(y_bar) |
| 72 | + ffty_bar = fft(y_bar, (1, 2)) |
| 73 | + |
| 74 | + if CUDA.functional() && k isa CuArray |
| 75 | + x_bar_re = CUDA.zeros(Float32, size(x)) |
| 76 | + x_bar_im = CUDA.zeros(Float32, size(x)) |
| 77 | + k_bar_re = CUDA.zeros(Float32, size(k)) |
| 78 | + k_bar_im = CUDA.zeros(Float32, size(k)) |
| 79 | + backend = CUDABackend() |
| 80 | + workgroupsize = 256 |
| 81 | + else |
| 82 | + x_bar_re = zeros(Float32, size(x)) |
| 83 | + x_bar_im = zeros(Float32, size(x)) |
| 84 | + k_bar_re = zeros(Float32, size(k)) |
| 85 | + k_bar_im = zeros(Float32, size(k)) |
| 86 | + backend = CPU() |
| 87 | + workgroupsize = 64 |
| 88 | + end |
| 89 | + |
| 90 | + # Launch the adjoint kernel for x |
| 91 | + convolve_adjoint_x_kernel(backend, workgroupsize)( |
| 92 | + x_bar_re, |
| 93 | + x_bar_im, |
| 94 | + ffty_bar, |
| 95 | + fft_k; |
| 96 | + ndrange = size(x), |
| 97 | + ) |
| 98 | + # Launch the adjoint kernel for k |
| 99 | + convolve_adjoint_k_kernel(backend, workgroupsize)( |
| 100 | + k_bar_re, |
| 101 | + k_bar_im, |
| 102 | + fft_x, |
| 103 | + ffty_bar, |
| 104 | + size(x, 3); |
| 105 | + ndrange = size(k), |
| 106 | + ) |
| 107 | + |
| 108 | + x_bar = ComplexF32.(x_bar_re, x_bar_im) |
| 109 | + k_bar = ComplexF32.(k_bar_re, k_bar_im) |
| 110 | + |
| 111 | + x_bar = real(ifft(x_bar, (1, 2))) |
| 112 | + k_bar = real(ifft(k_bar, (2, 3))) |
| 113 | + |
| 114 | + return NoTangent(), x_bar, k_bar |
| 115 | + end |
| 116 | + return y, convolve_pb |
| 117 | +end |
| 118 | + |
| 119 | +@kernel inbounds = true function convolve_adjoint_x_kernel( |
| 120 | + x_bar_re, |
| 121 | + x_bar_im, |
| 122 | + ffty_bar, |
| 123 | + fft_k, |
| 124 | +) |
| 125 | + i, j, ci, b = @index(Global, NTuple) |
| 126 | + for c = 1:size(fft_k, 1) |
| 127 | + # Use the complex conjugate to backprop the convolution |
| 128 | + y = ffty_bar[i, j, c, b] * conj(fft_k[c, i, j]) |
| 129 | + @atomic x_bar_re[i, j, ci, b] += real(y) |
| 130 | + @atomic x_bar_im[i, j, ci, b] += imag(y) |
| 131 | + end |
| 132 | +end |
| 133 | + |
| 134 | +@kernel inbounds = true function convolve_adjoint_k_kernel( |
| 135 | + k_bar_re, |
| 136 | + k_bar_im, |
| 137 | + fft_x, |
| 138 | + ffty_bar, |
| 139 | + ch_x, |
| 140 | +) |
| 141 | + c, i, j = @index(Global, NTuple) |
| 142 | + for b = 1:size(fft_x, 4) |
| 143 | + for ci = 1:ch_x |
| 144 | + y = conj(fft_x[i, j, ci, b]) * ffty_bar[i, j, c, b] |
| 145 | + @atomic k_bar_re[c, i, j] += real(y) |
| 146 | + @atomic k_bar_im[c, i, j] += imag(y) |
| 147 | + end |
| 148 | + end |
| 149 | +end |
| 150 | + |
| 151 | + |
| 152 | +function apply_masked_convolution(y, k, mask) |
| 153 | + # to get the correct k i have to reshape+mask+trim |
| 154 | + # TODO: i don't like this... |
| 155 | + # ! Zygote does not like that you reuse variable names so, this makes it even uglier with the definition of k2 and k3 |
| 156 | + # ! also Zygote wants the mask to be explicitely defined as a vector so i have to pull it out from the tuple via mask=masks[i] |
| 157 | + |
| 158 | + # Apply the mask to the kernel |
| 159 | + k2 = mask_kernel(k, mask) |
| 160 | + |
| 161 | + # Adjust the kernel size to match the input dimensions |
| 162 | + k3 = trim_kernel(k2, size(y)) |
| 163 | + |
| 164 | + # Apply the convolution |
| 165 | + y = convolve(y, k3) |
| 166 | + |
| 167 | + return y |
| 168 | +end |
| 169 | + |
| 170 | +function trim_kernel(k, sizex) |
| 171 | + xx, xy, _, _ = sizex |
| 172 | + # Trim the kernel to match the input dimensions |
| 173 | + if k isa CuArray |
| 174 | + return CUDA.@allowscalar(k[:, 1:xx, 1:xy]) |
| 175 | + else |
| 176 | + return @view k[:, 1:xx, 1:xy] |
| 177 | + end |
| 178 | +end |
| 179 | + |
| 180 | +function ChainRulesCore.rrule(::typeof(trim_kernel), k, sizex) |
| 181 | + y = trim_kernel(k, sizex) |
| 182 | + if k isa CuArray |
| 183 | + k_bar = CUDA.zeros(Float32, size(k)) |
| 184 | + else |
| 185 | + k_bar = zeros(Float32, size(k)) |
| 186 | + end |
| 187 | + |
| 188 | + function trim_kernel_pullback(y_bar) |
| 189 | + k_bar[:, 1:size(y_bar)[2], 1:size(y_bar)[3]] .= y_bar |
| 190 | + return NoTangent(), k_bar, NoTangent() |
| 191 | + end |
| 192 | + return y, trim_kernel_pullback |
| 193 | +end |
| 194 | + |
| 195 | + |
| 196 | +function mask_kernel(k, mask) |
| 197 | + permutedims(permutedims(k, [2, 3, 1]) .* mask, [3, 1, 2]) |
| 198 | +end |
| 199 | + |
| 200 | +function get_kernel(ks, chrange) |
| 201 | + if ks isa CuArray |
| 202 | + return CUDA.@allowscalar(ks[chrange, :, :]) |
| 203 | + else |
| 204 | + return @view(ks[chrange, :, :]) |
| 205 | + end |
| 206 | +end |
| 207 | + |
| 208 | +function ChainRulesCore.rrule(::typeof(get_kernel), ks, chrange) |
| 209 | + result = get_kernel(ks, chrange) |
| 210 | + |
| 211 | + function get_kernel_pullback(result_bar) |
| 212 | + if ks isa CuArray |
| 213 | + k_bar = CUDA.zeros(Float32, size(ks)) |
| 214 | + k_bar[chrange, :, :] .= CUDA.@allowscalar(result_bar) |
| 215 | + else |
| 216 | + k_bar = zeros(Float32, size(ks)) |
| 217 | + k_bar[chrange, :, :] .= result_bar |
| 218 | + end |
| 219 | + |
| 220 | + return NoTangent(), k_bar, NoTangent() |
| 221 | + end |
| 222 | + |
| 223 | + return result, get_kernel_pullback |
| 224 | +end |
0 commit comments