Skip to content

Commit

Permalink
Merge branch 'main' into fix-emb-model-with-trfrs
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Jan 7, 2025
2 parents f6202aa + 6639b1d commit 659ee3d
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 13 deletions.
6 changes: 5 additions & 1 deletion benchmark/text-generation-inference/performance/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ output_path="${model//\//_}#${date_str}_guidellm_report.json"

export HF_TOKEN=$(cat ~/.cache/huggingface/token)

export GUIDELLM__NUM_SWEEP_PROFILES=1
export GUIDELLM__MAX_CONCURRENCY=128
export GUIDELLM__REQUEST_TIMEOUT=60

guidellm \
--target "http://localhost:8080/v1" \
--model ${model} \
--data-type emulated \
--data "prompt_tokens=1500,prompt_tokens_variance=150,generated_tokens=250,generated_tokens_variance=20" \
--output-path ${output_path}
--output-path ${output_path} \
12 changes: 12 additions & 0 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import GenerationMode

Expand Down Expand Up @@ -154,6 +157,15 @@ def create(

logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
# Remove transformers TopK, TopP and Temperature processors
logits_processor = LogitsProcessorList(
[
p
for p in logits_processor
if not isinstance(p, (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper))
]
)
# We use a fused logits warper instead
logits_warper = FusedLogitsWarper.from_config(generation_config)

return cls(
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry)
return inputs

def compute_loss(self, model, inputs, return_outputs: bool = False):
def compute_loss(self, model, inputs, num_items_in_batch):
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
inputs = self._prepare_inputs(inputs)
loss = model.run_train(**inputs)
else:
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)
loss = super().compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
return loss

def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
Expand Down
4 changes: 2 additions & 2 deletions text-generation-inference/tests/integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ async def test_model_single_request(tgi_service):
)
sample_expectations = {
"gpt2": "Deep Learning",
"llama": "Deep learning",
"mistral": "Deep Learning",
"llama": "Deep Learning",
"mistral": "Deep learning",
"qwen2": "Deep Learning",
"granite": "Deep learning",
}
Expand Down
10 changes: 5 additions & 5 deletions text-generation-inference/tests/server/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def _test_decode(config_name, generator, do_sample):
assert output.finish_reason == 0
if do_sample:
expected_text = {
"gpt2": " the wind was blowing",
"llama": "George Orwell",
"mistral": "The sky is black",
"qwen2": " I stood in the back yard",
"granite": "Aldous Huxley, Brave New World",
"gpt2": " The sun was set",
"llama": "George Orwell, 1984",
"mistral": "The sky was",
"qwen2": " A young woman with",
"granite": "1984, George Orwell",
}[config_name]
assert expected_text in output.text
else:
Expand Down
6 changes: 3 additions & 3 deletions text-generation-inference/tests/server/test_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def _test_prefill(config_name, generator, batch_size, do_sample):
assert len(generations) == batch_size
if do_sample:
expectations = {
"gpt2": [632, " It"],
"gpt2": [383, " The"],
"llama": [10058, " George"],
"mistral": [450, " The"],
"qwen2": [358, " I"],
"granite": [429, " -"],
"qwen2": [362, " A"],
"granite": [308, " ("],
}[config_name]
else:
expectations = {
Expand Down

0 comments on commit 659ee3d

Please sign in to comment.