Skip to content

Commit

Permalink
chore: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Dec 11, 2023
1 parent 3a521c9 commit 72ee382
Show file tree
Hide file tree
Showing 36 changed files with 715 additions and 450 deletions.
13 changes: 9 additions & 4 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2

def serialize(
self,
data,
Expand Down Expand Up @@ -69,7 +70,9 @@ def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:
prefill_token.id == other.id
and prefill_token.text == other.text
and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
)
if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob
)
Expand Down Expand Up @@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75


class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
Expand Down Expand Up @@ -198,6 +202,7 @@ def _inner_health(self) -> bool:
def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator)


@pytest.fixture
def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator)
Expand All @@ -219,7 +224,7 @@ def local_launcher(
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
dtype: Optional[str] = None
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
Expand Down Expand Up @@ -282,7 +287,7 @@ def docker_launcher(
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
dtype: Optional[str] = None
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)

Expand Down Expand Up @@ -335,7 +340,7 @@ def docker_launcher(
],
volumes=volumes,
ports={"80/tcp": port},
shm_size="1G"
shm_size="1G",
)

yield ContainerLauncherHandle(client, container.name, port)
Expand Down
12 changes: 9 additions & 3 deletions integration-tests/models/test_flash_medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
)

assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert (
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
)

assert responses == response_snapshot
4 changes: 3 additions & 1 deletion integration-tests/models/test_flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
)

assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1"

assert responses == response_snapshot
4 changes: 3 additions & 1 deletion integration-tests/models/test_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

@pytest.fixture(scope="module")
def idefics_handle(launcher):
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle:
with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
) as handle:
yield handle


Expand Down
16 changes: 14 additions & 2 deletions server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0


Expand Down
16 changes: 14 additions & 2 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0


Expand Down
16 changes: 14 additions & 2 deletions server/tests/models/test_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0


Expand Down
38 changes: 31 additions & 7 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,24 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}:
if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
)


Expand Down Expand Up @@ -140,23 +152,35 @@ def download_weights(

try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")

medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt"
)
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f:
config = json.load(f)

model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
Expand Down
13 changes: 8 additions & 5 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
__all__.append(FlashMixtral)



def get_model(
model_id: str,
revision: Optional[str],
Expand Down Expand Up @@ -157,7 +156,9 @@ def get_model(
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
Expand Down Expand Up @@ -249,7 +250,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa
use_medusa=use_medusa,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
Expand Down Expand Up @@ -313,7 +314,9 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)

if model_type == "opt":
return OPTSharded(
Expand Down Expand Up @@ -354,7 +357,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"):
elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
Expand Down
6 changes: 5 additions & 1 deletion server/text_generation_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
Expand Down
21 changes: 14 additions & 7 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,11 @@ def __init__(
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()

if tokenizer.pad_token_id is None:
Expand Down Expand Up @@ -676,7 +680,10 @@ def generate_token(
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
Expand All @@ -703,11 +710,11 @@ def generate_token(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
FastRMSNorm,
)


class LlamaConfig(PretrainedConfig):
def __init__(
self,
Expand Down Expand Up @@ -202,7 +203,7 @@ def forward(
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(
Expand Down Expand Up @@ -237,7 +238,7 @@ def forward(
input_lengths,
max_s,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))


Expand Down Expand Up @@ -288,7 +289,9 @@ def __init__(self, layer_id, config, weights):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
Expand Down
Loading

0 comments on commit 72ee382

Please sign in to comment.