Skip to content

Commit 712c15f

Browse files
committed
pino new training code
1 parent 021074e commit 712c15f

16 files changed

+639
-79
lines changed

Diff for: configs/operator/Re500-1_4-2000-FNO.yaml

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
data:
2+
name: KF
3+
paths: ['../data/NS-Re500_T3000_id0.npy']
4+
Re: 500
5+
total_num: 3000
6+
offset: 0
7+
n_samples: 600
8+
testoffset: 2500
9+
n_test_samples: 400
10+
t_duration: 0.25
11+
raw_res: [256, 256, 257]
12+
data_res: [256, 256, 257] # resolution in 1 second
13+
pde_res: [256, 256, 257] # resolution in 1 second
14+
shuffle: True
15+
16+
model:
17+
layers: [64, 64, 64, 64, 64]
18+
modes1: [8, 8, 8, 8]
19+
modes2: [8, 8, 8, 8]
20+
modes3: [8, 8, 8, 8]
21+
fc_dim: 128
22+
act: gelu
23+
num_pad: 4
24+
25+
train:
26+
batchsize: 2
27+
epochs: 401
28+
milestones: [100, 300]
29+
base_lr: 0.001
30+
scheduler_gamma: 0.5
31+
ic_loss: 0.0
32+
f_loss: 0.0
33+
xy_loss: 1.0
34+
save_step: 50
35+
36+
test:
37+
batchsize: 1
38+
data_res: [256, 256, 257]
39+
ckpt: model-400.pt
40+
41+
log:
42+
logdir: Re500-1_4s-2000-FNO
43+
entity: hzzheng-pino
44+
project: PINO-NS
45+
group: Re500-1_4s-2000-FNO

Diff for: configs/operator/Re500-1_8-1200-PINO.yaml

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
data:
2+
name: KF
3+
paths: ['../data/NS-Re500_T300_id0.npy']
4+
Re: 500
5+
total_num: 300
6+
offset: 0
7+
n_samples: 200
8+
testoffset: 200
9+
n_test_samples: 400
10+
t_duration: 0.125
11+
raw_res: [256, 256, 513]
12+
data_res: [64, 64, 257] # resolution in 1 second
13+
pde_res: [256, 256, 513] # resolution in 1 second
14+
shuffle: True
15+
16+
model:
17+
layers: [64, 64, 64, 64, 64]
18+
modes1: [8, 8, 8, 8]
19+
modes2: [8, 8, 8, 8]
20+
modes3: [8, 8, 8, 8]
21+
fc_dim: 128
22+
act: gelu
23+
num_pad: 4
24+
pad_ratio: 0.0625
25+
26+
train:
27+
batchsize: 2
28+
epochs: 201
29+
milestones: [50, 100, 150]
30+
base_lr: 0.001
31+
scheduler_gamma: 0.5
32+
ic_loss: 1.0
33+
f_loss: 1.0
34+
xy_loss: 5.0
35+
save_step: 25
36+
37+
test:
38+
batchsize: 1
39+
data_res: [256, 256, 257]
40+
ckpt: model-400.pt
41+
42+
log:
43+
logdir: Re500-1_8s-1200-PINO
44+
entity: hzzheng-pino
45+
project: PINO-NS
46+
group: Re500-1_8s-1200-PINO

Diff for: configs/operator/Re500-1_8-2000-FNO-xl.yaml

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
data:
2+
name: KF
3+
paths: ['../data/NS-Re500_T3000_id0.npy']
4+
Re: 500
5+
total_num: 3000
6+
offset: 0
7+
n_samples: 350
8+
testoffset: 2500
9+
n_test_samples: 400
10+
t_duration: 0.125
11+
raw_res: [256, 256, 257]
12+
data_res: [256, 256, 257] # resolution in 1 second
13+
pde_res: [256, 256, 257] # resolution in 1 second
14+
shuffle: True
15+
16+
model:
17+
layers: [64, 64, 64, 64, 64]
18+
modes1: [12, 12, 12, 12]
19+
modes2: [12, 12, 12, 12]
20+
modes3: [12, 12, 12, 12]
21+
fc_dim: 128
22+
act: gelu
23+
num_pad: 4
24+
25+
train:
26+
batchsize: 2
27+
epochs: 201
28+
milestones: [50, 100, 150]
29+
base_lr: 0.001
30+
scheduler_gamma: 0.5
31+
ic_loss: 0.0
32+
f_loss: 0.0
33+
xy_loss: 1.0
34+
save_step: 20
35+
36+
test:
37+
batchsize: 1
38+
data_res: [256, 256, 257]
39+
ckpt: model-400.pt
40+
41+
log:
42+
logdir: Re500-1_8s-2400-FNO
43+
entity: hzzheng-pino
44+
project: PINO-NS
45+
group: Re500-1_8s-2400-FNO

