Skip to content

Commit

Permalink
Add option to disable partial convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
eplesiat committed Aug 21, 2024
1 parent 3c1ccf7 commit 20068c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions climatereconstructionai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def set_common_args():
arg_parser.add_argument('--masked-bn', action='store_true',
help="Use masked batch normalization instead of standard BN")
arg_parser.add_argument('--lazy-load', action='store_true', help="Use lazy loading for large datasets")
arg_parser.add_argument('--standard-conv', action='store_true', help="Disable partial convolution")
arg_parser.add_argument('--global-padding', action='store_true', help="Use a custom padding for global dataset")
arg_parser.add_argument('--normalize-data', action='store_true',
help="Normalize the input climate data to 0 mean and 1 std")
Expand Down
6 changes: 5 additions & 1 deletion climatereconstructionai/utils/netcdfloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, data_root, img_names, mask_root, mask_names, split, data_type
super(NetCDFLoader, self).__init__()

self.random = random.Random(cfg.loop_random_seed)
self.standard_conv = cfg.standard_conv

self.data_types = data_types
self.time_steps = time_steps
Expand Down Expand Up @@ -260,7 +261,10 @@ def __getitem__(self, index):
if cfg.n_target_data == 0 and i < cfg.n_output_data:
images.append(image[cfg.out_steps])
out_masks.append(self.create_out_mask(mask, i))
in_masks.append(mask[cfg.in_steps])
if self.standard_conv:
in_masks.append(torch.ones_like(mask[cfg.in_steps]))
else:
in_masks.append(mask[cfg.in_steps])
masked.append(image[cfg.in_steps] * mask[cfg.in_steps])

if cfg.channel_steps:
Expand Down

0 comments on commit 20068c8

Please sign in to comment.