Skip to content

Commit b2d5e0e

Browse files
authored
1 parent 3f63c92 commit b2d5e0e

12 files changed

+70
-76
lines changed

Project.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@ CairoMakie = "0.12, 0.13"
3333
ChainRulesCore = "1.25.1"
3434
ChainRulesTestUtils = "1.13.0"
3535
ComponentArrays = "0.15"
36-
DifferentialEquations = "7.16.0"
36+
DifferentialEquations = "7"
3737
FFTW = "1"
38-
Images = "0.26.2"
38+
Images = "0.26"
3939
JuliaFormatter = "2"
40-
KernelAbstractions = "0.9.34"
40+
KernelAbstractions = "0.9"
4141
Lux = "1"
42-
LuxCUDA = "0.3.3"
42+
LuxCUDA = "0.3"
4343
LuxCore = "1"
4444
NNlib = "0.9"
4545
OpenSSL_jll = "3.0.13"
46-
Optimization = "4.1.1"
47-
OptimizationOptimisers = "0.3.7"
48-
TestImages = "1.9.0"
49-
Zygote = "0.6.76"
46+
Optimization = "4"
47+
OptimizationOptimisers = "0.3"
48+
TestImages = "1.9"
49+
Zygote = "0.7"
5050
julia = "1.11"
5151

5252
[extras]

src/convolution.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ function ChainRulesCore.rrule(::typeof(convolve), x, k)
6666
fft_k = fft(k, (2, 3))
6767

6868
function convolve_pb(y_bar)
69-
ffty_bar = fft(y_bar, (1, 2))
69+
yb = unthunk(y_bar)
70+
ffty_bar = fft(yb, (1, 2))
7071

7172
if CUDA.functional() && k isa CuArray
7273
x_bar_re = CUDA.zeros(Float32, size(x))
@@ -150,15 +151,15 @@ end
150151

151152
function apply_masked_convolution(y, k, mask)
152153
# to get the correct k i have to reshape+mask+trim
153-
# TODO: i don't like this...
154-
# ! Zygote does not like that you reuse variable names so, this makes it even uglier with the definition of k2 and k3
155-
# ! also Zygote wants the mask to be explicitely defined as a vector so i have to pull it out from the tuple via mask=masks[i]
154+
# ! Zygote does not like that you reuse variable names so k2 and k3 needs to be defined
155+
# ! also Zygote wants the mask to be explicitely defined as a vector so mask_kernel is needed
156156

157157
# Apply the mask to the kernel
158158
k2 = mask_kernel(k, mask)
159159

160-
# Adjust the kernel size to match the input dimensions
160+
## Adjust the kernel size to match the input dimensions
161161
k3 = trim_kernel(k2, size(y))
162+
#k3 = k2
162163

163164
# Apply the convolution
164165
y = convolve(y, k3)
@@ -178,20 +179,20 @@ end
178179

179180
function ChainRulesCore.rrule(::typeof(trim_kernel), k, sizex)
180181
y = trim_kernel(k, sizex)
181-
if k isa CuArray
182-
k_bar = CUDA.zeros(Float32, size(k))
183-
else
184-
k_bar = zeros(Float32, size(k))
185-
end
182+
k_bar = similar(k, Float32)
186183

187184
function trim_kernel_pullback(y_bar)
188-
k_bar[:, 1:size(y_bar)[2], 1:size(y_bar)[3]] .= y_bar
185+
yb = unthunk(y_bar)
186+
sz2, sz3 = size(yb, 2), size(yb, 3)
187+
k_bar .= 0 # clear first to be safe
188+
k_bar[:, 1:sz2, 1:sz3] .= yb
189189
return NoTangent(), k_bar, NoTangent()
190190
end
191191
return y, trim_kernel_pullback
192192
end
193193

194194

195+
195196
function mask_kernel(k, mask)
196197
permutedims(permutedims(k, [2, 3, 1]) .* mask, [3, 1, 2])
197198
end