Diff for: configs/operator/Re500-1_8-2000-FNO.yaml

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
data:
2+
name: KF
3+
paths: ['../data/NS-Re500_T3000_id0.npy']
4+
Re: 500
5+
total_num: 3000
6+
offset: 0
7+
n_samples: 350
8+
testoffset: 2500
9+
n_test_samples: 400
10+
t_duration: 0.125
11+
raw_res: [256, 256, 257]
12+
data_res: [256, 256, 257] # resolution in 1 second
13+
pde_res: [256, 256, 257] # resolution in 1 second
14+
shuffle: True
15+
16+
model:
17+
layers: [64, 64, 64, 64, 64]
18+
modes1: [8, 8, 8, 8]
19+
modes2: [8, 8, 8, 8]
20+
modes3: [8, 8, 8, 8]
21+
fc_dim: 128
22+
act: gelu
23+
num_pad: 4
24+
pad_ratio: 0.0625
25+
26+
train:
27+
batchsize: 2
28+
epochs: 201
29+
milestones: [50, 100, 150]
30+
base_lr: 0.001
31+
scheduler_gamma: 0.5
32+
ic_loss: 0.0
33+
f_loss: 0.0
34+
xy_loss: 1.0
35+
save_step: 25
36+
37+
test:
38+
batchsize: 1
39+
data_res: [256, 256, 257]
40+
ckpt: model-400.pt
41+
42+
log:
43+
logdir: Re500-1_8s-2000-FNO
44+
entity: hzzheng-pino
45+
project: PINO-NS
46+
group: Re500-1_8s-2000-FNO

Diff for: configs/operator/Re500-1_8-2000-PINO.yaml

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
data:
2+
name: KF
3+
paths: ['../data/NS-Re500_T300_id0.npy']
4+
Re: 500
5+
offset: 0
6+
total_num: 300
7+
raw_res: [256, 256, 513]
8+
n_data_samples: 150
9+
data_res: [64, 64, 257] # resolution in 1 second
10+
pde_res: [256, 256, 513] # resolution in 1 second
11+
a_offset: 0
12+
n_a_samples: 250
13+
testoffset: 200
14+
n_test_samples: 50
15+
t_duration: 0.125
16+
shuffle: True
17+
18+
model:
19+
layers: [64, 64, 64, 64, 64]
20+
modes1: [8, 8, 8, 8]
21+
modes2: [8, 8, 8, 8]
22+
modes3: [8, 8, 8, 8]
23+
fc_dim: 128
24+
act: gelu
25+
num_pad: 4
26+
pad_ratio: 0.0625
27+
28+
train:
29+
batchsize: 2
30+
epochs: 201
31+
num_iter: 300_001
32+
milestones: [100_000, 200_000]
33+
base_lr: 0.001
34+
scheduler_gamma: 0.5
35+
ic_loss: 1.0
36+
f_loss: 1.0
37+
xy_loss: 5.0
38+
save_step: 20_000
39+
eval_step: 5000
40+
41+
test:
42+
batchsize: 1
43+
data_res: [256, 256, 257]
44+
ckpt: model-400.pt
45+
46+
log:
47+
logdir: Re500-1_8s-2000-PINO
48+
entity: hzzheng-pino
49+
project: PINO-NS
50+
group: Re500-1_8s-2000-PINO

Diff for: configs/scratch/Re500-scratch-1s.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ train:
3131
save_name: 'PINO-scratch128-1s.pt'
3232

3333
log:
34-
entity: hzzheng-pino
35-
project: PINO-NavierStokes
34+
entity: 'hzzheng-pino'
35+
project: 'PINO-NavierStokes'
3636
group: 'Re500-scratch-1s'
3737

3838

Diff for: generate_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def gen_data(args):
114114
parser.add_argument('--re', type=float, default=40.0)
115115
parser.add_argument('--x_res', type=int, default=512)
116116
parser.add_argument('--x_sub', type=int, default=2)
117-
parser.add_argument('--T', type=int, default=2000)
117+
parser.add_argument('--T', type=int, default=300)
118118
parser.add_argument('--outdir', type=str, default='../data')
119-
parser.add_argument('--t_res', type=int, default=256)
119+
parser.add_argument('--t_res', type=int, default=512)
120120
parser.add_argument('--batchsize', type=int, default=1)
121121
parser.add_argument('--num_batchs', type=int, default=1)
122122
args = parser.parse_args()

