Skip to content

Commit cd58360

Browse files
committed
suit transformers>=4.51
Signed-off-by: xin3he <[email protected]>
1 parent e9bd2e7 commit cd58360

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

neural_compressor/torch/algorithms/weight_only/save_load.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tempfile
2222

2323
import torch
24+
import transformers
2425

2526
from neural_compressor.common.utils import AWQ, TEQ, save_config_mapping
2627
from neural_compressor.torch.utils import (
@@ -846,7 +847,8 @@ def _init_hf_model(self, model_class, config):
846847

847848
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)
848849

849-
init_contexts = [no_init_weights(_enable=_fast_init)]
850+
init_contexts = [no_init_weights(_enable=_fast_init)] if transformers.__version__ < "4.51" else\
851+
[no_init_weights()]
850852
init_contexts.append(init_empty_weights())
851853

852854
with ContextManagers(init_contexts):

neural_compressor/transformers/models/modeling_auto.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
678678
quantization_config.weight_dtype = "int4"
679679
logger.warning("int4 weight_dtype is used, please change the config.json if you don't want to use it.")
680680

681-
init_contexts = [no_init_weights(_enable=_fast_init)]
681+
init_contexts = [no_init_weights(_enable=_fast_init)] if transformers.__version__ < "4.51" else\
682+
[no_init_weights()]
682683
init_contexts.append(init_empty_weights())
683684

684685
with ContextManagers(init_contexts):

0 commit comments

Comments
 (0)