Skip to content

Commit b25611d

Browse files
authored
Implement KernelAbstraction (#23)
1 parent dd6eb8c commit b25611d

21 files changed

+1917
-446
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ coverage
1212
docs/build/
1313
env
1414
node_modules
15+
test/test_figs/activation.png
16+
test/test_figs/downsampling_upsampling.png

Project.toml

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,65 @@ authors = ["SCiarella <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
8+
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
79
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8-
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
10+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
11+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
912
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1013
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1114
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
15+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1216
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
17+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1318
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1419
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1520
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16-
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1721

1822
[sources]
19-
CoupledNODE = {rev = "main", url = "https://github.com/DEEPDIP-project/CoupledNODE.jl.git"}
23+
CoupledNODE= {rev = "main", url = "https://github.com/DEEPDIP-project/CoupledNODE.jl.git"}
2024
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
2125

2226
[compat]
27+
AbstractFFTs = "1.5.0"
28+
Atomix = "1.1.1"
2329
CUDA = "5"
24-
ChainRules = "1"
30+
CairoMakie = "0.12"
31+
ChainRulesCore = "1.25.1"
32+
ChainRulesTestUtils = "1.13.0"
2533
ComponentArrays = "0.15"
2634
DifferentialEquations = "7.16.0"
2735
FFTW = "1"
36+
Images = "0.26.2"
2837
JuliaFormatter = "1.0.62"
38+
KernelAbstractions = "0.9.34"
2939
Lux = "1"
40+
LuxCUDA = "0.3.3"
3041
LuxCore = "1"
3142
NNlib = "0.9"
3243
Optimization = "4.1.1"
3344
OptimizationOptimisers = "0.3.7"
34-
Plots = "1.40.10"
3545
TestImages = "1.9.0"
36-
Tullio = "0.3"
37-
julia = "1.10"
46+
julia = "1.11"
3847

3948
[extras]
4049
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
50+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
51+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4152
CoupledNODE = "88291d29-22ea-41b1-bc0b-03785bffce48"
4253
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
54+
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
4355
IncompressibleNavierStokes = "5e318141-6589-402b-868d-77d7df8c442e"
4456
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
4557
NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25"
4658
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
4759
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
48-
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
4960
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
50-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
61+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
5162
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5263
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
5364
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
5465
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5566

5667
[targets]
57-
test = ["Test", "TestImages", "Adapt", "CoupledNODE", "IncompressibleNavierStokes", "JLD2", "NeuralClosure", "Optimisers", "OrdinaryDiffEqTsit5", "TestItemRunner", "Zygote", "Plots", "DifferentialEquations", "Optimization", "OptimizationOptimisers"]
68+
test = ["Test", "TestImages", "Images", "Adapt", "CoupledNODE", "IncompressibleNavierStokes", "JLD2", "NeuralClosure", "Optimisers", "OrdinaryDiffEqTsit5", "TestItemRunner", "Zygote", "CairoMakie", "DifferentialEquations", "Optimization", "OptimizationOptimisers", "ChainRulesTestUtils"]

src/ConvolutionalNeuralOperators.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ module ConvolutionalNeuralOperators
33
using CUDA: CUDA
44
ArrayType = CUDA.functional() ? CUDA.CuArray : Array
55

6+
include("filters.jl")
7+
include("convolution.jl")
8+
include("downsample.jl")
9+
include("upsample.jl")
610
include("utils.jl")
711
include("models.jl")
812

src/convolution.jl

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
using Lux: Lux, relu, leakyrelu
2+
using LuxCUDA
3+
using LuxCore: AbstractLuxLayer
4+
using Random: AbstractRNG
5+
using ComponentArrays: ComponentArray
6+
using KernelAbstractions
7+
using Atomix: @atomic
8+
using AbstractFFTs: fft, ifft
9+
using FFTW: fft, ifft
10+
11+
@kernel inbounds = true function convolve_kernel(ffty_r, ffty_im, fft_x, fft_k, ch_x)
12+
i, j, c, b = @index(Global, NTuple)
13+
for ci = 1:ch_x
14+
y = fft_x[i, j, ci, b] * fft_k[c, i, j]
15+
# In order to use atomic operation I have to split the real and imaginary part
16+
@atomic ffty_r[i, j, c, b] += real(y)
17+
@atomic ffty_im[i, j, c, b] += imag(y)
18+
end
19+
end
20+
21+
function convolve(x, k)
22+
fft_x = fft(x, (1, 2))
23+
fft_k = fft(k, (2, 3))
24+
25+
if CUDA.functional() && k isa CuArray
26+
# TODO: type is hardcoded
27+
ffty_r = CUDA.zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4))
28+
ffty_im = CUDA.zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4))
29+
backend = CUDABackend()
30+
workgroupsize = 256
31+
else
32+
ffty_r = zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4))
33+
ffty_im = zeros(Float32, size(x, 1), size(x, 2), size(k, 1), size(x, 4))
34+
backend = CPU()
35+
workgroupsize = 64
36+
end
37+
38+
# Launch the kernel
39+
convolve_kernel(backend, workgroupsize)(
40+
ffty_r,
41+
ffty_im,
42+
fft_x,
43+
fft_k,
44+
size(x, 3);
45+
ndrange = size(ffty_r),
46+
)
47+
48+
real(ifft(ComplexF32.(ffty_r, ffty_im), (1, 2)))
49+
end
50+
51+
52+
function ChainRulesCore.rrule(::typeof(convolve), x, k)
53+
# Given Y = X * K (where * denotes convolution),
54+
# the gradients for backpropagation are:
55+
#
56+
# 1. Gradient w.r.t. X:
57+
# ∂L/∂X = (∂L/∂Y) * flip(K)
58+
# In the Fourier domain: ℱ(∂L/∂X) = ℱ(∂L/∂Y) * conj(ℱ(K))
59+
#
60+
# 2. Gradient w.r.t. K:
61+
# ∂L/∂K = flip(X * (∂L/∂Y))
62+
# In the Fourier domain: ℱ(∂L/∂K) = conj(ℱ(X)) * ℱ(∂L/∂Y)
63+
#
64+
# Here, flip(K) represents a 180-degree rotation (flipping in both dimensions),
65+
# and conj() denotes the complex conjugate in the Fourier domain.
66+
67+
y = convolve(x, k)
68+
fft_x = fft(x, (1, 2))
69+
fft_k = fft(k, (2, 3))
70+
71+
function convolve_pb(y_bar)
72+
ffty_bar = fft(y_bar, (1, 2))
73+
74+
if CUDA.functional() && k isa CuArray
75+
x_bar_re = CUDA.zeros(Float32, size(x))
76+
x_bar_im = CUDA.zeros(Float32, size(x))
77+
k_bar_re = CUDA.zeros(Float32, size(k))
78+
k_bar_im = CUDA.zeros(Float32, size(k))
79+
backend = CUDABackend()
80+
workgroupsize = 256
81+
else
82+
x_bar_re = zeros(Float32, size(x))
83+
x_bar_im = zeros(Float32, size(x))
84+
k_bar_re = zeros(Float32, size(k))
85+
k_bar_im = zeros(Float32, size(k))
86+
backend = CPU()
87+
workgroupsize = 64
88+
end
89+
90+
# Launch the adjoint kernel for x
91+
convolve_adjoint_x_kernel(backend, workgroupsize)(
92+
x_bar_re,
93+
x_bar_im,
94+
ffty_bar,
95+
fft_k;
96+
ndrange = size(x),
97+
)
98+
# Launch the adjoint kernel for k
99+
convolve_adjoint_k_kernel(backend, workgroupsize)(
100+
k_bar_re,
101+
k_bar_im,
102+
fft_x,
103+
ffty_bar,
104+
size(x, 3);
105+
ndrange = size(k),
106+
)
107+
108+
x_bar = ComplexF32.(x_bar_re, x_bar_im)
109+
k_bar = ComplexF32.(k_bar_re, k_bar_im)
110+
111+
x_bar = real(ifft(x_bar, (1, 2)))
112+
k_bar = real(ifft(k_bar, (2, 3)))
113+
114+
return NoTangent(), x_bar, k_bar
115+
end
116+
return y, convolve_pb
117+
end
118+
119+
@kernel inbounds = true function convolve_adjoint_x_kernel(
120+
x_bar_re,
121+
x_bar_im,
122+
ffty_bar,
123+
fft_k,
124+
)
125+
i, j, ci, b = @index(Global, NTuple)
126+
for c = 1:size(fft_k, 1)
127+
# Use the complex conjugate to backprop the convolution
128+
y = ffty_bar[i, j, c, b] * conj(fft_k[c, i, j])
129+
@atomic x_bar_re[i, j, ci, b] += real(y)
130+
@atomic x_bar_im[i, j, ci, b] += imag(y)
131+
end
132+
end
133+
134+
@kernel inbounds = true function convolve_adjoint_k_kernel(
135+
k_bar_re,
136+
k_bar_im,
137+
fft_x,
138+
ffty_bar,
139+
ch_x,
140+
)
141+
c, i, j = @index(Global, NTuple)
142+
for b = 1:size(fft_x, 4)
143+
for ci = 1:ch_x
144+
y = conj(fft_x[i, j, ci, b]) * ffty_bar[i, j, c, b]
145+
@atomic k_bar_re[c, i, j] += real(y)
146+
@atomic k_bar_im[c, i, j] += imag(y)
147+
end
148+
end
149+
end
150+
151+
152+
function apply_masked_convolution(y, k, mask)
153+
# to get the correct k i have to reshape+mask+trim
154+
# TODO: i don't like this...
155+
# ! Zygote does not like that you reuse variable names so, this makes it even uglier with the definition of k2 and k3
156+
# ! also Zygote wants the mask to be explicitely defined as a vector so i have to pull it out from the tuple via mask=masks[i]
157+
158+
# Apply the mask to the kernel
159+
k2 = mask_kernel(k, mask)
160+
161+
# Adjust the kernel size to match the input dimensions
162+
k3 = trim_kernel(k2, size(y))
163+
164+
# Apply the convolution
165+
y = convolve(y, k3)
166+
167+
return y
168+
end
169+
170+
function trim_kernel(k, sizex)
171+
xx, xy, _, _ = sizex
172+
# Trim the kernel to match the input dimensions
173+
if k isa CuArray
174+
return CUDA.@allowscalar(k[:, 1:xx, 1:xy])
175+
else
176+
return @view k[:, 1:xx, 1:xy]
177+
end
178+
end
179+
180+
function ChainRulesCore.rrule(::typeof(trim_kernel), k, sizex)
181+
y = trim_kernel(k, sizex)
182+
if k isa CuArray
183+
k_bar = CUDA.zeros(Float32, size(k))
184+
else
185+
k_bar = zeros(Float32, size(k))
186+
end
187+
188+
function trim_kernel_pullback(y_bar)
189+
k_bar[:, 1:size(y_bar)[2], 1:size(y_bar)[3]] .= y_bar
190+
return NoTangent(), k_bar, NoTangent()
191+
end
192+
return y, trim_kernel_pullback
193+
end
194+
195+
196+
function mask_kernel(k, mask)
197+
permutedims(permutedims(k, [2, 3, 1]) .* mask, [3, 1, 2])
198+
end
199+
200+
function get_kernel(ks, chrange)
201+
if ks isa CuArray
202+
return CUDA.@allowscalar(ks[chrange, :, :])
203+
else
204+
return @view(ks[chrange, :, :])
205+
end
206+
end
207+
208+
function ChainRulesCore.rrule(::typeof(get_kernel), ks, chrange)
209+
result = get_kernel(ks, chrange)
210+
211+
function get_kernel_pullback(result_bar)
212+
if ks isa CuArray
213+
k_bar = CUDA.zeros(Float32, size(ks))
214+
k_bar[chrange, :, :] .= CUDA.@allowscalar(result_bar)
215+
else
216+
k_bar = zeros(Float32, size(ks))
217+
k_bar[chrange, :, :] .= result_bar
218+
end
219+
220+
return NoTangent(), k_bar, NoTangent()
221+
end
222+
223+
return result, get_kernel_pullback
224+
end

0 commit comments

Comments
 (0)