Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallel support to T5 via NxD #697

Merged
merged 29 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def parse_args_neuronx(parser: "ArgumentParser"):
"--tensor_parallel_size",
type=int,
default=1,
help="Tensor parallelism degree, the number of neuron cores on which to shard the model.",
help="Tensor parallelism size, the number of neuron cores on which to shard the model.",
)
optional_group.add_argument(
"--dynamic-batch-size",
Expand Down
12 changes: 6 additions & 6 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
self._config = config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.mandatory_axes = ()
self.tp_degree = tensor_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.task = task
self._axes: Dict[str, int] = {}
self.dynamic_batch_size = dynamic_batch_size
Expand Down Expand Up @@ -230,12 +230,12 @@ def task(self, value: str):
self.mandatory_axes = self.get_mandatory_axes_for_task(self.task)

@property
def tp_degree(self) -> int:
return self._tp_degree
def tensor_parallel_size(self) -> int:
return self._tensor_parallel_size

@tp_degree.setter
def tp_degree(self, value: int):
self._tp_degree = value
@tensor_parallel_size.setter
def tensor_parallel_size(self, value: int):
self._tensor_parallel_size = value

def __getattr__(self, attr_name) -> Any:
if attr_name != "_axes" and attr_name in self._axes:
Expand Down
10 changes: 5 additions & 5 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def export_models(
input_names=neuron_inputs,
output_names=neuron_outputs,
dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
tensor_parallel_size=sub_neuron_config.tp_degree,
tensor_parallel_size=sub_neuron_config.tensor_parallel_size,
compiler_type=NEURON_COMPILER_TYPE,
compiler_version=NEURON_COMPILER_VERSION,
inline_weights_to_neff=inline_weights_to_neff,
Expand Down Expand Up @@ -531,10 +531,10 @@ def export_neuronx(

# Prepare the model / function(tp) to trace
aliases = {}
tp_degree = config.tp_degree
tensor_parallel_size = config.tensor_parallel_size
if isinstance(config, TextSeq2SeqNeuronConfig):
checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
if tp_degree == 1:
if tensor_parallel_size == 1:
aliases = config.generate_io_aliases(checked_model)
else:
checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)
Expand Down Expand Up @@ -562,15 +562,15 @@ def export_neuronx(
inline_weights_to_neff = True

# Start trace
if tp_degree > 1:
if tensor_parallel_size > 1:
# 1. use NxD to trace for parallel
neuron_model = neuronx_distributed.trace.parallel_model_trace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok in a first step, but for LLama example they are not using this anymore, but instead the ModelBuilder class that wraps the model into NxDModel classes that contains several sub-models with different input shapes (bucketing).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the use of bucketing mature and justified? I think we can start with parallel_model_trace anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes a bit beyond that, because prefill / decode already use two different input shapes, not even mentioning bucketing, and using the builder allows to share the same weights between all the alternate graphs.

checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
tp_degree=tp_degree,
tp_degree=tensor_parallel_size,
)
neuronx_distributed.trace.parallel_model_save(neuron_model, output)
else:
Expand Down
24 changes: 13 additions & 11 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def patch_model_for_export(self, model_or_path, device="xla", **kwargs):
sequence_length = kwargs.pop("sequence_length", None)
batch_size = kwargs.pop("batch_size", None)

if self.tp_degree > 1:
if self.tensor_parallel_size > 1:
# `torch.nn.modules` objects not eligible for pickling, the model needs to be loaded within the func.
return partial(
self.get_parallel_callable,
Expand All @@ -813,7 +813,7 @@ def patch_model_for_export(self, model_or_path, device="xla", **kwargs):
batch_size,
num_beams,
device,
self.tp_degree,
self.tensor_parallel_size,
)
else:
return self.CUSTOM_MODEL_WRAPPER(
Expand All @@ -822,10 +822,12 @@ def patch_model_for_export(self, model_or_path, device="xla", **kwargs):
batch_size=batch_size,
num_beams=num_beams,
device=device,
tp_degree=self.tp_degree,
tensor_parallel_size=self.tensor_parallel_size,
)

def get_parallel_callable(self, model_name_or_path, sequence_length, batch_size, num_beams, device, tp_degree):
def get_parallel_callable(
self, model_name_or_path, sequence_length, batch_size, num_beams, device, tensor_parallel_size
):
"""Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
model = TasksManager.get_model_from_task(
model_name_or_path=model_name_or_path,
Expand All @@ -846,15 +848,15 @@ def get_parallel_callable(self, model_name_or_path, sequence_length, batch_size,
batch_size=batch_size,
num_beams=num_beams,
device=device,
tp_degree=tp_degree,
tensor_parallel_size=tensor_parallel_size,
)
encoder.eval()
aliases = self.generate_io_aliases(encoder)
return encoder, aliases

def generate_io_aliases(self, encoder=None):
aliases = {}
if self.tp_degree > 1:
if self.tensor_parallel_size > 1:
for i in range(len(encoder.past_key_values_sa)):
aliases[encoder.past_key_values_sa[i]] = i
for i in range(len(encoder.past_key_values_ca)):
Expand Down Expand Up @@ -912,9 +914,9 @@ def patch_model_for_export(self, model, device="xla", **kwargs):
"output_hidden_states": self.output_hidden_states,
"output_attentions": self.output_attentions,
"device": device,
"tp_degree": self.tp_degree,
"tensor_parallel_size": self.tensor_parallel_size,
}
if self.tp_degree > 1:
if self.tensor_parallel_size > 1:
return partial(
self.get_parallel_callable,
model,
Expand All @@ -924,7 +926,7 @@ def patch_model_for_export(self, model, device="xla", **kwargs):
self.output_hidden_states,
self.output_attentions,
device,
self.tp_degree,
self.tensor_parallel_size,
)
else:
return self.CUSTOM_MODEL_WRAPPER(**trace_args)
Expand All @@ -938,7 +940,7 @@ def get_parallel_callable(
output_hidden_states,
output_attentions,
device,
tp_degree,
tensor_parallel_size,
):
"""Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
model = TasksManager.get_model_from_task(
Expand All @@ -963,7 +965,7 @@ def get_parallel_callable(
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
device=device,
tp_degree=tp_degree,
tensor_parallel_size=tensor_parallel_size,
)
decoder.eval()
aliases = self.generate_io_aliases(decoder)
Expand Down
16 changes: 8 additions & 8 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
batch_size: Optional[int] = None,
num_beams: int = 1,
device: str = "xla",
tp_degree: int = 1,
tensor_parallel_size: int = 1,
):
super().__init__()
self.model = model
Expand All @@ -141,10 +141,10 @@ def __init__(
self.sequence_length = sequence_length
self.batch_size = batch_size
self.device = device
self.tp_degree = tp_degree
self.num_attention_heads_per_partition = self.config.num_heads # when tp_degree=1
self.tensor_parallel_size = tensor_parallel_size
self.num_attention_heads_per_partition = self.config.num_heads # when tensor_parallel_size=1

if self.tp_degree > 1:
if self.tensor_parallel_size > 1:
self.num_attention_heads_per_partition = (
self.num_attention_heads_per_partition
// neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size()
Expand Down Expand Up @@ -227,7 +227,7 @@ def shape(states):
key_states = shape(attention.k(encoder_hidden_states))
value_states = shape(attention.v(encoder_hidden_states))

if not self.tp_degree > 1:
if not self.tensor_parallel_size > 1:
# cross_attn_kv_state
present_key_value_states_ca.append(key_states)
present_key_value_states_ca.append(value_states)
Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(
output_hidden_states: bool = False,
output_attentions: bool = False,
device: str = "xla",
tp_degree: int = 1,
tensor_parallel_size: int = 1,
):
super().__init__()
self.model = model
Expand All @@ -307,10 +307,10 @@ def __init__(
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.device = device
self.tp_degree = tp_degree
self.tensor_parallel_size = tensor_parallel_size

self.num_attention_heads_per_partition = self.config.num_heads
if tp_degree > 1:
if tensor_parallel_size > 1:
self.num_attention_heads_per_partition = (
self.num_attention_heads_per_partition
// neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size()
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def get_encoder_decoder_models_for_export(
task (`str`):
The task to export the model for. If not specified, the task will be auto-inferred based on the model.
tensor_parallel_size (`int`):
Tensor parallelism degree, the number of devices on which to shard the model.
Tensor parallelism size, the number of Neuron cores on which to shard the model.
input_shapes (`Dict[str, int]`):
Static shapes used for compiling the encoder and the decoder.
dynamic_batch_size (`bool`, defaults to `False`):
Expand Down
9 changes: 5 additions & 4 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
if generation_config is None:
generation_config = GenerationConfig.from_model_config(self.configs[DECODER_NAME])
self.generation_config = generation_config
self.tp_degree = self.neuron_configs[DECODER_NAME].tp_degree
self.tensor_parallel_size = self.neuron_configs[DECODER_NAME].tensor_parallel_size

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -517,13 +517,14 @@ def generate(
axis=1,
)

if not self.tp_degree > 1:
if self.tensor_parallel_size == 1:
# copy the new cache state to the decoder
for state, tensor in zip(self.decoder.model.parameters(), past_key_values):
state.copy_(tensor)
else:
# Encoder returns cache as device tensors, we assign them to decoder's cache to avoid the copy.
# The KV cache always use pre-allocated memory, no host-device communication overhead.
# Here we iterate sharded encoders and decoders since the encoder on each rank will return cache as device tensors,
# we want to assign them to the cache of the sharded decoder on the same rank to avoid the copy. The KV cache always
# use pre-allocated memory, no host-device communication overhead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

for decoder_tp, encoder_tp in zip(self.decoder.model.models, self.encoder.model.models):
decoder_tp.load_state_dict(encoder_tp.state_dict(), strict=False)

Expand Down
4 changes: 3 additions & 1 deletion tests/generation/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def test_general_seq2seq_generation(export_seq2seq_id, export_seq2seq_model_clas
_test_model_generation_trn(model, tokenizer, 1, 10, **gen_kwargs)


# Mandatory for multiprocessing tests eg. tensor parallel tracing
# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program.
# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace
# and compile the model.
if __name__ == "__main__":
pytest.main([__file__])
Loading