11"""Base model class for multimodal model and unimodal teacher"""
2+ import copy
3+
24from torchvision .transforms .functional import center_crop
35
46from pvnet .models .base_model import BaseModel
@@ -8,44 +10,47 @@ class MultimodalBaseModel(BaseModel):
810 """Base model class for multimodal model and unimodal teacher"""
911
1012 def _adapt_batch (self , batch ):
11- """Slice batches into appropriate shapes for model
13+ """Slice batches into appropriate shapes for model.
1214
15+ Returns a new batch dictionary with adapted data, leaving the original batch unchanged.
1316 We make some specific assumptions about the original batch and the derived sliced batch:
1417 - We are only limiting the future projections. I.e. we are never shrinking the batch from
1518 the left hand side of the time axis, only slicing it from the right
1619 - We are only shrinking the spatial crop of the satellite and NWP data
1720
1821 """
22+ # Create a copy of the batch to avoid modifying the original
23+ new_batch = {key : copy .deepcopy (value ) for key , value in batch .items ()}
1924
20- if "gsp" in batch .keys ():
25+ if "gsp" in new_batch .keys ():
2126 # Slice off the end of the GSP data
2227 gsp_len = self .forecast_len + self .history_len + 1
23- batch ["gsp" ] = batch ["gsp" ][:, :gsp_len ]
24- batch ["gsp_time_utc" ] = batch ["gsp_time_utc" ][:, :gsp_len ]
28+ new_batch ["gsp" ] = new_batch ["gsp" ][:, :gsp_len ]
29+ new_batch ["gsp_time_utc" ] = new_batch ["gsp_time_utc" ][:, :gsp_len ]
2530
2631 if self .include_sat :
2732 # Slice off the end of the satellite data and spatially crop
2833 # Shape: batch_size, seq_length, channel, height, width
29- batch ["satellite_actual" ] = center_crop (
30- batch ["satellite_actual" ][:, : self .sat_sequence_len ],
34+ new_batch ["satellite_actual" ] = center_crop (
35+ new_batch ["satellite_actual" ][:, : self .sat_sequence_len ],
3136 output_size = self .sat_encoder .image_size_pixels ,
3237 )
3338
3439 if self .include_nwp :
3540 # Slice off the end of the NWP data and spatially crop
3641 for nwp_source in self .nwp_encoders_dict :
3742 # shape: batch_size, seq_len, n_chans, height, width
38- batch ["nwp" ][nwp_source ]["nwp" ] = center_crop (
39- batch ["nwp" ][nwp_source ]["nwp" ],
43+ new_batch ["nwp" ][nwp_source ]["nwp" ] = center_crop (
44+ new_batch ["nwp" ][nwp_source ]["nwp" ],
4045 output_size = self .nwp_encoders_dict [nwp_source ].image_size_pixels ,
4146 )[:, : self .nwp_encoders_dict [nwp_source ].sequence_length ]
4247
4348 if self .include_sun :
4449 # Slice off the end of the solar coords data
4550 for s in ["solar_azimuth" , "solar_elevation" ]:
4651 key = f"{ self ._target_key } _{ s } "
47- if key in batch .keys ():
52+ if key in new_batch .keys ():
4853 sun_len = self .forecast_len + self .history_len + 1
49- batch [key ] = batch [key ][:, :sun_len ]
54+ new_batch [key ] = new_batch [key ][:, :sun_len ]
5055
51- return batch
56+ return new_batch
0 commit comments