Skip to content

Commit 633a474

Browse files
committed
Use Adapt to keep type on GPU
Fix data creation
1 parent 7361cda commit 633a474

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Simone Ciarella <[email protected]>, Luisa Orozco <l.oroz
44
version = "0.0.3"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
910
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
@@ -34,7 +35,6 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
3435
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3536

3637
[weakdeps]
37-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3838
AttentionLayer = "3ee63b08-73c5-50c8-acc9-f395aa68c39a"
3939
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
4040
ConvolutionalNeuralOperators = "d769ba41-1544-53e8-a779-241a28c31cef"
@@ -55,8 +55,8 @@ NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralC
5555
AttentionCNN = ["AttentionLayer"]
5656
CNO = ["ConvolutionalNeuralOperators"]
5757
CoupledNODECUDA_ext = ["cuDNN", "CUDSS"]
58-
NavierStokes = ["IncompressibleNavierStokes", "NeuralClosure", "Adapt"]
59-
fno = ["NeuralOperators", "Adapt"]
58+
NavierStokes = ["IncompressibleNavierStokes", "NeuralClosure"]
59+
fno = ["NeuralOperators"]
6060

6161
[compat]
6262
Adapt = "4"
@@ -104,4 +104,4 @@ NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25"
104104
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
105105

106106
[targets]
107-
test = ["Test", "IncompressibleNavierStokes", "NeuralClosure", "Adapt", "cuDNN", "NeuralOperators"]
107+
test = ["Test", "IncompressibleNavierStokes", "NeuralClosure", "cuDNN", "NeuralOperators"]

ext/NavierStokes/create_data.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DifferentialEquations
2+
using JLD2
23
using IncompressibleNavierStokes: right_hand_side!, apply_bc_u!, momentum!, project!
34

45
function create_les_data_projected(;
@@ -75,9 +76,9 @@ function create_les_data_projected(;
7576
t in tdatapoint && return true
7677
return false
7778
end
78-
all_ules = Array{T}(undef, (nles[1] + 2, nles[1]+2, D, length(tdatapoint)-1))
79-
all_c = Array{T}(undef, (nles[1]+2, nles[1]+2, D, length(tdatapoint)-1))
80-
all_t = Array{T}(undef, (length(tdatapoint)-1))
79+
all_ules = Array{T}(undef, (nles[1] + 2, nles[1]+2, D, length(tdatapoint)))
80+
all_c = Array{T}(undef, (nles[1]+2, nles[1]+2, D, length(tdatapoint)))
81+
all_t = Array{T}(undef, (length(tdatapoint)))
8182
idx = Ref(1)
8283
Fdns = INS.create_right_hand_side(dns, psolver)
8384
p = scalarfield(les)
@@ -118,6 +119,9 @@ function create_les_data_projected(;
118119
u_current = u # Initial condition
119120
prob = ODEProblem(rhs!, u_current, nothing, nothing)
120121

122+
# Store the data at t=0
123+
filter_callback((; u = u, t = T(0)))
124+
121125
any(u -> any(isnan, u), u_current) &&
122126
@warn "Solution contains NaNs. Probably dt is too large."
123127

@@ -134,7 +138,7 @@ function create_les_data_projected(;
134138

135139
sol = solve(
136140
prob, sciml_solver; u0 = u_current, p = nothing,
137-
adaptive = false, dt = Δt, save_end = true, callback = cb,
141+
adaptive = true, save_end = true, callback = cb,
138142
tspan = tspan_chunk, tstops = tdatapoint
139143
)
140144

src/models/cnn.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A tuple `(chain, params, state)` where
2424
- `state`: The state of the model.
2525
"""
2626
function cnn(;
27-
T = Float32,
27+
T,
2828
D,
2929
data_ch,
3030
radii,
@@ -37,12 +37,13 @@ function cnn(;
3737
r, c, σ, b = radii, channels, activations, use_bias
3838

3939
if use_cuda
40-
dev = Lux.gpu_device()
40+
dev = x -> adapt(CuArray, x)
4141
else
4242
dev = Lux.cpu_device()
4343
end
4444

45-
@warn "*** CNN is using the following device: $(dev) "
45+
T = eltype(T(0.0))
46+
@warn "*** CNN is using the following device: $(dev) and type $(T)"
4647

4748
# Weight initializer
4849
glorot_uniform_T(rng::Random.AbstractRNG, dims...) = glorot_uniform(rng, T, dims...)
@@ -71,6 +72,7 @@ function cnn(;
7172
)...,
7273
decollocate
7374
)
75+
7476
chain = Chain(layers...)
7577
params, state = Lux.setup(rng, chain)
7678
state = state |> dev

src/train.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Zygote: Zygote
44
using Optimization: Optimization
55
using OptimizationOptimisers: OptimizationOptimisers
66
using ChainRulesCore: ignore_derivatives
7+
using Adapt: adapt
78

89
function train(model, ps, st, train_dataloader, loss_function;
910
nepochs = 50,
@@ -12,7 +13,7 @@ function train(model, ps, st, train_dataloader, loss_function;
1213
cpu::Bool = false,
1314
λ = nothing,
1415
kwargs...)
15-
dev = cpu ? identity : Lux.gpu_device()
16+
dev = cpu ? identity : x -> adapt(CuArray, x)
1617
if !cpu
1718
ps, st = (ps, st) .|> dev
1819
end

0 commit comments

Comments
 (0)