Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forecasting capabilities #28

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dae031f
Add multiple timesteps forecasting
eplesiat Jan 4, 2024
ebdee5f
Fix issue with multiple steady masks
eplesiat Jan 12, 2024
0008dee
Fix issue memory time out
eplesiat Jan 12, 2024
c707b1d
Adapt time axes of output file in evaluate for pred_steps
eplesiat Jan 15, 2024
8478238
Fix steady_mask issue when partitioning the data
eplesiat Jan 17, 2024
a4609da
Apply steady mask to gt and image in evaluate
eplesiat Feb 1, 2024
e228ffb
Fix issue with history attribute in NetCDF output file
eplesiat Feb 15, 2024
7f6c0e5
Add variational autoencoder option
eplesiat Feb 27, 2024
d949529
Fix img size issue when using VAE
eplesiat Feb 29, 2024
a78fd97
Add more options to split the output
eplesiat Mar 22, 2024
9a24351
Fix NaNs in infilled when GT contains missing values
eplesiat Mar 25, 2024
a5b8ef3
Add extreme loss
eplesiat Jul 9, 2024
11b33d5
Add forecasting and VAE capabilities
eplesiat Jul 10, 2024
54b1e6c
Select loss criterions using names
eplesiat Jul 10, 2024
e63a489
Fix normalizer issue with new dataloader
eplesiat Jul 10, 2024
4b439df
Add tests
eplesiat Jul 10, 2024
6a27ac8
Make PEP-8 compliant
eplesiat Jul 10, 2024
a985b43
Add VAE training test
eplesiat Jul 10, 2024
d4fff5f
Fix steady mask issue with new dataloader
eplesiat Jul 10, 2024
336240b
Add missing pytorch-gpu package
eplesiat Jul 24, 2024
3c1ccf7
Add info about installation times
eplesiat Jul 25, 2024
20068c8
Add option to disable partial convolution
eplesiat Aug 21, 2024
e8ff1f4
Add dlprof profiler
eplesiat Sep 17, 2024
62d37a9
Fix loss computation issue with multi GPUs
eplesiat Sep 18, 2024
37b658c
Fix loss computation issue with multi GPUs
eplesiat Sep 19, 2024
cc8edc1
Fix inpainting loss with multi GPUs
eplesiat Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ conda activate crai

`environment-cuda.yml` should be used when working with GPUs using CUDA.

The installation time of the required dependencies should not exceed 15 minutes using a stable and standard internet connection

## Installation

`climatereconstructionAI` can be installed using `pip` in the current directory:
```bash
pip install .
```

The installation time of the Python package should not exceed 1 minute on a regular computer

## Usage

The software can be used to:
Expand Down
130 changes: 92 additions & 38 deletions climatereconstructionai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,88 @@ def set_lambdas():

lambda_dict = {}

if loss_criterion == 0:
if loss_criterion in ("0", "inpainting"):
lambda_dict['valid'] = 1.
lambda_dict['hole'] = 6.
lambda_dict['tv'] = .1
lambda_dict['prc'] = .05
lambda_dict['style'] = 120.

elif loss_criterion == 1:
elif loss_criterion in ("1", "l1-hole"):
lambda_dict['hole'] = 1.

elif loss_criterion == 2:
elif loss_criterion in ("2", "downscaling"):
lambda_dict['valid'] = 7.
lambda_dict['hole'] = 0.
lambda_dict['tv'] = .1
lambda_dict['prc'] = .05
lambda_dict['style'] = 120.
elif loss_criterion == 3:

elif loss_criterion in ("3", "l1-valid"):
lambda_dict['valid'] = 1.

elif loss_criterion in ("4", "extreme"):
lambda_dict['-extreme'] = 1.
lambda_dict['+extreme'] = 1.

if vae_zdim != 0:
lambda_dict['kldiv'] = 1.

if lambda_loss is not None:
lambda_dict.update(lambda_loss)


def set_steps(evaluate=False):

assert sum(bool(x) for x in [lstm_steps, gru_steps, channel_steps]) < 2, \
"lstm, gru and channel options are mutually exclusive"

global recurrent_steps, n_recurrent_steps
global time_steps
time_steps = [0, 0]
if lstm_steps:
time_steps = lstm_steps
recurrent_steps = lstm_steps[0]
elif gru_steps:
time_steps = gru_steps
recurrent_steps = gru_steps[0]
else:
recurrent_steps = 0

n_recurrent_steps = sum(time_steps) + 1

