Skip to content

Commit

Permalink
Only loads, nothing else yet
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 26, 2024
1 parent be2d380 commit 49ac515
Show file tree
Hide file tree
Showing 3 changed files with 885 additions and 1 deletion.
33 changes: 32 additions & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
FLASH_ATTENTION = True

try:
from text_generation_server.models.flash_deepseek_v2 import FlashDeepseekV2
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
Expand Down Expand Up @@ -89,6 +90,7 @@
FLASH_ATTENTION = False

if FLASH_ATTENTION:
__all__.append(FlashDeepseekV2)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
Expand Down Expand Up @@ -116,6 +118,11 @@


class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
Expand Down Expand Up @@ -424,7 +431,31 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)

if model_type == MAMBA:
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
return FlashDeepseekV2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

elif model_type == MAMBA:
return Mamba(
model_id,
revision,
Expand Down
Loading

0 comments on commit 49ac515

Please sign in to comment.