We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
supports_gradient_checkpointing
NemotronHPreTrainedModel
1 parent c472755 commit ded2b74Copy full SHA for ded2b74
2 files changed
src/transformers/models/nemotron_h/modeling_nemotron_h.py
@@ -952,6 +952,7 @@ def forward(
952
class NemotronHPreTrainedModel(PreTrainedModel):
953
config: NemotronHConfig
954
base_model_prefix = "model"
955
+ supports_gradient_checkpointing = True
956
_no_split_modules = ["NemotronHBlock"]
957
_skip_keys_device_placement = ["past_key_values"]
958
_supports_flash_attn = True
src/transformers/models/nemotron_h/modular_nemotron_h.py
@@ -305,6 +305,7 @@ def forward(
305
306
307
308
309
310
311
0 commit comments