Skip to content

Commit 246fb65

Browse files
authored
Stop the _adapt_batch() from changing the batch in-place (#306)
1 parent 14c7249 commit 246fb65

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed
Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Base model class for multimodal model and unimodal teacher"""
2+
import copy
3+
24
from torchvision.transforms.functional import center_crop
35

46
from pvnet.models.base_model import BaseModel
@@ -8,44 +10,47 @@ class MultimodalBaseModel(BaseModel):
810
"""Base model class for multimodal model and unimodal teacher"""
911

1012
def _adapt_batch(self, batch):
11-
"""Slice batches into appropriate shapes for model
13+
"""Slice batches into appropriate shapes for model.
1214
15+
Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
1316
We make some specific assumptions about the original batch and the derived sliced batch:
1417
- We are only limiting the future projections. I.e. we are never shrinking the batch from
1518
the left hand side of the time axis, only slicing it from the right
1619
- We are only shrinking the spatial crop of the satellite and NWP data
1720
1821
"""
22+
# Create a copy of the batch to avoid modifying the original
23+
new_batch = {key: copy.deepcopy(value) for key, value in batch.items()}
1924

20-
if "gsp" in batch.keys():
25+
if "gsp" in new_batch.keys():
2126
# Slice off the end of the GSP data
2227
gsp_len = self.forecast_len + self.history_len + 1
23-
batch["gsp"] = batch["gsp"][:, :gsp_len]
24-
batch["gsp_time_utc"] = batch["gsp_time_utc"][:, :gsp_len]
28+
new_batch["gsp"] = new_batch["gsp"][:, :gsp_len]
29+
new_batch["gsp_time_utc"] = new_batch["gsp_time_utc"][:, :gsp_len]
2530

2631
if self.include_sat:
2732
# Slice off the end of the satellite data and spatially crop
2833
# Shape: batch_size, seq_length, channel, height, width
29-
batch["satellite_actual"] = center_crop(
30-
batch["satellite_actual"][:, : self.sat_sequence_len],
34+
new_batch["satellite_actual"] = center_crop(
35+
new_batch["satellite_actual"][:, : self.sat_sequence_len],
3136
output_size=self.sat_encoder.image_size_pixels,
3237
)
3338

3439
if self.include_nwp:
3540
# Slice off the end of the NWP data and spatially crop
3641
for nwp_source in self.nwp_encoders_dict:
3742
# shape: batch_size, seq_len, n_chans, height, width
38-
batch["nwp"][nwp_source]["nwp"] = center_crop(
39-
batch["nwp"][nwp_source]["nwp"],
43+
new_batch["nwp"][nwp_source]["nwp"] = center_crop(
44+
new_batch["nwp"][nwp_source]["nwp"],
4045
output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
4146
)[:, : self.nwp_encoders_dict[nwp_source].sequence_length]
4247

4348
if self.include_sun:
4449
# Slice off the end of the solar coords data
4550
for s in ["solar_azimuth", "solar_elevation"]:
4651
key = f"{self._target_key}_{s}"
47-
if key in batch.keys():
52+
if key in new_batch.keys():
4853
sun_len = self.forecast_len + self.history_len + 1
49-
batch[key] = batch[key][:, :sun_len]
54+
new_batch[key] = new_batch[key][:, :sun_len]
5055

51-
return batch
56+
return new_batch

0 commit comments

Comments
 (0)