Skip to content

Commit

Permalink
feat: mixtral (#1328)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Dec 11, 2023
1 parent 9ecfa16 commit 3a521c9
Show file tree
Hide file tree
Showing 11 changed files with 959 additions and 231 deletions.
9 changes: 7 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda

# Build megablocks
FROM kernel-builder as megablocks-builder

RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e

# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base

Expand All @@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
&& rm -rf /var/lib/apt/lists/*

# Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy conda with PyTorch and Megablocks installed
COPY --from=megablocks-builder /opt/conda /opt/conda

# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
Expand Down
5 changes: 5 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ pub async fn run(
// Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();

// Prometheus handler
let builder = PrometheusBuilder::new()
Expand All @@ -641,6 +644,8 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap();
let prom_handle = builder
.install_recorder()
Expand Down
3 changes: 3 additions & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py

install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e

install: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt
Expand Down
27 changes: 24 additions & 3 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import torch

from loguru import logger
Expand Down Expand Up @@ -78,6 +77,18 @@
if MISTRAL:
__all__.append(FlashMistral)

MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False

if MIXTRAL:
__all__.append(FlashMixtral)



def get_model(
model_id: str,
revision: Optional[str],
Expand Down Expand Up @@ -141,7 +152,6 @@ def get_model(
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
Expand Down Expand Up @@ -292,7 +302,18 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mistral model requires flash attention v2")
raise NotImplementedError("Mistral models requires flash attention v2")

if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")

if model_type == "opt":
return OPTSharded(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,8 @@
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM

if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops


class LlamaConfig(PretrainedConfig):
def __init__(
Expand Down Expand Up @@ -95,75 +89,6 @@ def __init__(
)


class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()

weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states

return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states

out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")


def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
Expand Down Expand Up @@ -363,10 +288,8 @@ def __init__(self, layer_id, config, weights):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -430,7 +353,7 @@ def __init__(self, config, weights):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,9 @@
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM

if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops

if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -100,76 +96,6 @@ def __init__(
**kwargs,
)


class MistralRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()

weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states

return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states

out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")


def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
Expand Down Expand Up @@ -371,10 +297,10 @@ def __init__(self, layer_id, config, weights):
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

self.input_layernorm = MistralRMSNorm(
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MistralRMSNorm(
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -440,7 +366,7 @@ def __init__(self, config, weights):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = MistralRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)

Expand Down
Loading

0 comments on commit 3a521c9

Please sign in to comment.