diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 6af7fd477564..93bd47f2c3f4 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -952,6 +952,7 @@ def forward( class NemotronHPreTrainedModel(PreTrainedModel): config: NemotronHConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _no_split_modules = ["NemotronHBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index f49597f43140..803e5c638239 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -305,6 +305,7 @@ def forward( class NemotronHPreTrainedModel(PreTrainedModel): config: NemotronHConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _no_split_modules = ["NemotronHBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True