Diff for: models/basics.py

-17
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,6 @@
22

33
import torch
44
import torch.nn as nn
5-
import torch.nn.functional as F
6-
7-
8-
def _get_act(act):
9-
if act == 'tanh':
10-
func = F.tanh
11-
elif act == 'gelu':
12-
func = F.gelu
13-
elif act == 'relu':
14-
func = F.relu_
15-
elif act == 'elu':
16-
func = F.elu_
17-
elif act == 'leaky_relu':
18-
func = F.leaky_relu_
19-
else:
20-
raise ValueError(f'{act} is not supported')
21-
return func
225

236

247
@torch.jit.script

Diff for: models/fourier1d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch.nn as nn
2-
from .basics import SpectralConv1d, _get_act
2+
from .basics import SpectralConv1d
3+
from .utils import _get_act
34

45

56
class FNN1d(nn.Module):

Diff for: models/fourier2d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch.nn as nn
2-
from .basics import SpectralConv2d, _get_act
2+
from .basics import SpectralConv2d
3+
from .utils import _get_act
34

45

56
class FNN2d(nn.Module):

Diff for: models/fourier3d.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
import torch.nn as nn
2-
from .basics import SpectralConv3d, _get_act
2+
from .basics import SpectralConv3d
3+
from .utils import add_padding, remove_padding, _get_act
34

45

56
class FNN3d(nn.Module):
6-
def __init__(self, modes1, modes2, modes3,
7-
width=16, fc_dim=128,
7+
def __init__(self,
8+
modes1, modes2, modes3,
9+
width=16,
10+
fc_dim=128,
811
layers=None,
912
in_dim=4, out_dim=1,
10-
act='tanh'):
13+
act='tanh',
14+
pad_ratio=0):
1115
'''
1216
Args:
1317
modes1: list of int, first dimension maximal modes for each layer
1418
modes2: list of int, second dimension maximal modes for each layer
1519
modes3: list of int, third dimension maximal modes for each layer
1620
layers: list of int, channels for each layer
21+
fc_dim: dimension of fully connected layers
1722
in_dim: int, input dimension
1823
out_dim: int, output dimension
24+
act: {tanh, gelu, relu, leaky_relu}, activation function
25+
pad_ratio: the ratio of the extended domain
1926
'''
2027
super(FNN3d, self).__init__()
2128
self.modes1 = modes1
2229
self.modes2 = modes2
2330
self.modes3 = modes3
31+
self.pad_ratio = pad_ratio
2432

2533
if layers is None:
2634
self.layers = [width] * 4
@@ -49,6 +57,7 @@ def forward(self, x):
4957
u: (batchsize, x_grid, y_grid, t_grid, 1)
5058
5159
'''
60+
x = add_padding(x, pad_ratio=self.pad_ratio)
5261
length = len(self.ws)
5362
batchsize = x.shape[0]
5463
size_x, size_y, size_z = x.shape[1], x.shape[2], x.shape[3]
@@ -66,4 +75,5 @@ def forward(self, x):
6675
x = self.fc1(x)
6776
x = self.act(x)
6877
x = self.fc2(x)
78+
x = remove_padding(x, pad_ratio=self.pad_ratio)
6979
return x

Diff for: models/utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch.nn.functional as F
2+
3+
4+
def add_padding(x, pad_ratio):
5+
if pad_ratio > 0:
6+
num_pad = int(pad_ratio * x.shape[-2])
7+
res = F.pad(x, (0, 0, 0, num_pad), 'constant', 0)
8+
else:
9+
res = x
10+
return res
11+
12+
13+
def remove_padding(x, pad_ratio):
14+
if pad_ratio > 0:
15+
num_pad = int(pad_ratio * x.shape[-2])
16+
res = x[:, :, :, :-num_pad, 0]
17+
else:
18+
res = x
19+
return res
20+
21+
22+
def _get_act(act):
23+
if act == 'tanh':
24+
func = F.tanh
25+
elif act == 'gelu':
26+
func = F.gelu
27+
elif act == 'relu':
28+
func = F.relu_
29+
elif act == 'elu':
30+
func = F.elu_
31+
elif act == 'leaky_relu':
32+
func = F.leaky_relu_
33+
else:
34+
raise ValueError(f'{act} is not supported')
35+
return func
36+

0 commit comments

Comments
 (0)