global n_channel_steps, gt_channels
if channel_steps:
time_steps = channel_steps
n_channel_steps = sum(channel_steps) + 1
gt_channels = [i * n_channel_steps + channel_steps[0] for i in range(n_output_data)]
else:
n_channel_steps = 1
gt_channels = [0 for i in range(n_output_data)]

global n_time_steps, in_steps, out_steps, n_pred_steps, pred_timestep, out_channels

n_time_steps = sum(time_steps) + 1
pred_timestep = list(range(-pred_steps[0], pred_steps[1] + 1))
n_pred_steps = len(pred_timestep)

if evaluate:
in_steps = range(0, n_time_steps)
out_steps = [time_steps[0]]
out_channels = n_output_data * n_pred_steps
else:
in_step = max(pred_steps[0] - time_steps[0], 0)
in_steps = range(in_step, in_step + n_time_steps)
n_time_steps = len(in_steps)
interval = [max(time_steps[i], pred_steps[i]) for i in range(2)]
out_steps = range(interval[0] - pred_steps[0], interval[0] + pred_steps[1] + 1)
time_steps = interval

out_channels = n_output_data * len(out_steps)

assert len(time_steps) == 2


def global_args(parser, arg_file=None, prog_func=None):
import torch

Expand Down Expand Up @@ -109,33 +168,18 @@ def global_args(parser, arg_file=None, prog_func=None):
if not os.path.exists(log_dir):
os.makedirs(log_dir)

global recurrent_steps
global n_recurrent_steps
global time_steps
time_steps = [0, 0]
if lstm_steps:
recurrent_steps = lstm_steps[0]
time_steps = lstm_steps
elif gru_steps:
recurrent_steps = gru_steps[0]
time_steps = gru_steps
else:
recurrent_steps = 0

n_recurrent_steps = sum(time_steps) + 1

global n_channel_steps
global gt_channels
global n_output_data
if n_target_data > 0:
n_output_data = n_target_data

n_channel_steps = 1
gt_channels = [0 for i in range(out_channels)]
if channel_steps:
time_steps = channel_steps
n_channel_steps = sum(channel_steps) + 1
for i in range(out_channels):
gt_channels[i] = (i + 1) * channel_steps[0] + i * (channel_steps[1] + 1)
global min_bounds, max_bounds
if len(min_bounds) == 1:
min_bounds = [min_bounds[0] for i in range(n_output_data)]
if len(max_bounds) == 1:
max_bounds = [max_bounds[0] for i in range(n_output_data)]

assert len(time_steps) == 2
assert len(min_bounds) == n_output_data
assert len(max_bounds) == n_output_data