src/downsample.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function ChainRulesCore.rrule(
5656

5757
dk_pb!(backend, workgroupsize)(
5858
x_filter_bar,
59-
result_bar,
59+
unthunk(result_bar),
6060
down_factor;
6161
ndrange = downsampled_size,
6262
)

src/models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ function ((;)::CNO)(x, params, state)
383383
)
384384
# concatenate with the corresponding bottleneck
385385
y = cat(y, bottlenecks_out[i], dims = D + 1)
386-
# apply the last bottleneck
386+
# apply the last bottleneck that combines the two branches
387387
# ! do not forget to reverse the bottleneck ranges
388388
y = apply_masked_convolution(
389389
y,

test/data_train.jld2

-404 KB
Binary file not shown.

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ Don't add your tests to runtests.jl. Instead, create files named
88
99
The file will be automatically included inside a `@testset` with title "Title For My Test".
1010
=#
11+
12+
# Helper function to check if a variable is on the GPU
13+
function is_on_gpu(x)
14+
return x isa CuArray || (x isa SubArray && is_on_gpu(x.parent))
15+
end
16+
1117
for (root, dirs, files) in walkdir(@__DIR__)
1218
for file in files
1319
if isnothing(match(r"^test-.*\.jl$", file))

test/test-couplednode_posterior.jl

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -82,42 +82,41 @@ batch = 4
8282
pairs = @. Symbol(k) => v
8383
(; pairs...)
8484
end
85-
data_train = []
86-
data_i = namedtupleload("data_train.jld2")
87-
push!(data_train, hcat(data_i))
85+
data_train = load("data_train.jld2", "data_train")
8886

8987
# Create the io array
9088
NS = Base.get_extension(CoupledNODE, :NavierStokes)
91-
io_train = NS.create_io_arrays_posteriori(data_train, setup)
89+
io_train = NS.create_io_arrays_posteriori(data_train, setup[1])
9290

9391
# Create the dataloader
9492
θ = device(copy(θ_start))
9593
nunroll = 2
9694
nunroll_valid = 2
9795
dataloader_post = NS.create_dataloader_posteriori(
98-
io_train[1];
96+
io_train;
9997
nunroll = nunroll,
10098
rng = Random.Xoshiro(24),
101-
device = device,
10299
)
100+
u, t = dataloader_post()
103101

104102
# Create the right hand side and the loss
105103
dudt_nn = NS.create_right_hand_side_with_closure(setup[1], psolver, closure, st)
104+
griddims = ((:) for _ = 1:D)
106105
loss = CoupledNODE.create_loss_post_lux(
107-
dudt_nn;
106+
dudt_nn,
107+
griddims,
108+
griddims;
108109
sciml_solver = Tsit5(),
109-
dt = T(conf["params"]["Δt"]),
110-
use_cuda = false,
110+
force_cpu = true,
111111
)
112112
callbackstate = trainstate = nothing
113113

114114

115115
# For testing reason, explicitely set up the probelm
116116
# Notice that this is automatically done in CoupledNODE
117117
u, t = dataloader_post()
118-
griddims = ((:) for _ = 1:(ndims(u)-2))
119-
x = u[griddims..., :, 1]
120-
y = u[griddims..., :, 2:end] # remember to discard sol at the initial time step
118+
x = u[griddims..., :, 1, 1]
119+
y = u[griddims..., :, 1, 2:end] # remember to discard sol at the initial time step
121120
tspan, dt, prob, pred = nothing, nothing, nothing, nothing # initialize variable outside allowscalar do.
122121
dt = @views t[2:2] .- t[1:1]
123122
dt = only(Array(dt))
@@ -126,9 +125,7 @@ batch = 4
126125
end
127126
tspan = get_tspan(t)
128127
prob = ODEProblem(dudt_nn, x, tspan, θ)
129-
pred = Array(
130-
solve(prob, Tsit5(); u0 = x, p = θ, adaptive = false, saveat = Array(t), dt = dt),
131-
)
128+
pred = Array(solve(prob, Tsit5(); u0 = x, p = θ, adaptive = true, saveat = Array(t)))
132129

133130
# Test the forward pass
134131
@test size(pred[:, :, :, 2:end]) == size(y)
@@ -226,54 +223,50 @@ end
226223
pairs = @. Symbol(k) => v
227224
(; pairs...)
228225
end
229-
data_train = []
230-
data_i = namedtupleload("data_train.jld2")
231-
push!(data_train, hcat(data_i))
226+
data_train = load("data_train.jld2", "data_train")
232227

233228
# Create the io array
234229
NS = Base.get_extension(CoupledNODE, :NavierStokes)
235-
io_train = NS.create_io_arrays_posteriori(data_train, setup)
230+
io_train = NS.create_io_arrays_posteriori(data_train, setup[1], device)
236231

237232
# Create the dataloader
238233
θ = device(copy(θ_start))
239234
nunroll = 2
240235
nunroll_valid = 2
241236
dataloader_post = NS.create_dataloader_posteriori(
242-
io_train[1];
237+
io_train;
243238
nunroll = nunroll,
244239
rng = Random.Xoshiro(24),
245240
device = device,
246241
)
242+
u, t = dataloader_post()
247243

248244
# Create the right hand side and the loss
249245
dudt_nn = NS.create_right_hand_side_with_closure(setup[1], psolver, closure, st)
246+
griddims = ((:) for _ = 1:D)
250247
loss = CoupledNODE.create_loss_post_lux(
251-
dudt_nn;
248+
dudt_nn,
249+
griddims,
250+
griddims;
252251
sciml_solver = Tsit5(),
253-
dt = T(conf["params"]["Δt"]),
254-
use_cuda = true,
255252
)
256253
callbackstate = trainstate = nothing
257254

258255

259256
# For testing reason, explicitely set up the probelm
260257
# Notice that this is automatically done in CoupledNODE
261-
u, t = dataloader_post()
262-
griddims = ((:) for _ = 1:(ndims(u)-2))
263-
x = u[griddims..., :, 1]
264-
y = u[griddims..., :, 2:end] # remember to discard sol at the initial time step
265-
tspan, dt, prob, pred = nothing, nothing, nothing, nothing # initialize variable outside allowscalar do.
266-
dt = CUDA.allowscalar() do
267-
t[2] .- t[1]
258+
x, y = nothing, nothing
259+
CUDA.allowscalar() do
260+
x = u[griddims..., :, 1, 1]
261+
y = u[griddims..., :, 1, 2:end] # remember to discard sol at the initial time step
268262
end
263+
tspan, dt, prob, pred = nothing, nothing, nothing, nothing # initialize variable outside allowscalar do.
269264
function get_tspan(t)
270265
return (Array(t)[1], Array(t)[end])
271266
end
272267
tspan = get_tspan(t)
273268
prob = ODEProblem(dudt_nn, x, tspan, θ)
274-
pred = Array(
275-
solve(prob, Tsit5(); u0 = x, p = θ, adaptive = false, saveat = Array(t), dt = dt),
276-
)
269+
pred = Array(solve(prob, Tsit5(); u0 = x, p = θ, adaptive = true, saveat = Array(t)))
277270

278271
# Test the forward pass
279272
@test size(pred[:, :, :, 2:end]) == size(y)

test/test-couplednode_prior.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,16 @@ batch = 4
8383
pairs = @. Symbol(k) => v
8484
(; pairs...)
8585
end
86-
data_train = []
87-
data_i = namedtupleload("data_train.jld2")
88-
push!(data_train, hcat(data_i))
86+
data_train = load("data_train.jld2", "data_train")
8987

9088
# Create the io array
9189
NS = Base.get_extension(CoupledNODE, :NavierStokes)
92-
io_train = NS.create_io_arrays_priori(data_train, setup)
90+
io_train = NS.create_io_arrays_priori(data_train, setup[1])
9391

9492
# Create the dataloader
9593
θ = device(copy(θ_start))
9694
dataloader_prior = NS.create_dataloader_prior(
97-
io_train[1];
95+
io_train;
9896
batchsize = 4,
9997
rng = Random.Xoshiro(24),
10098
device = device,
@@ -186,25 +184,23 @@ end
186184
pairs = @. Symbol(k) => v
187185
(; pairs...)
188186
end
189-
data_train = []
190-
data_i = namedtupleload("data_train.jld2")
191-
push!(data_train, hcat(data_i))
187+
data_train = load("data_train.jld2", "data_train")
192188

193189
# Create the io array
194190
NS = Base.get_extension(CoupledNODE, :NavierStokes)
195-
io_train = NS.create_io_arrays_priori(data_train, setup)
191+
io_train = NS.create_io_arrays_priori(data_train, setup[1], device)
196192

197193
# Create the dataloader
198194
θ = device(copy(θ_start))
199195
dataloader_prior = NS.create_dataloader_prior(
200-
io_train[1];
196+
io_train;
201197
batchsize = 4,
202198
rng = Random.Xoshiro(24),
203199
device = device,
204200
)
205201
train_data_priori = dataloader_prior()
206-
@test isa(train_data_priori[1], CuArray)
207-
@test isa(train_data_priori[2], CuArray)
202+
@test is_on_gpu(train_data_priori[1])
203+
@test is_on_gpu(train_data_priori[2])
208204

209205
l0 = CoupledNODE.loss_priori_lux(closure, θ, st, train_data_priori)[1]
210206
@test isnan(l0) == false

test/test-fullmodel.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ end
112112
@test isa(y, CuArray)
113113

114114

115-
return
116-
u_in = rand(T, size(u))
117-
tgt = rand(T, size(u))
115+
u_in = CUDA.rand(T, size(u))
116+
tgt = CUDA.rand(T, size(u))
118117
function loss(θ, batch = 16)
119118
yout = model(u_in, θ, st)[1]
120119
return sum(abs2, (yout .- tgt))

test/test-maskedconvolution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Lux: Lux
88
using CUDA
99
using LuxCUDA
1010
using ConvolutionalNeuralOperators:
11-
convolve, apply_masked_convolution, trim_kernel, get_kernel
11+
convolve, apply_masked_convolution, trim_kernel, get_kernel, mask_kernel
1212
using Zygote: Zygote
1313
using Test # Importing the Test module for @test statements
1414
using AbstractFFTs: fft, ifft

0 commit comments

Comments
 (0)