Skip to content

Commit ded2b74

Browse files
Add supports_gradient_checkpointing to NemotronHPreTrainedModel (#45625)
Add supports_gradient_checkpointing to NemotronHPreTrainedModel
1 parent c472755 commit ded2b74

2 files changed

Lines changed: 2 additions & 0 deletions

File tree

src/transformers/models/nemotron_h/modeling_nemotron_h.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,7 @@ def forward(
952952
class NemotronHPreTrainedModel(PreTrainedModel):
953953
config: NemotronHConfig
954954
base_model_prefix = "model"
955+
supports_gradient_checkpointing = True
955956
_no_split_modules = ["NemotronHBlock"]
956957
_skip_keys_device_placement = ["past_key_values"]
957958
_supports_flash_attn = True

src/transformers/models/nemotron_h/modular_nemotron_h.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def forward(
305305
class NemotronHPreTrainedModel(PreTrainedModel):
306306
config: NemotronHConfig
307307
base_model_prefix = "model"
308+
supports_gradient_checkpointing = True
308309
_no_split_modules = ["NemotronHBlock"]
309310
_skip_keys_device_placement = ["past_key_values"]
310311
_supports_flash_attn = True

0 commit comments

Comments
 (0)