if all('.json' in data_name for data_name in data_names) and (lstm_steps or channel_steps):
print('Warning: Each input file defined in your ".json" files will be considered individually.'
Expand Down Expand Up @@ -163,21 +207,24 @@ def set_common_args():
help="Number of data-names (from last) to be used as target data")
arg_parser.add_argument('--device', type=str, default='cuda', help="Device used by PyTorch (cuda or cpu)")
arg_parser.add_argument('--shuffle-masks', action='store_true', help="Select mask indices randomly")
arg_parser.add_argument('--vae-zdim', type=int, default=0, help="Use VAE with latent space dimension")
arg_parser.add_argument('--channel-steps', type=int_list, default=None,
help="Comma separated number of considered sequences for channeled memory:"
"past_steps,future_steps")
arg_parser.add_argument('--lstm-steps', type=int_list, default=None,
help="Comma separated number of considered sequences for lstm: past_steps,future_steps")
arg_parser.add_argument('--gru-steps', type=int_list, default=None,
help="Comma separated number of considered sequences for gru: past_steps,future_steps")
arg_parser.add_argument('--pred-steps', type=int_list, default=[0, 0],
help="Comma separated number of considered sequences for pred: past_steps,future_steps")
arg_parser.add_argument('--encoding-layers', type=int_list, default='3',
help="Number of encoding layers in the CNN")
arg_parser.add_argument('--pooling-layers', type=int_list, default='0', help="Number of pooling layers in the CNN")
arg_parser.add_argument('--conv-factor', type=int, default=None, help="Number of channels in the deepest layer")
arg_parser.add_argument('--weights', type=str, default=None, help="Initialization weight")
arg_parser.add_argument('--steady-masks', type=str_list, default=None,
help="Comma separated list of netCDF files containing a single mask to be applied "
"to all timesteps. The number of steady-masks must be the same as out-channels")
"to all timesteps. The number of steady-masks must be the same as n-output-data")
arg_parser.add_argument('--loop-random-seed', type=int, default=None,
help="Random seed for iteration loop")
arg_parser.add_argument('--cuda-random-seed', type=int, default=None,
Expand All @@ -192,17 +239,18 @@ def set_common_args():
arg_parser.add_argument('--masked-bn', action='store_true',
help="Use masked batch normalization instead of standard BN")
arg_parser.add_argument('--lazy-load', action='store_true', help="Use lazy loading for large datasets")
arg_parser.add_argument('--standard-conv', action='store_true', help="Disable partial convolution")
arg_parser.add_argument('--global-padding', action='store_true', help="Use a custom padding for global dataset")
arg_parser.add_argument('--normalize-data', action='store_true',
help="Normalize the input climate data to 0 mean and 1 std")
arg_parser.add_argument('--n-filters', type=int, default=None, help="Number of filters for the first/last layer")
arg_parser.add_argument('--out-channels', type=int, default=1, help="Number of channels for the output data")
arg_parser.add_argument('--n-output-data', type=int, default=1, help="Number of output data")
arg_parser.add_argument('--dataset-name', type=str, default=None, help="Name of the dataset for format checking")
arg_parser.add_argument('--min-bounds', type=float_list, default="-inf",
help="Comma separated list of values defining the permitted lower-bound of output values")
arg_parser.add_argument('--max-bounds', type=float_list, default="inf",
help="Comma separated list of values defining the permitted upper-bound of output values")
arg_parser.add_argument('--profile', action='store_true', help="Profile code using tensorboard profiler")
arg_parser.add_argument('--profiler', type=str, default=None, help="Use specified profiler")
return arg_parser


Expand Down Expand Up @@ -231,9 +279,8 @@ def set_train_args(arg_file=None):
help="Number of final models to be saved")
arg_parser.add_argument('--final-models-interval', type=int, default=1000,
help="Iteration step interval at which the final models should be saved")
arg_parser.add_argument('--loss-criterion', type=int, default=0,
help="Index defining the loss function "
"(0=original from Liu et al., 1=MAE of the hole region)")
arg_parser.add_argument('--loss-criterion', type=str, default="l1-hole",
help="Index/string defining the loss function (inpainting/l1-hole/l1-valid/etc.)")
arg_parser.add_argument('--eval-timesteps', type=int_list, default=None,
help="Sample indices for which a snapshot is created at each iter defined by log-interval")
arg_parser.add_argument('-f', '--load-from-file', type=str, action=LoadFromFile,
Expand All @@ -257,6 +304,7 @@ def set_train_args(arg_file=None):
help="Number of batch iterations used to average the validation loss")

args = global_args(arg_parser, arg_file)
set_steps()

global passed_args
passed_args = get_passed_arguments(args, arg_parser)
Expand All @@ -276,22 +324,28 @@ def set_train_args(arg_file=None):
def set_evaluate_args(arg_file=None, prog_func=None):
arg_parser = set_common_args()
arg_parser.add_argument('--model-dir', type=str, default='snapshots/ckpt/', help="Directory of the trained models")
arg_parser.add_argument('--model-names', type=str_list, default='1000000.pth', help="Model names")
arg_parser.add_argument('--model-names', type=str_list, default='final.pth', help="Model names")
arg_parser.add_argument('--evaluation-dirs', type=str_list, default='evaluation/',
help="Directory where the output files will be stored")
arg_parser.add_argument('--eval-names', type=str_list, default='output',
help="Prefix used for the output filenames")
arg_parser.add_argument('--use-train-stats', action='store_true',
help="Use mean and std from training data for normalization")
arg_parser.add_argument('--n-evaluations', type=int, default=1, help="Number of evaluations")
arg_parser.add_argument('--create-graph', action='store_true', help="Create a Tensorboard graph of the NN")
arg_parser.add_argument('--plot-results', type=int_list, default=[],
help="Create plot images of the results for the comma separated list of time indices")
arg_parser.add_argument('--partitions', type=int, default=1,
help="Split the climate dataset into several partitions along the time coordinate")
arg_parser.add_argument('--maxmem', type=int, default=None,
help="Maximum available memory in MB (overwrite partitions parameter)")
arg_parser.add_argument('--split-outputs', action='store_true',
help="Do not merge the outputs when using multiple models and/or partitions")
arg_parser.add_argument('--time-freq', type=str, default=None,
help="Time frequency for pred-steps option (only for D,H,M,S,etc.)")
arg_parser.add_argument('--split-outputs', type=str, default="all", const=None, nargs='?',
help="Split the outputs according to members and/or partitions")
arg_parser.add_argument('-f', '--load-from-file', type=str, action=LoadFromFile,
help="Load all the arguments from a text file")
global_args(arg_parser, arg_file, prog_func)
set_steps(evaluate=True)
assert len(eval_names) == n_output_data
globals()["model_names"] *= globals()["n_evaluations"]
50 changes: 33 additions & 17 deletions climatereconstructionai/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def store_encoding(ds):
return ds


