diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index a1ed65ad4..7fb4d813d 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -24,7 +24,7 @@ pip install pipdeptree pipdeptree python -m pylint -f json --disable=R,C,W,E1129 --enable=line-too-long --max-line-length=120 --extension-pkg-whitelist=numpy --ignored-classes=TensorProto,NodeProto \ ---ignored-modules=tensorflow,keras,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow,torchinfo,horovod,transformers \ +--ignored-modules=tensorflow,keras,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow,torchinfo,horovod,transformers,deepspeed,deepspeed.module_inject \ /auto-round/${scan_module} > $log_dir/pylint.json exit_code=$? diff --git a/auto_round/utils.py b/auto_round/utils.py index b9462a500..3aebd06b4 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -38,6 +38,11 @@ SHARED_CACHE_KEYS = ("position_ids", "cache_position", "position_embeddings") +deepspeed_exists = False +if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed + deepspeed_exists = True + + class SupportedFormats: def __init__(self): @@ -67,16 +72,18 @@ def __getitem__(self, key): return self._support_list[key] +SUPPORTED_DTYPES = ("int", "mx_fp", "fp", "nv_fp") SUPPORTED_FORMATS = SupportedFormats() - SUPPORTED_LAYER_TYPES = (torch.nn.Linear, transformers.pytorch_utils.Conv1D) ##changed to str as it relies triton or others lib to load this INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear",) - # INNER_SUPPORTED_LAYER_TYPES = (transformers.integrations.finegrained_fp8.FP8Linear,) -SUPPORTED_DTYPES = ("int", "mx_fp", "fp", "nv_fp") +if deepspeed_exists: + from deepspeed.module_inject import LinearAllreduce, LinearLayer + + SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce) def infer_bits_by_data_type(data_type: str): @@ -1172,6 +1179,8 @@ def get_layer_features(layer): return layer.weight.shape[0], layer.weight.shape[1] elif isinstance(layer, torch.nn.Embedding): return layer.num_embeddings, layer.embedding_dim + elif deepspeed_exists and isinstance(layer, (LinearLayer, LinearAllreduce)): + return layer.weight.shape[1], layer.weight.shape[0] # (input_dim, output_dim) return None, None # Unsupported layer type diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index c3dfc0765..8942a2c0d 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -18,7 +18,20 @@ from auto_round.data_type import get_quant_func -from .utils import SUPPORTED_LAYER_TYPES, check_to_quantized, get_scale_shape, is_mx_fp, is_nv_fp, logger, set_module +from .utils import ( + SUPPORTED_LAYER_TYPES, + check_to_quantized, + deepspeed_exists, + get_scale_shape, + is_mx_fp, + is_nv_fp, + logger, + set_module, +) + +if deepspeed_exists: + from deepspeed import comm as dist + from deepspeed.module_inject import LinearAllreduce, LinearLayer def reshape_and_pad_tensor(v, group_size=-1): @@ -94,7 +107,18 @@ def __init__( else: self.q_scale_thresh = 1e-5 self._init_tuning_params_and_quant_func() - self.orig_forward = self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward + if deepspeed_exists: + if isinstance(self.orig_layer, (torch.nn.Linear, LinearLayer)): + self.orig_forward = self.linear_forward + elif isinstance(self.orig_layer, LinearAllreduce): + self.orig_forward = self.all_reduce_linear_forward + self.mp_group = self.orig_layer.mp_group + else: + self.orig_forward = self.conv1d_forward + else: + self.orig_forward = ( + self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward + ) def _init_tuning_params_and_quant_func(self): """Initializes tuning parameters and quantization functions. @@ -367,6 +391,24 @@ def linear_forward(self, x, weight, bias): """ return F.linear(x, weight, bias) + def all_reduce_linear_forward(self, x, weight, bias): + """Performs the forward pass for a linear layer. + + Args: + x (torch.Tensor): Input tensor. + weight (torch.Tensor): Weight tensor for the linear layer. + bias (torch.Tensor): Bias tensor for the linear layer. + + Returns: + torch.Tensor: Output tensor after applying the linear layer. + """ + output = torch.matmul(x, weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.inference_all_reduce(output, group=self.mp_group) + if bias is not None: + output += bias + return output + def conv1d_forward(self, x, weight, bias): """Performs the forward pass for a Conv1D layer. @@ -414,7 +456,10 @@ class WrapperWALayer(torch.nn.Module): def __init__(self, orig_layer): super(WrapperWALayer, self).__init__() self.orig_layer = orig_layer + self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None + self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None self.act_quant_func = self.orig_layer.act_quant_func + self.extra_repr_org = orig_layer.extra_repr def forward(self, x): act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None @@ -429,6 +474,9 @@ def forward(self, x): ) return self.orig_layer.forward(x) + def extra_repr(self): + return f"{self.extra_repr_org()}, weight_type={self.data_type}, act_data_type={self.act_data_type}" + class WrapperLayerNorm(torch.nn.Module): """A wrapper for layer normalization with quantized weights.