diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 10fe6419..30d28884 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -87,7 +87,7 @@ def __init__( state_feature_weights = get_state_feature_weighting( config=config, datastore=datastore ) - self.feature_weights = torch.tensor( + self.register_buffer("feature_weights", torch.tensor(...)) state_feature_weights, dtype=torch.float32 )