def format_time(ds):
ds['time'].encoding = encoding
ds['time'].encoding['original_shape'] = len(ds["time"])
return ds.transpose("time", ...).reset_coords(drop=True)


def evaluate(arg_file=None, prog_func=None):
cfg.set_evaluate_args(arg_file, prog_func)

Expand All @@ -36,7 +42,7 @@ def evaluate(arg_file=None, prog_func=None):
data_stats = None

dataset_val = NetCDFLoader(cfg.data_root_dir, cfg.data_names, cfg.mask_dir, cfg.mask_names, "infill",
cfg.data_types, cfg.time_steps, data_stats)
cfg.data_types, cfg.time_steps, cfg.steady_masks, data_stats)

n_samples = len(dataset_val)

Expand Down Expand Up @@ -79,28 +85,38 @@ def evaluate(arg_file=None, prog_func=None):
batch_size = get_batch_size(model.parameters(), n_samples, image_sizes)
iterator_val = iter(DataLoader(dataset_val, batch_size=batch_size,
sampler=FiniteSampler(len(dataset_val)), num_workers=0))
infill(model, iterator_val, eval_path, output_names, data_stats, dataset_val.xr_dss, count)
infill(model, iterator_val, eval_path, output_names, dataset_val.steady_mask, data_stats,
dataset_val.xr_dss, count)

for name in output_names:
if len(output_names[name]) == 1 and len(output_names[name][1]) == 1:
os.rename(output_names[name][1][0], name + ".nc")
else:
if not cfg.split_outputs:
dss = []
for i_model in output_names[name]:
dss.append(xr.open_mfdataset(output_names[name][i_model], preprocess=store_encoding, autoclose=True,
combine='nested', data_vars='minimal', concat_dim="time", chunks={}))
dss[-1] = dss[-1].assign_coords({"member": i_model})

if len(dss) == 1:
ds = dss[-1].drop("member")
if cfg.split_outputs is not None:

if cfg.split_outputs == "time":
k = 0
for names in zip(*(output_names[name].values())):
k += 1
dss = [xr.open_dataset(names[i]).assign_coords({"member": i}) for i in range(len(names))]
xr.concat(dss, dim="member").to_netcdf("{}-{}.nc".format(name, k))
else:
ds = xr.concat(dss, dim="member")

ds['time'].encoding = encoding
ds['time'].encoding['original_shape'] = len(ds["time"])
ds = ds.transpose("time", ...).reset_coords(drop=True)
ds.to_netcdf(name + ".nc")
dss = []
for i_model in output_names[name]:
ds = xr.open_mfdataset(output_names[name][i_model], preprocess=store_encoding, autoclose=True,
combine='nested', data_vars='minimal', concat_dim="time", chunks={})
ds = ds.assign_coords({"member": i_model})
if cfg.split_outputs == "member":
format_time(ds).to_netcdf("{}.{}.nc".format(name, i_model))
else:
dss.append(ds)

if cfg.split_outputs != "member":
if len(dss) == 1:
ds = dss[-1].drop("member")
else:
ds = xr.concat(dss, dim="member")
format_time(ds).to_netcdf(name + ".nc")

for i_model in output_names[name]:
for output_name in output_names[name][i_model]:
Expand Down
28 changes: 28 additions & 0 deletions climatereconstructionai/loss/extreme_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn as nn


class ExtremeLoss(nn.Module):
def __init__(self):
super().__init__()
self.l2 = nn.MSELoss()
self.sm = nn.Softmax(dim=0)

def forward(self, data_dict):
loss_dict = {
'-extreme': 0.0,
'+extreme': 0.0,
}

output = data_dict['output']
gt = data_dict['gt']

# calculate loss for all channels
for channel in range(output.shape[1]):

gt_ch = torch.unsqueeze(gt[:, channel, :, :], dim=1)
output_ch = torch.unsqueeze(output[:, channel, :, :], dim=1)
loss_dict['-extreme'] += self.l2(self.sm(-output_ch), self.sm(-gt_ch))
loss_dict['+extreme'] += self.l2(self.sm(output_ch), self.sm(gt_ch))

return loss_dict
Loading
Loading