Skip to content

support deepspeed LinearLayer and LinearAllreduce #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/codeScan/pylint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install pipdeptree
pipdeptree

python -m pylint -f json --disable=R,C,W,E1129 --enable=line-too-long --max-line-length=120 --extension-pkg-whitelist=numpy --ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,keras,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow,torchinfo,horovod,transformers \
--ignored-modules=tensorflow,keras,torch,torch.quantization,torch.tensor,torchvision,fairseq,mxnet,onnx,onnxruntime,intel_extension_for_pytorch,intel_extension_for_tensorflow,torchinfo,horovod,transformers,deepspeed,deepspeed.module_inject \
/auto-round/${scan_module} > $log_dir/pylint.json

exit_code=$?
Expand Down
15 changes: 12 additions & 3 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
SHARED_CACHE_KEYS = ("position_ids", "cache_position", "position_embeddings")


deepspeed_exists = False
if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed
deepspeed_exists = True


class SupportedFormats:

def __init__(self):
Expand Down Expand Up @@ -67,16 +72,18 @@ def __getitem__(self, key):
return self._support_list[key]


SUPPORTED_DTYPES = ("int", "mx_fp", "fp", "nv_fp")
SUPPORTED_FORMATS = SupportedFormats()

SUPPORTED_LAYER_TYPES = (torch.nn.Linear, transformers.pytorch_utils.Conv1D)

##changed to str as it relies triton or others lib to load this
INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear",)

# INNER_SUPPORTED_LAYER_TYPES = (transformers.integrations.finegrained_fp8.FP8Linear,)

SUPPORTED_DTYPES = ("int", "mx_fp", "fp", "nv_fp")
if deepspeed_exists:
from deepspeed.module_inject import LinearAllreduce, LinearLayer

SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce)


def infer_bits_by_data_type(data_type: str):
Expand Down Expand Up @@ -1172,6 +1179,8 @@ def get_layer_features(layer):
return layer.weight.shape[0], layer.weight.shape[1]
elif isinstance(layer, torch.nn.Embedding):
return layer.num_embeddings, layer.embedding_dim
elif deepspeed_exists and isinstance(layer, (LinearLayer, LinearAllreduce)):
return layer.weight.shape[1], layer.weight.shape[0] # (input_dim, output_dim)
return None, None # Unsupported layer type


Expand Down
52 changes: 50 additions & 2 deletions auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,20 @@

from auto_round.data_type import get_quant_func

from .utils import SUPPORTED_LAYER_TYPES, check_to_quantized, get_scale_shape, is_mx_fp, is_nv_fp, logger, set_module
from .utils import (
SUPPORTED_LAYER_TYPES,
check_to_quantized,
deepspeed_exists,
get_scale_shape,
is_mx_fp,
is_nv_fp,
logger,
set_module,
)

if deepspeed_exists:
from deepspeed import comm as dist
from deepspeed.module_inject import LinearAllreduce, LinearLayer


def reshape_and_pad_tensor(v, group_size=-1):
Expand Down Expand Up @@ -94,7 +107,18 @@ def __init__(
else:
self.q_scale_thresh = 1e-5
self._init_tuning_params_and_quant_func()
self.orig_forward = self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward
if deepspeed_exists:
if isinstance(self.orig_layer, (torch.nn.Linear, LinearLayer)):
self.orig_forward = self.linear_forward
elif isinstance(self.orig_layer, LinearAllreduce):
self.orig_forward = self.all_reduce_linear_forward
self.mp_group = self.orig_layer.mp_group
else:
self.orig_forward = self.conv1d_forward
else:
self.orig_forward = (
self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward
)

def _init_tuning_params_and_quant_func(self):
"""Initializes tuning parameters and quantization functions.
Expand Down Expand Up @@ -367,6 +391,24 @@ def linear_forward(self, x, weight, bias):
"""
return F.linear(x, weight, bias)

def all_reduce_linear_forward(self, x, weight, bias):
"""Performs the forward pass for a linear layer.

Args:
x (torch.Tensor): Input tensor.
weight (torch.Tensor): Weight tensor for the linear layer.
bias (torch.Tensor): Bias tensor for the linear layer.

Returns:
torch.Tensor: Output tensor after applying the linear layer.
"""
output = torch.matmul(x, weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
if bias is not None:
output += bias
return output

def conv1d_forward(self, x, weight, bias):
"""Performs the forward pass for a Conv1D layer.

Expand Down Expand Up @@ -414,7 +456,10 @@ class WrapperWALayer(torch.nn.Module):
def __init__(self, orig_layer):
super(WrapperWALayer, self).__init__()
self.orig_layer = orig_layer
self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None
self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None
self.act_quant_func = self.orig_layer.act_quant_func
self.extra_repr_org = orig_layer.extra_repr

def forward(self, x):
act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None
Expand All @@ -429,6 +474,9 @@ def forward(self, x):
)
return self.orig_layer.forward(x)

def extra_repr(self):
return f"{self.extra_repr_org()}, weight_type={self.data_type}, act_data_type={self.act_data_type}"


class WrapperLayerNorm(torch.nn.Module):
"""A wrapper for layer normalization with quantized weights.
Expand Down