Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"granite",
"gemma",
"dbrx_converted",
"granitemoe"
]

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Expand All @@ -43,6 +44,7 @@
"granite": GraniteGPTQ,
"dbrx": DbrxGPTQ,
"dbrx_converted": DbrxConvertedGPTQ,
"granitemoe": GraniteMoeGPTQ
}

at_least_one_cuda_v6 = any(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def save_quantized(
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
value=META_QUANTIZER_GPTQMODEL,
version=__version__,
version="1.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be changed?

)

# The config, quantize_config and model may be edited in place in save_quantized.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .base import BaseGPTQModel


class GraniteMoeGPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]
convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]

layers_node = "model.layers"
layer_type = "GraniteMoeDecoderLayer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you add some simple key to inform the format of input_linear and output_linear, that these are 3D tensors.

Also in the granitemoe case, another compilation is that input_linear fuses w1 and w3. it might be ok for a first cut just to leave them as fused.

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically the simple key needs to know what do look for to convert it to 3D tensor, and then when you write layer_modules you write it as though they have been converrted

class GraniteMoeGPTQ(BaseGPTQModel):
    
    convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]

    layer_modules = [

        [
             "block_sparse_moe.input_linear.0.weight",
              "block_sparse_moe.input_linear.1.weight",
              ...
        ], [
             "block_sparse_moe.output_linear.0.weight",
              "block_sparse_moe.output_linear.1.weight",
              ...
        ]
    ]

layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
[f"block_sparse_moe.input_linear.{i}" for i in range(40)],
[f"block_sparse_moe.output_linear.{i}" for i in range(40)],
]
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import threadpoolctl as tctl
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import transformers.models.granitemoe.modeling_granitemoe as MOE

# Local
from ..models._const import (
Expand All @@ -52,6 +54,24 @@
logger.addHandler(handler)
logger.setLevel(logging.INFO)

class ThreeDTensorModuleList(nn.ModuleList):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Shape of input: (num_selected_experts * batch_size (expert_size), input_features_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this module is called ThreeDTensorModuleList but its written quite specifically for MoE. I think its fine, but then its design should then assume that it is an MoE, and not a generic module with three3 tensors

expert_size = len(self)
input_list = inputs.split(expert_size, dim=0)
output_list = []

# Iterate over the number of selected experts and apply each expert to the corresponding input
for i in range(len(self)):
# Shape of input_list[i]: (batch_size, input_features_size); Shape of self[i]: (output_features_size, input_features_size)
# Shape of output: (batch_size, output_features_size);
expert_output = F.linear(input_list[i], self[i])
output_list.append(expert_output)

# Concatenate the outputs along the first dimension
results = torch.cat(output_list, dim=0) # Shape: (num_selected_experts * batch_size, output_features_size)
return results


def recurse_getattr(obj, attr: str):
"""
Expand Down Expand Up @@ -100,14 +120,40 @@ def nested_move_to(v, device):
return v


def check3DTensor(module, name, convert3dToModuleList=["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]):
if convert3dToModuleList and name in convert3dToModuleList:
# print("INSIDE check3DTensor module, name, convert3dToModuleList", module, name, convert3dToModuleList)
num_experts = module.num_experts
input_size = module.input_size
output_size = module.output_size
module = ThreeDTensorModuleList([
nn.Linear(input_size, output_size, bias=False) for _ in range(num_experts)
])

return module


def find_layers(module, layers=None, name=""):
# print("1- INSIDE find_layers module", module)
module = check3DTensor(module, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is missing here is the logic for having the model recipe needs to pass in the convert3dToModuleList How do you plan to handle this?

One alternative is to handle the module swap completely outside of this function, then you dont need this logic

# print("2- AFTER check3DTensor module", module)
if not layers:
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]

# print("2- INFO: type(module), name", type(module), name)
# if hasattr(module, "weight"):
# print("3- type(module.weight), module.weight.shape, module.weight.ndim", type(module.weight), module.weight.shape, module.weight.ndim)
for layer in layers:
if isinstance(module, layer):
return {name: module}

res = {}
# if isinstance(module, MOE.GraniteMoeParallelExperts):
# print("Print GraniteMoeParallelExperts Layer children")
# for name1, child in module.named_children():
# print("4- name1, child", name1, child)
for name1, child in module.named_children():
# print("PROCESS- name, name1, child", name, name1, child)
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
Expand Down
Loading