-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for transformers-neuronx continuous batching #488
Changes from all commits
8ebbf9a
0d5e4ed
81db49c
b607992
101a0ae
d91213c
6b988ab
2a9d312
f5dec76
c637e71
ce3f88e
f45d620
513a214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
|
||
|
||
if is_transformers_neuronx_available(): | ||
from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig | ||
from transformers_neuronx.module import save_split | ||
|
||
|
||
|
@@ -131,16 +132,26 @@ def __init__( | |
|
||
exporter = get_exporter(config, task) | ||
|
||
# transformers-neuronx uses f32/f16 instead of fp32/fp16 | ||
auto_cast_type = auto_cast_type.replace("p", "") | ||
tnx_kwargs = { | ||
"batch_size": batch_size, | ||
"tp_degree": num_cores, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The TP degree is always exactly the number of neuron cores used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I use the num_cores name in optimum-neuron because I use it for two things: to set the TP degree (here) and also to restrict the number of cores I reserve at initialization (otherwise the TNX runtime takes them all). |
||
# transformers-neuronx uses f32/f16 instead of fp32/fp16 | ||
"amp": auto_cast_type.replace("p", ""), | ||
} | ||
if batch_size > 1 and exporter.continuous_batching: | ||
# Continuous batching is always enabled for models that support it because static batching | ||
# is broken for these models: see https://github.com/aws-neuron/transformers-neuronx/issues/79 | ||
tnx_kwargs["neuron_config"] = NeuronConfig( | ||
continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size) | ||
) | ||
tnx_kwargs["n_positions"] = [sequence_length] | ||
tnx_kwargs["context_length_estimate"] = [sequence_length] | ||
else: | ||
tnx_kwargs["n_positions"] = sequence_length | ||
|
||
# Instantiate neuronx model | ||
checkpoint_path = checkpoint_dir.name if isinstance(checkpoint_dir, TemporaryDirectory) else checkpoint_dir | ||
neuronx_model = exporter.neuronx_class.from_pretrained( | ||
checkpoint_path, | ||
batch_size=batch_size, | ||
n_positions=sequence_length, | ||
tp_degree=num_cores, | ||
amp=auto_cast_type, | ||
) | ||
neuronx_model = exporter.neuronx_class.from_pretrained(checkpoint_path, **tnx_kwargs) | ||
|
||
if compiled_dir is not None: | ||
# Specify the path where compiled artifacts are stored before conversion | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe raise an explicit error with a message saying that the shapes should match in this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using an assert because these methods are always called internally, so it is rather to catch internal programming errors.