@@ -82,42 +82,41 @@ batch = 4
82
82
pairs = @. Symbol (k) => v
83
83
(; pairs... )
84
84
end
85
- data_train = []
86
- data_i = namedtupleload (" data_train.jld2" )
87
- push! (data_train, hcat (data_i))
85
+ data_train = load (" data_train.jld2" , " data_train" )
88
86
89
87
# Create the io array
90
88
NS = Base. get_extension (CoupledNODE, :NavierStokes )
91
- io_train = NS. create_io_arrays_posteriori (data_train, setup)
89
+ io_train = NS. create_io_arrays_posteriori (data_train, setup[ 1 ] )
92
90
93
91
# Create the dataloader
94
92
θ = device (copy (θ_start))
95
93
nunroll = 2
96
94
nunroll_valid = 2
97
95
dataloader_post = NS. create_dataloader_posteriori (
98
- io_train[ 1 ] ;
96
+ io_train;
99
97
nunroll = nunroll,
100
98
rng = Random. Xoshiro (24 ),
101
- device = device,
102
99
)
100
+ u, t = dataloader_post ()
103
101
104
102
# Create the right hand side and the loss
105
103
dudt_nn = NS. create_right_hand_side_with_closure (setup[1 ], psolver, closure, st)
104
+ griddims = ((:) for _ = 1 : D)
106
105
loss = CoupledNODE. create_loss_post_lux (
107
- dudt_nn;
106
+ dudt_nn,
107
+ griddims,
108
+ griddims;
108
109
sciml_solver = Tsit5 (),
109
- dt = T (conf[" params" ][" Δt" ]),
110
- use_cuda = false ,
110
+ force_cpu = true ,
111
111
)
112
112
callbackstate = trainstate = nothing
113
113
114
114
115
115
# For testing reason, explicitely set up the probelm
116
116
# Notice that this is automatically done in CoupledNODE
117
117
u, t = dataloader_post ()
118
- griddims = ((:) for _ = 1 : (ndims (u)- 2 ))
119
- x = u[griddims... , :, 1 ]
120
- y = u[griddims... , :, 2 : end ] # remember to discard sol at the initial time step
118
+ x = u[griddims... , :, 1 , 1 ]
119
+ y = u[griddims... , :, 1 , 2 : end ] # remember to discard sol at the initial time step
121
120
tspan, dt, prob, pred = nothing , nothing , nothing , nothing # initialize variable outside allowscalar do.
122
121
dt = @views t[2 : 2 ] .- t[1 : 1 ]
123
122
dt = only (Array (dt))
@@ -126,9 +125,7 @@ batch = 4
126
125
end
127
126
tspan = get_tspan (t)
128
127
prob = ODEProblem (dudt_nn, x, tspan, θ)
129
- pred = Array (
130
- solve (prob, Tsit5 (); u0 = x, p = θ, adaptive = false , saveat = Array (t), dt = dt),
131
- )
128
+ pred = Array (solve (prob, Tsit5 (); u0 = x, p = θ, adaptive = true , saveat = Array (t)))
132
129
133
130
# Test the forward pass
134
131
@test size (pred[:, :, :, 2 : end ]) == size (y)
@@ -226,54 +223,50 @@ end
226
223
pairs = @. Symbol (k) => v
227
224
(; pairs... )
228
225
end
229
- data_train = []
230
- data_i = namedtupleload (" data_train.jld2" )
231
- push! (data_train, hcat (data_i))
226
+ data_train = load (" data_train.jld2" , " data_train" )
232
227
233
228
# Create the io array
234
229
NS = Base. get_extension (CoupledNODE, :NavierStokes )
235
- io_train = NS. create_io_arrays_posteriori (data_train, setup)
230
+ io_train = NS. create_io_arrays_posteriori (data_train, setup[ 1 ], device )
236
231
237
232
# Create the dataloader
238
233
θ = device (copy (θ_start))
239
234
nunroll = 2
240
235
nunroll_valid = 2
241
236
dataloader_post = NS. create_dataloader_posteriori (
242
- io_train[ 1 ] ;
237
+ io_train;
243
238
nunroll = nunroll,
244
239
rng = Random. Xoshiro (24 ),
245
240
device = device,
246
241
)
242
+ u, t = dataloader_post ()
247
243
248
244
# Create the right hand side and the loss
249
245
dudt_nn = NS. create_right_hand_side_with_closure (setup[1 ], psolver, closure, st)
246
+ griddims = ((:) for _ = 1 : D)
250
247
loss = CoupledNODE. create_loss_post_lux (
251
- dudt_nn;
248
+ dudt_nn,
249
+ griddims,
250
+ griddims;
252
251
sciml_solver = Tsit5 (),
253
- dt = T (conf[" params" ][" Δt" ]),
254
- use_cuda = true ,
255
252
)
256
253
callbackstate = trainstate = nothing
257
254
258
255
259
256
# For testing reason, explicitely set up the probelm
260
257
# Notice that this is automatically done in CoupledNODE
261
- u, t = dataloader_post ()
262
- griddims = ((:) for _ = 1 : (ndims (u)- 2 ))
263
- x = u[griddims... , :, 1 ]
264
- y = u[griddims... , :, 2 : end ] # remember to discard sol at the initial time step
265
- tspan, dt, prob, pred = nothing , nothing , nothing , nothing # initialize variable outside allowscalar do.
266
- dt = CUDA. allowscalar () do
267
- t[2 ] .- t[1 ]
258
+ x, y = nothing , nothing
259
+ CUDA. allowscalar () do
260
+ x = u[griddims... , :, 1 , 1 ]
261
+ y = u[griddims... , :, 1 , 2 : end ] # remember to discard sol at the initial time step
268
262
end
263
+ tspan, dt, prob, pred = nothing , nothing , nothing , nothing # initialize variable outside allowscalar do.
269
264
function get_tspan (t)
270
265
return (Array (t)[1 ], Array (t)[end ])
271
266
end
272
267
tspan = get_tspan (t)
273
268
prob = ODEProblem (dudt_nn, x, tspan, θ)
274
- pred = Array (
275
- solve (prob, Tsit5 (); u0 = x, p = θ, adaptive = false , saveat = Array (t), dt = dt),
276
- )
269
+ pred = Array (solve (prob, Tsit5 (); u0 = x, p = θ, adaptive = true , saveat = Array (t)))
277
270
278
271
# Test the forward pass
279
272
@test size (pred[:, :, :, 2 : end ]) == size (y)
0 commit comments