Skip to content

Commit 7b4a912

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3fb72bb commit 7b4a912

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

neural_compressor/torch/algorithms/layer_wise/utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,7 @@ def load_value(model, param_name, path, device="cpu"):
221221
os.path.join(path, "model.safetensors"), param_name, prefix=prefix, device=device
222222
)
223223
elif len(safetensors_files) >= 2:
224-
value = load_tensor_from_safetensors_shard(
225-
path, param_name, prefix=prefix, device=device
226-
)
224+
value = load_tensor_from_safetensors_shard(path, param_name, prefix=prefix, device=device)
227225
elif "pytorch_model.bin.index.json" in files:
228226
value = load_tensor_from_shard(path, param_name, prefix)
229227
else:

neural_compressor/torch/algorithms/weight_only/save_load.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import torch
2424
import transformers
25-
2625
from packaging.version import parse
2726

2827
from neural_compressor.common.utils import AWQ, TEQ, save_config_mapping
@@ -849,8 +848,11 @@ def _init_hf_model(self, model_class, config):
849848

850849
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)
851850

852-
init_contexts = [no_init_weights(_enable=_fast_init)] if parse(transformers.__version__) < parse("4.51") else\
853-
[no_init_weights()]
851+
init_contexts = (
852+
[no_init_weights(_enable=_fast_init)]
853+
if parse(transformers.__version__) < parse("4.51")
854+
else [no_init_weights()]
855+
)
854856
init_contexts.append(init_empty_weights())
855857

856858
with ContextManagers(init_contexts):

neural_compressor/transformers/models/modeling_auto.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838
import transformers
3939
from accelerate import init_empty_weights
4040
from accelerate.utils import is_xpu_available
41+
from packaging.version import parse
4142
from transformers import AutoConfig
4243
from transformers.configuration_utils import PretrainedConfig
4344
from transformers.modeling_utils import load_state_dict
4445
from transformers.utils import has_file, is_safetensors_available
45-
from packaging.version import parse
4646

4747
from neural_compressor.common.utils import CpuInfo, logger
4848
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
@@ -679,8 +679,11 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
679679
quantization_config.weight_dtype = "int4"
680680
logger.warning("int4 weight_dtype is used, please change the config.json if you don't want to use it.")
681681

682-
init_contexts = [no_init_weights(_enable=_fast_init)] if parse(transformers.__version__) < parse("4.51") else\
683-
[no_init_weights()]
682+
init_contexts = (
683+
[no_init_weights(_enable=_fast_init)]
684+
if parse(transformers.__version__) < parse("4.51")
685+
else [no_init_weights()]
686+
)
684687
init_contexts.append(init_empty_weights())
685688

686689
with ContextManagers(init_contexts):

0 commit comments

Comments
 (0)