From b9f13bd53fb081e0b5789812869337263ce16b62 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 16 Dec 2024 14:20:05 +0000 Subject: [PATCH 01/18] chore: use AWS Neuron 2.21.0 SDK --- optimum/neuron/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/neuron/version.py b/optimum/neuron/version.py index e0ce222fe..a387b9423 100644 --- a/optimum/neuron/version.py +++ b/optimum/neuron/version.py @@ -14,4 +14,4 @@ __version__ = "0.0.28.dev0" -__sdk_version__ = "2.20.2" +__sdk_version__ = "2.21.0b" From aca7d981ee453905bae2177f8a5b8c9c53fd1ece Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 16 Dec 2024 14:39:09 +0000 Subject: [PATCH 02/18] chore: use AWS Neuron SDK 2.21.0 pip packages --- setup.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 079412c3e..58a638844 100644 --- a/setup.py +++ b/setup.py @@ -64,13 +64,13 @@ ], "neuronx": [ "wheel", - "neuronx-cc==2.15.143.0", - "torch-neuronx==2.1.2.2.3.2", - "transformers-neuronx==0.12.313", - "torch==2.1.2.*", - "torchvision==0.16.*", - "neuronx_distributed==0.9.0", - "libneuronxla==2.0.5347.0", + "neuronx-cc==2.16.345.0", + "torch-neuronx==2.5.1.2.4.0", + "transformers-neuronx==0.13.322", + "torch==2.5.1.*", + "torchvision==0.20.*", + "neuronx_distributed==0.10.0", + "libneuronxla==2.1.681.0", ], "diffusers": ["diffusers>=0.28.0, <=0.30.3", "peft"], "sentence-transformers": ["sentence-transformers >= 2.2.0"], From e5fd0625feddd38f846ba348352011ca9bfa9e65 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 16 Dec 2024 18:25:26 +0000 Subject: [PATCH 03/18] chore(tgi): use AWS Neuron SDK 2.21.0 --- text-generation-inference/Dockerfile | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/text-generation-inference/Dockerfile b/text-generation-inference/Dockerfile index 30aa71147..054d35eda 100644 --- a/text-generation-inference/Dockerfile +++ b/text-generation-inference/Dockerfile @@ -113,10 +113,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU # Install neuronx packages RUN apt-get update -y \ && apt-get install -y --no-install-recommends \ - aws-neuronx-dkms=2.18.20.0 \ - aws-neuronx-collectives=2.22.33.0-d2128d1aa \ - aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \ - aws-neuronx-tools=2.19.0.0 \ + aws-neuronx-dkms=2.19.64.0 \ + aws-neuronx-collectives=2.23.133.0-3e70920f2 \ + aws-neuronx-runtime-lib=2.23.110.0-9b5179492 \ + aws-neuronx-tools=2.20.204.0 \ libxml2 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -124,10 +124,10 @@ RUN apt-get update -y \ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}" RUN pip3 install \ - neuronx-cc==2.15.143.0 \ - torch-neuronx==2.1.2.2.3.2 \ - transformers-neuronx==0.12.313 \ - libneuronxla==2.0.5347.0 \ + neuronx-cc==2.16.345.0 \ + torch-neuronx==2.5.1.2.4.0 \ + transformers-neuronx==0.13.322 \ + libneuronxla==2.1.681.0 \ --extra-index-url=https://pip.repos.neuron.amazonaws.com # Install HuggingFace packages From 9f355293899273fa0aea9987642c1569cd14a813 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 24 Dec 2024 08:46:37 +0000 Subject: [PATCH 04/18] ci: use AWS Neuron SDK 2.21 system packages --- .github/workflows/inference_cache_llm.yml | 3 ++- .github/workflows/inference_cache_stable_diffusion.yml | 3 ++- .github/workflows/test_inf2.yml | 3 ++- .github/workflows/test_inf2_export.yml | 3 ++- .github/workflows/test_inf2_full_export.yml | 3 ++- .github/workflows/test_inf2_inference.yml | 3 ++- .github/workflows/test_inf2_tgi.yml | 3 ++- .github/workflows/test_trainium_common.yml | 3 ++- .github/workflows/test_trainium_distributed.yml | 3 ++- .github/workflows/test_trainium_examples.yml | 3 ++- 10 files changed, 20 insertions(+), 10 deletions(-) diff --git a/.github/workflows/inference_cache_llm.yml b/.github/workflows/inference_cache_llm.yml index 41598cd1c..cd6e6b1b6 100644 --- a/.github/workflows/inference_cache_llm.yml +++ b/.github/workflows/inference_cache_llm.yml @@ -39,7 +39,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/inference_cache_stable_diffusion.yml b/.github/workflows/inference_cache_stable_diffusion.yml index 2bd83eae1..b760bc5b7 100644 --- a/.github/workflows/inference_cache_stable_diffusion.yml +++ b/.github/workflows/inference_cache_stable_diffusion.yml @@ -29,7 +29,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/test_inf2.yml b/.github/workflows/test_inf2.yml index 7135c8d7d..e066d6c2f 100644 --- a/.github/workflows/test_inf2.yml +++ b/.github/workflows/test_inf2.yml @@ -32,7 +32,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v2 diff --git a/.github/workflows/test_inf2_export.yml b/.github/workflows/test_inf2_export.yml index a863652a0..45aecd4e3 100644 --- a/.github/workflows/test_inf2_export.yml +++ b/.github/workflows/test_inf2_export.yml @@ -32,7 +32,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v2 diff --git a/.github/workflows/test_inf2_full_export.yml b/.github/workflows/test_inf2_full_export.yml index 921596bfe..6cc274167 100644 --- a/.github/workflows/test_inf2_full_export.yml +++ b/.github/workflows/test_inf2_full_export.yml @@ -30,7 +30,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v2 diff --git a/.github/workflows/test_inf2_inference.yml b/.github/workflows/test_inf2_inference.yml index 1a37a23a2..f79f0c796 100644 --- a/.github/workflows/test_inf2_inference.yml +++ b/.github/workflows/test_inf2_inference.yml @@ -32,7 +32,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Install cv2 dependencies run: | diff --git a/.github/workflows/test_inf2_tgi.yml b/.github/workflows/test_inf2_tgi.yml index c8dad05c1..2f5d02cbd 100644 --- a/.github/workflows/test_inf2_tgi.yml +++ b/.github/workflows/test_inf2_tgi.yml @@ -34,7 +34,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Checkout uses: actions/checkout@v2 diff --git a/.github/workflows/test_trainium_common.yml b/.github/workflows/test_trainium_common.yml index 78233b641..6b881ec01 100644 --- a/.github/workflows/test_trainium_common.yml +++ b/.github/workflows/test_trainium_common.yml @@ -34,7 +34,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Install cv2 dependencies run: | diff --git a/.github/workflows/test_trainium_distributed.yml b/.github/workflows/test_trainium_distributed.yml index 1571ec7c1..f30235c1a 100644 --- a/.github/workflows/test_trainium_distributed.yml +++ b/.github/workflows/test_trainium_distributed.yml @@ -33,7 +33,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Install cv2 dependencies run: | diff --git a/.github/workflows/test_trainium_examples.yml b/.github/workflows/test_trainium_examples.yml index d5a18d61d..8c75e973e 100644 --- a/.github/workflows/test_trainium_examples.yml +++ b/.github/workflows/test_trainium_examples.yml @@ -41,7 +41,8 @@ jobs: EOF wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add - sudo apt-get update -y - sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y + sudo apt-get install aws-neuronx-tools=2.20.204.0 aws-neuronx-runtime-lib=2.23.110.0-9b5179492 aws-neuronx-collectives=2.23.133.0-3e70920f2 -y + dpkg -l | grep neuron export PATH=/opt/aws/neuron/bin:$PATH - name: Install cv2 dependencies run: | From 7506d992dac8c3efd6de83aa955592667fc77b0d Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 18 Dec 2024 13:42:26 +0000 Subject: [PATCH 05/18] refactor(qwen2): adapt to latest TnX --- optimum/neuron/models/qwen2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/models/qwen2/model.py b/optimum/neuron/models/qwen2/model.py index 8ee60d9b4..f34a5bd63 100644 --- a/optimum/neuron/models/qwen2/model.py +++ b/optimum/neuron/models/qwen2/model.py @@ -287,6 +287,7 @@ def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwar return padded_inputs, input_embeddings, *rst def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + original_input_ids = input_ids if last_token_id is not None: # preprocess_and_embed() has already been invoked rst = cache_ids, start_ids, last_token_id else: # invoke preprocess_and_embed() @@ -294,5 +295,5 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) inputs = input_embeddings if input_embeddings is not None else input_ids logits = self._forward(inputs, *rst) - logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + logits = self._postprocess(original_input_ids, logits, start_ids=start_ids, **kwargs) return logits From 8dd93d0318f531edd7c97311f91bb3cd09836476 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 2 Jan 2025 11:59:15 +0000 Subject: [PATCH 06/18] refactor(granite): align modeling with latest TnX --- optimum/neuron/models/granite/config.py | 5 +- optimum/neuron/models/granite/hlo.py | 897 ++++++++++++++++++------ optimum/neuron/models/granite/model.py | 96 ++- 3 files changed, 771 insertions(+), 227 deletions(-) diff --git a/optimum/neuron/models/granite/config.py b/optimum/neuron/models/granite/config.py index 6eefd30a6..5cd2640b2 100644 --- a/optimum/neuron/models/granite/config.py +++ b/optimum/neuron/models/granite/config.py @@ -18,7 +18,10 @@ class GraniteConfig(LlamaConfig): - """The Granite model uses the same configuration as the TnX LLama model""" + """The Granite model uses the same base configuration as the TnX LLama model + + It simply includes in addition the granite specific scaling factors. + """ def __init__( self, config: PretrainedConfig, n_positions: int, batch_size: int, amp: str, tp_degree: int, **kwargs diff --git a/optimum/neuron/models/granite/hlo.py b/optimum/neuron/models/granite/hlo.py index d66f12b8d..ae2152633 100644 --- a/optimum/neuron/models/granite/hlo.py +++ b/optimum/neuron/models/granite/hlo.py @@ -21,9 +21,14 @@ from transformers_neuronx.layers import attention, attention_utils, flash_decoding, rotary, transformer from transformers_neuronx.nki.compile import nki_call +from optimum.utils import logging + from .config import GraniteConfig +logger = logging.get_logger() + + def scale_mul(t, scale): """Multiply a tensor by a float scale""" dtype = t.dtype @@ -63,30 +68,69 @@ def inputs(self, scribe, dtype, n_active_tokens, batch_size): return tensors, dims - def token_tree_inputs(self, scribe, dtype, n_active_tokens, batch_size): + def eagle_draft_inputs( + self, + scribe, + dtype, + n_active_tokens, + batch_size, + token_tree=False, + k=0, + n_leaves=0, + depth=0, + n_entrees=0, + width=0, + ): tensors, dims = self.inputs(scribe, dtype, n_active_tokens, batch_size) + hidden_sizes = batch_size, n_active_tokens, self.config.hidden_size + prev_hidden = dtype[hidden_sizes].Parameter(parameter_number=6) + if not token_tree: + return (*tensors, prev_hidden), (*dims, 1) s32 = scribe.s32 - cache_2d = self.neuron_config and self.neuron_config.use_2d_cache_ids - # Allow tree based speculation inputs - if cache_2d: - position_sizes = batch_size, n_active_tokens - previous_cache_ids = s32[position_sizes].Parameter(parameter_number=4) - reorder_mapping = s32[position_sizes].Parameter(parameter_number=5) - else: - previous_cache_ids = s32[n_active_tokens].Parameter(parameter_number=4) - reorder_mapping = s32[n_active_tokens].Parameter(parameter_number=5) - seq_slice_dim = 1 if cache_2d else 0 - - return (*tensors, previous_cache_ids, reorder_mapping), (*dims, seq_slice_dim, seq_slice_dim) + tree_mask_sizes = k, k + tree_mask = s32[tree_mask_sizes].Parameter(parameter_number=7) + indices_sizes = batch_size, k - 1 + update_indices = s32[indices_sizes].Parameter(parameter_number=8) + hidden_update_sizes = batch_size, k - 1 + hidden_update_indices = s32[hidden_update_sizes].Parameter(parameter_number=9) + cache_update_sizes = batch_size, depth + cache_gather_indices = s32[cache_update_sizes].Parameter(parameter_number=10) + cache_scatter_indices = s32[cache_update_sizes].Parameter(parameter_number=11) + pos_sizes = batch_size, k + position_ids = s32[pos_sizes].Parameter(parameter_number=12) + path_sizes = n_leaves, depth + all_paths = s32[path_sizes].Parameter(parameter_number=13) + return ( + *tensors, + prev_hidden, + tree_mask, + update_indices, + hidden_update_indices, + cache_gather_indices, + cache_scatter_indices, + position_ids, + all_paths, + ), (*dims, 1, 1, 1, 1, 1, 1, 1, 1) - def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights): - if self.neuron_config.shard_over_sequence and self.neuron_config.on_device_embedding: - *rst, embed_weight = weights + def embedding(self, input_ids, cache_ids, start_ids, last_token_id, block_tables, context_lens, *weights): + core_id = None + if ( + self.neuron_config.shard_over_sequence or self.neuron_config.sequence_parallel_norm + ) and self.neuron_config.on_device_embedding: + core_id, embed_weight, *rst = weights else: embed_weight, *rst = weights dtype = getattr(input_ids.scribe, self.config.amp) if self.neuron_config.on_device_embedding and self.neuron_config.sequence_parallel_norm: - hidden = hlo.embedding(embed_weight, input_ids, tp_degree=1, dtype=dtype) + hidden = hlo.embedding( + embed_weight, + input_ids, + tp_degree=self.config.tp_degree, + dim=0, + dtype=dtype, + core_id=core_id, + sequence_parallel=self.neuron_config.is_sequence_parallel, + ) else: hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype) if self.config.hidden_size % self.config.tp_degree != 0: @@ -95,12 +139,9 @@ def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights): hidden = hlo.transpose210(hidden) return hidden - def token_tree_embedding( - self, input_ids, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights + def pre_layer( + self, hidden, cache_ids, start_ids, last_token_id, block_tables, context_lens, *weights, position_ids=None ): - return self.embedding(input_ids, cache_ids, start_ids, last_token_id, *weights) - - def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): # TODO: move this fallback calculation to decoder.py if self.num_active_blocks is None and self.neuron_config.optimized_paged_attention: max_model_len = self.neuron_config.continuous_batching.max_model_len @@ -108,6 +149,27 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): block_size = self.neuron_config.continuous_batching.block_size self.num_active_blocks = (max_model_len * max_num_seqs // block_size) - 2 + block_to_seq = None + cached_mask = None + cached_to_contexted = None + active_to_contexted = None + core_id = None + if self.neuron_config.shard_over_sequence or ( + self.neuron_config.sequence_parallel_norm and self.neuron_config.on_device_embedding + ): + core_id, *rst = weights + if self.neuron_config.shard_over_sequence: + n_kv_heads = ( + self.config.num_key_value_heads + if hasattr(self.config, "num_key_value_heads") + else self.config.num_attention_heads + ) + cores_per_kv_head = self.config.tp_degree // n_kv_heads + self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree + cores_per_q_head = self.config.tp_degree // self.config.num_attention_heads + self.cores_per_kv_head = ( + self.cores_per_kv_head // cores_per_q_head if cores_per_q_head else self.cores_per_kv_head + ) if self.neuron_config.optimized_paged_attention and len(last_token_id.sizes) == 2: # For decoding with multiple KV cache blocks: # - cache_ids are used as context_lens @@ -125,36 +187,90 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): block_to_seq = attention_utils.block_to_seq_indexing( context_lens=cache_ids, num_seqs=max_num_seqs, num_blocks=self.num_active_blocks, block_size=block_size ) - else: - block_to_seq = None + elif self.neuron_config.enable_chunked_prefill: + # - cache_ids are used as position_ids of each token + # - start_ids are used as slot_mapping + # - last_token_id is used as new token length for each sequence + context_lens_2d = hlo.unsqueeze(context_lens, 1) + seq_lens = hlo.add(context_lens, last_token_id) + block_size = self.neuron_config.continuous_batching.block_size + if self.neuron_config.shard_over_sequence: + core_sos_rank = hlo.remainder(core_id, cores_per_kv_head) + core_sos_rank = hlo.cast(core_sos_rank, seq_lens.scribe.s32) + sharded_block_size = block_size // cores_per_kv_head + block_tables = attention_utils.active_block_tables( + block_tables=block_tables, + context_lens=hlo.unsqueeze(seq_lens, 1), + num_active_blocks=self.num_active_blocks, + neuron_config=self.neuron_config, + ) + start_ids, active_token_mask = attention_utils.sharded_slot_mapping( + start_ids, cache_ids, block_size, core_sos_rank, sos_degree=cores_per_kv_head + ) + max_num_keys = (self.num_active_blocks + 1) * sharded_block_size + _, n_active_tokens = cache_ids.sizes + cached_to_contexted, cached_to_contexted_idx, active_to_contexted, sharded_seq_lens = ( + attention_utils.sharded_kv_indexing( + seq_lens, + last_token_id, + cache_ids, + max_num_keys, + n_active_tokens, + block_size, + block_tables, + core_sos_rank, + active_token_mask, + sos_degree=cores_per_kv_head, + ) + ) + else: + block_tables = attention_utils.active_block_tables( + block_tables=block_tables, + context_lens=context_lens_2d, + num_active_blocks=self.num_active_blocks, + neuron_config=self.neuron_config, + ) + max_num_keys = self.num_active_blocks * block_size + self.n_positions + cached_mask, cached_to_contexted, active_to_contexted = attention_utils.contexted_kv_indexing( + query_lens=last_token_id, key_lens=seq_lens, max_num_keys=max_num_keys, block_size=block_size + ) # Granite specific: embeddings are multiplied by embedding_multiplier hidden = scale_mul(hidden, self.config.embedding_multiplier) head_dim = self.config.attention_head_size + position_ids = cache_ids if position_ids is None else position_ids pos_embed = rotary.hlo_rotary_embedding( hidden.dtype, int(head_dim * self.config.rotary_percentage), - cache_ids, + position_ids, base=self.config.rope_theta, interpolation_factor=self.config.position_interpolation_factor, rope_scaling=self.config.rope_scaling, ) - core_id = None # flash decoding - if self.neuron_config.shard_over_sequence: - core_id, *rst = weights - n_kv_heads = ( - self.config.num_key_value_heads - if hasattr(self.config, "num_key_value_heads") - else self.config.num_attention_heads - ) - cores_per_kv_head = self.config.tp_degree // n_kv_heads - self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree + if self.neuron_config.shard_over_sequence and not self.neuron_config.enable_chunked_prefill: cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id( cache_ids, start_ids, core_id, self.n_positions, cores_per_kv_head=self.cores_per_kv_head ) + elif self.neuron_config.shard_over_sequence and self.neuron_config.enable_chunked_prefill: + _, n_active_tokens = cache_ids.sizes + batch_size = self.neuron_config.continuous_batching.max_num_seqs + mask, active_mask = hlo.sharded_decoder_attention_block_diagonal_causal_from_bottomright_mask( + last_token_id, + seq_lens, + n_active_tokens, + max_num_keys, + batch_size, + cache_ids, + sharded_seq_lens, + cached_to_contexted_idx, + self.num_active_blocks + 1, + block_size, + core_sos_rank, + cores_per_kv_head, + ) else: mask, active_mask = hlo.attention_mask( cache_ids, @@ -163,21 +279,9 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): last_token_id=last_token_id, num_active_blocks=self.num_active_blocks, neuron_config=self.neuron_config, + context_lens=context_lens, ) - return hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id - - def token_tree_pre_layer( - self, hidden, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights - ): - hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id = ( - self.pre_layer(hidden, cache_ids, start_ids, last_token_id, *weights) - ) - if self.neuron_config.on_device_embedding: - embed_weight, token_tree_mask = weights - else: - token_tree_mask, *rst = weights - active_mask = hlo.token_tree_attention_mask(token_tree_mask, active_mask) return ( hidden, last_token_id, @@ -185,11 +289,44 @@ def token_tree_pre_layer( cache_ids, start_ids, block_to_seq, - previous_cache_ids, - reorder_mapping, mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, + ) + + def eagle_draft_pre_layer( + self, hidden, cache_ids, start_ids, last_token_id, block_tables, context_lens, *weights, position_ids=None + ): + + if ( + self.neuron_config.shard_over_sequence or self.neuron_config.sequence_parallel_norm + ) and self.neuron_config.on_device_embedding: + core_id, embed_weight, *rst = weights + else: + embed_weight, *rst = weights + + if self.config.bias: + fc_weight, fc_bias, *rst = rst + else: + fc_weight, *rst = rst + fc_bias = None + hidden = hlo.dot_add(fc_weight, hidden, fc_bias, 0, 2, 0) + hidden = hlo.permute(hidden, [1, 2, 0]) + hidden = hlo.all_gather(hidden, 2, self.config.tp_degree) + # hidden = hlo.dot_add(hidden, fc_weight, fc_bias, 2, 0, 2) + return self.pre_layer( + hidden, + cache_ids, + start_ids, + last_token_id, + block_tables, + context_lens, + *weights, + position_ids=position_ids, ) def layer( @@ -203,11 +340,93 @@ def layer( mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, + attn_k_cache, + attn_v_cache, + pre_attn_ln_weight, + pre_attn_ln_bias, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + post_attn_ln_weight, + post_attn_ln_bias, + pre_mlp_ln_weight, + pre_mlp_ln_bias, + mlp_in_weight, + mlp_in_scales, + mlp_in_bias, + mlp_out_weight, + mlp_out_scales, + mlp_out_bias, + post_mlp_ln_weight, + post_mlp_ln_bias, + in0_weight=None, + in0_scales=None, + in1_weight=None, + in1_scales=None, + out_weight=None, + out_scales=None, + is_first_last_layer=False, + ): + local_args = {**locals()} + local_args.pop("self") + + # Initialize with kernels + enable_qkv_kernel, enable_mlp_kernel = False, False + if self.neuron_config and self.neuron_config.fused_rmsnorm_qkv: + try: + from neuronxcc.nki._private_kernels.qkv import rmsnorm_qkv_isa_fused_add_kernel # noqa: F401 + + enable_qkv_kernel = True + except Exception: + logger.warning("No QKV kernel found") + if self.neuron_config and self.neuron_config.fused_rmsnorm_mlp: + try: + from neuronxcc.nki._private_kernels.mlp import mlp_isa_kernel # noqa: F401 + + enable_mlp_kernel = True + except Exception: + logger.warning("No MLP kernel found") + enable_mlp_kernel = True + + if (not enable_qkv_kernel and not enable_mlp_kernel) or active_mask is not None: + return self.flat_compiler_layer(**local_args) + + local_args["enable_qkv_kernel"] = enable_qkv_kernel + local_args["enable_mlp_kernel"] = enable_mlp_kernel + return self.native_kernel_layer(**local_args) + + def flat_compiler_layer( + self, + hidden, + last_token_id, + pos_embed, + cache_ids, + start_ids, + block_to_seq, + mask, + active_mask, + core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, attn_k_cache, attn_v_cache, pre_attn_ln_weight, pre_attn_ln_bias, - fused_pre_attn_ln_qkv_weight, attn_q_weight, attn_q_scales, attn_q_bias, @@ -238,39 +457,11 @@ def layer( in1_scales=None, out_weight=None, out_scales=None, + is_first_last_layer=False, ): eps = self.config.rms_norm_eps is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH - if self.neuron_config and self.neuron_config.fused_rmsnorm_qkv and active_mask is None: - assert fused_pre_attn_ln_qkv_weight is not None - attn_output, out_attn_k_cache, out_attn_v_cache = self.fused_rmsnorm_qkv( - hidden, - None, - eps, - cache_ids, - start_ids, - last_token_id, - block_to_seq, - pos_embed, - mask, - active_mask, - core_id, - attn_k_cache, - attn_v_cache, - fused_pre_attn_ln_qkv_weight, - attn_q_scales, - attn_q_bias, - attn_k_weight, - attn_k_scales, - attn_k_bias, # should be none - attn_v_weight, - attn_v_scales, - attn_v_bias, # should be none - attn_out_weight, - attn_out_scales, - attn_out_bias, - ) - else: + if self.neuron_config.has_pre_attention_norm: ln_hidden = ( hlo.rms_norm( hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree @@ -285,31 +476,37 @@ def layer( tp_degree=self.config.tp_degree, ) ) - attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( - ln_hidden, - cache_ids, - start_ids, - last_token_id, - block_to_seq, - pos_embed, - mask, - active_mask, - core_id, - attn_k_cache, - attn_v_cache, - attn_q_weight, - attn_q_scales, - attn_q_bias, - attn_k_weight, - attn_k_scales, - attn_k_bias, - attn_v_weight, - attn_v_scales, - attn_v_bias, - attn_out_weight, - attn_out_scales, - attn_out_bias, - ) + else: + ln_hidden = hidden + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, + attn_k_cache, + attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) # Granite specific: attention output is multiplied by residual multiplier attn_output = scale_mul(attn_output, self.config.residual_multiplier) hidden = hlo.add(attn_output, hidden) @@ -347,7 +544,7 @@ def layer( res_hidden = hlo.add(mlp_hidden, hidden) return res_hidden, out_attn_k_cache, out_attn_v_cache - def token_tree_layer( + def native_kernel_layer( self, hidden, last_token_id, @@ -355,16 +552,17 @@ def token_tree_layer( cache_ids, start_ids, block_to_seq, - previous_cache_ids, - reorder_mapping, mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, attn_k_cache, attn_v_cache, pre_attn_ln_weight, pre_attn_ln_bias, - fused_pre_attn_ln_qkv_weight, attn_q_weight, attn_q_scales, attn_q_bias, @@ -389,82 +587,192 @@ def token_tree_layer( mlp_out_bias, post_mlp_ln_weight, post_mlp_ln_bias, - in0_weight, - in0_scales, - in1_weight, - in1_scales, - out_weight, - out_scales, + in0_weight=None, + in0_scales=None, + in1_weight=None, + in1_scales=None, + out_weight=None, + out_scales=None, + is_first_last_layer=False, + enable_qkv_kernel=False, + enable_mlp_kernel=False, ): eps = self.config.rms_norm_eps is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH - ln_hidden = ( - hlo.rms_norm( - hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree + assert is_bsh + rms_norm_dim = 2 if is_bsh else 0 + + from neuronxcc.nki._private_kernels.mlp import mlp_fused_add_isa_kernel, mlp_isa_kernel + + # lambda functions for calling kernels + def _mlp_fused_add_kernel(attn_output, hidden, ln_w, gate_w, up_w, down_w, out, fused_rmsnorm=True): + mlp_fused_add_isa_kernel( + attn_output, hidden, ln_w, gate_w, up_w, down_w, out, "MLP", fused_rmsnorm=fused_rmsnorm ) - if is_bsh - else hlo.rms_norm( + + def _mlp_kernel(hidden, ln_w, gate_w, up_w, down_w, out, fused_rmsnorm=False): + mlp_isa_kernel(hidden, ln_w, gate_w, up_w, down_w, out, "MLP", fused_rmsnorm=fused_rmsnorm) + + if enable_qkv_kernel: + fused_out = self.fused_rmsnorm_qkv( hidden, pre_attn_ln_weight, eps, - dim=0, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, + attn_k_cache, + attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, # should be none + attn_v_weight, + attn_v_scales, + attn_v_bias, # should be none + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) + if len(fused_out) == 3: + attn_output, out_attn_k_cache, out_attn_v_cache = fused_out + else: + attn_output, out_attn_k_cache, out_attn_v_cache, fused_added_hidden = fused_out + else: + ln_hidden = ( + hlo.rms_norm( + hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree + ) + if is_bsh + else hlo.rms_norm( + hidden, + pre_attn_ln_weight, + eps, + dim=0, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + ) + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, + cache_ids, + start_ids, + last_token_id, + block_to_seq, + pos_embed, + mask, + active_mask, + core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, + attn_k_cache, + attn_v_cache, + attn_q_weight, + attn_q_scales, + attn_q_bias, + attn_k_weight, + attn_k_scales, + attn_k_bias, + attn_v_weight, + attn_v_scales, + attn_v_bias, + attn_out_weight, + attn_out_scales, + attn_out_bias, + ) + + if isinstance(hidden, tuple): + hidden = hidden[0] + + if enable_mlp_kernel: + if self.neuron_config.is_sequence_parallel: + # In sequence parallel, we cannot fuse residual add and rms norm into the kernel + hidden = hlo.add(attn_output, hidden) + norm_hidden = hlo.rms_norm( + hidden, + pre_mlp_ln_weight, + eps, + dim=rms_norm_dim, + neuron_config=self.neuron_config, + tp_degree=self.config.tp_degree, + ) + mlp_result = nki_call( + _mlp_kernel, + norm_hidden, + pre_mlp_ln_weight, + in0_weight, + in1_weight, + out_weight, + output_HloShapes=[ + norm_hidden.dtype[norm_hidden.sizes[0], norm_hidden.sizes[1], norm_hidden.sizes[2]] + ], + ) + dtype, replica_groups = utils.parse_dtype_replica_groups(self.neuron_config, self.config.tp_degree) + mlp_hidden = hlo.reduce_scatter_sum( + mlp_result, tp_degree=self.config.tp_degree, dim=1, replica_groups=replica_groups, dtype=dtype + ) + return hlo.add(mlp_hidden, hidden), out_attn_k_cache, out_attn_v_cache + + # In TP, we can fuse residual add and rms norm into the kernel + if is_first_last_layer or not enable_qkv_kernel: + hidden_add = hlo.add(attn_output, hidden) + mlp_result = nki_call( + _mlp_fused_add_kernel, + attn_output, + hidden, + pre_mlp_ln_weight, + in0_weight, + in1_weight, + out_weight, + output_HloShapes=[hidden.dtype[hidden.sizes[0], hidden.sizes[1], hidden.sizes[2]]], + ) + dtype, replica_groups = utils.parse_dtype_replica_groups(self.neuron_config, self.config.tp_degree) + mlp_hidden = hlo.all_reduce_sum( + mlp_result, self.config.tp_degree, dtype=dtype, replica_groups=replica_groups + ) + if is_first_last_layer or not enable_qkv_kernel: + return hlo.add(mlp_hidden, hidden_add), out_attn_k_cache, out_attn_v_cache + + return (hidden, mlp_hidden, attn_output), out_attn_k_cache, out_attn_v_cache + else: + hidden = hlo.add(attn_output, hidden) + gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp + norm_hidden = hlo.rms_norm( + hidden, + pre_mlp_ln_weight, + eps, + dim=rms_norm_dim, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree, ) - ) - reordered_attn_k_cache, reordered_attn_v_cache = attention.reorder_kv_cache( - attn_k_cache, attn_v_cache, previous_cache_ids, reorder_mapping, neuron_config=self.neuron_config - ) - attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( - ln_hidden, - cache_ids, - start_ids, - last_token_id, - block_to_seq, - pos_embed, - mask, - active_mask, - core_id, - reordered_attn_k_cache, - reordered_attn_v_cache, - attn_q_weight, - attn_q_scales, - attn_q_bias, - attn_k_weight, - attn_k_scales, - attn_k_bias, - attn_v_weight, - attn_v_scales, - attn_v_bias, - attn_out_weight, - attn_out_scales, - attn_out_bias, - ) - hidden = hlo.add(attn_output, hidden) - gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp - rms_norm_dim = 2 if is_bsh else 0 - norm_hidden = hlo.rms_norm( - hidden, - pre_mlp_ln_weight, - eps, - dim=rms_norm_dim, - neuron_config=self.neuron_config, - tp_degree=self.config.tp_degree, - ) - mlp_hidden = gated_mlp( - norm_hidden, - in0_weight, - in1_weight, - out_weight, - in0_scales=in0_scales, - in1_scales=in1_scales, - out_scales=out_scales, - activation_function="silu", - tp_degree=self.config.tp_degree, - neuron_config=self.neuron_config, - ) - res_hidden = hlo.add(mlp_hidden, hidden) - return res_hidden, out_attn_k_cache, out_attn_v_cache + mlp_hidden = gated_mlp( + norm_hidden, + in0_weight, + in1_weight, + out_weight, + in0_scales=in0_scales, + in1_scales=in1_scales, + out_scales=out_scales, + activation_function="silu", + tp_degree=self.config.tp_degree, + neuron_config=self.neuron_config, + ) + if is_first_last_layer or not enable_qkv_kernel: + return hlo.add(mlp_hidden, hidden), out_attn_k_cache, out_attn_v_cache + return (hidden, mlp_hidden, attn_output), out_attn_k_cache, out_attn_v_cache def ln_lm_head( self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias, return_all_outputs=True @@ -495,6 +803,10 @@ def fused_rmsnorm_qkv( mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, attn_k_cache, attn_v_cache, attn_q_weight, @@ -510,11 +822,20 @@ def fused_rmsnorm_qkv( attn_out_scales, attn_out_bias, ): - # TODO: refactor below - from neuronxcc.nki._private_kernels.fused_linear import fused_rms_norm_qkv + from neuronxcc.nki._private_kernels.qkv import rmsnorm_qkv_isa_fused_add_kernel, rmsnorm_qkv_isa_kernel + + def _kernel(h, w, ln_w, output): + return rmsnorm_qkv_isa_kernel(h, w, ln_w, output, "QKV") - def _kernel(h, w, output): - return fused_rms_norm_qkv(h, w, output, eps=eps) + def _fused_out_kernel(h0, h1, h2, w, ln_w, output): + # This kernel will perform h0 = h0 + h1 + h2 (writing results in-place to an input buffer + # FIXME: allow for multiple outputs + return rmsnorm_qkv_isa_fused_add_kernel(h0, h1, h2, w, ln_w, output, "QKV") + + fused_add = False + if isinstance(hidden, tuple): + fused_add = True + hidden, mlp_out, attn_out = hidden n_seqs, n_active_tokens, _ = hidden.sizes d_head = self.config.attention_head_size @@ -532,15 +853,25 @@ def _kernel(h, w, output): n_total_heads_tp = hidden_size_tp // d_head n_heads_tp = n_total_heads_tp - 2 * n_kv_heads_tp - # Q hidden size - hidden_size_tp = d_head * n_heads_tp - nki_output = nki_call( - _kernel, - hidden, - attn_q_weight, - output_HloShapes=[hidden.dtype[hidden.sizes[0], hidden.sizes[1], attn_q_weight.sizes[-1]]], - ) + if fused_add: + nki_output = nki_call( + _fused_out_kernel, + hidden, + mlp_out, + attn_out, + attn_q_weight, + pre_attn_ln_weight, + output_HloShapes=[hidden.dtype[n_seqs, n_active_tokens, hidden_size_tp]], + ) + else: + nki_output = nki_call( + _kernel, + hidden, + attn_q_weight, + pre_attn_ln_weight, + output_HloShapes=[hidden.dtype[n_seqs, n_active_tokens, hidden_size_tp]], + ) slice_lim = nki_output.sizes[-1] // (n_heads_tp + 2 * n_kv_heads_tp) query = hlo.slice_along(nki_output, -1, n_heads_tp * slice_lim, start=0) key = hlo.slice_along(nki_output, -1, (n_heads_tp + n_kv_heads_tp) * slice_lim, start=n_heads_tp * slice_lim) @@ -581,6 +912,10 @@ def _kernel(h, w, output): mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, attn_k_cache, attn_v_cache, attn_q_weight, @@ -597,6 +932,8 @@ def _kernel(h, w, output): attn_out_bias, qkv_tuple=(query, key, value), ) + if fused_add: + return attn_output, out_attn_k_cache, out_attn_v_cache, hidden return attn_output, out_attn_k_cache, out_attn_v_cache def attention( @@ -610,6 +947,10 @@ def attention( mask, active_mask, core_id, + block_tables, + cached_mask, + cached_to_contexted, + active_to_contexted, cached_keys, cached_values, q_weight, @@ -634,7 +975,7 @@ def attention( if self.config.num_key_value_heads is not None: n_head = self.config.num_attention_heads n_kv_head = self.config.num_key_value_heads - n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) + n_head_padded, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) n_kv_heads_tp = n_kv_head_padded // tp_degree # Q = (hidden @ wQ) + bQ @@ -663,6 +1004,25 @@ def attention( n_kv_heads_tp=n_kv_heads_tp, ) + if ( + (active_mask is None and not self.neuron_config.enable_chunked_prefill) + and self.neuron_config.shard_over_sequence + and self.neuron_config.duplicate_q_weight_sos + ): + # slice on computed qeury when sos and duplicate Q weights is on + + # q / kv -> number of q per core after replication + # core_id % tp/kv -> kv replication degree on cores + # q / tp -> actual q per core before replication + slice_start = hlo.remainder( + hlo.reshape(core_id, []), core_id.dtype.Constant(constant_value=self.neuron_config.kv_replication) + ) + slice_size = self.neuron_config.n_head_padded // tp_degree + + slice_start = hlo.multiply(slice_start, slice_start.dtype.Constant(constant_value=slice_size)) + + query = hlo.dynamic_slice_along(query, 2, start=slice_start, size=slice_size) + # Q = Rotate(Q) # K = Rotate(K) query, key = rotary.rotate_half( @@ -703,7 +1063,14 @@ def attention( cached_values_s = hlo.select( cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes), keepdim=True ) + elif cached_keys.sizes[batch_dim] == start_ids.sizes[0]: + # For batched speculative decoding, we will select kv caches for all sequences. No need to do + # index select, which is slow + cached_keys_s = cached_keys + cached_values_s = cached_values else: + # for multi prompt use case, cached_keys.sizes[batch_dim] can still be larger than 1, so we + # need to use start_ids size to determine if we want to select kv cache. cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids) cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids) if self.neuron_config and self.neuron_config.kv_cache_quant: @@ -727,7 +1094,12 @@ def attention( cached_keys_s = cached_keys cached_values_s = cached_values # Communication 1: all-gather query from cores - if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence: + # skip all-gather if query weight is already duplicated + if ( + (n_active_tokens != self.n_positions) + and self.neuron_config.shard_over_sequence + and not self.neuron_config.duplicate_q_weight_sos + ): query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, n_head, tp_degree) # Sp = Q @ Kp @@ -774,7 +1146,7 @@ def attention( shard_over_batch=self.shard_over_batch, ) cache_ids, value, key = flash_decoding.select_values_within_bound( - cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0 + cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0, n_positions=self.n_positions ) else: @@ -798,8 +1170,12 @@ def attention( # Multi-Token Context Encoding else: - _, batch_size, _, _ = query.sizes - if self.neuron_config.lhs_aligned or batch_size == 1: + batch_size = query.sizes[batch_dim] + if ( + (self.neuron_config.lhs_aligned or batch_size == 1) + and not self.neuron_config.enable_chunked_prefill + and not self.neuron_config.bsh_cache_layout + ): context = attention.flash_attention(query, key, value) else: # do not use flash attention for lhs padded (right aligned) batch > 1 case @@ -807,27 +1183,123 @@ def attention( context = None if context is None: - # S = Q @ K + if self.neuron_config.enable_chunked_prefill: + if self.neuron_config.shard_over_sequence: + # Communication 1: all-gather query from cores + if not self.neuron_config.duplicate_q_weight_sos: + query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, n_head, tp_degree) + # S = Q @ K (This matmul wastes some computation) + contexted_keys = attention_utils.gather_sharded_kv( + cached_keys, + active_idx=cached_to_contexted, + active_tokens=key, + active_token_idx=active_to_contexted, + ) + score = attention.score( + query, + contexted_keys, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + score = attention.mask(score, mask, tp_degree=tp_degree) + # FlashAttention-Style Communication + f32 = score.scribe.f32 + score = hlo.cast(score, f32) + max_score_local = hlo.reduce_max(score, dim=3) + max_score_local_br = hlo.broadcast(max_score_local, score.sizes, [0, 1, 2]) + score = hlo.exp(hlo.subtract(score, max_score_local_br)) + l_sum_score_local = hlo.reduce_sum(score, dim=3) - score = attention.score( - query, - key, - n_kv_heads=self.config.num_key_value_heads, - tp_degree=tp_degree, - neuron_config=self.neuron_config, - ) - score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) - context = attention.context_combined( - score, - value, - n_kv_heads=self.config.num_key_value_heads, - tp_degree=tp_degree, - neuron_config=self.neuron_config, - ) + # Value Combination + score = hlo.cast(score, cached_values.dtype) + contexted_values = attention_utils.gather_sharded_kv( + cached_values, + active_idx=cached_to_contexted, + active_tokens=value, + active_token_idx=active_to_contexted, + ) + context = attention.context_combined( + score, + contexted_values, + n_kv_heads=self.config.num_key_value_heads, + dtype=score.scribe.f32, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + skip_softmax=True, + ) + # Communication 2: softmax correction + context = attention_utils.sharded_softmax_correction( + context, + max_score_local, + l_sum_score_local, + core_id, + tp_degree=tp_degree, + sos_degree=self.cores_per_kv_head, + ) + # Communication 3: reduce-scatter partial context + num_groups = tp_degree // self.cores_per_kv_head + replica_groups = utils.build_replica_groups( + num_groups=num_groups, group_size=self.cores_per_kv_head, interleave=False + ) + context = hlo.reduce_scatter_sum( + context, tp_degree=self.cores_per_kv_head, dim=2, replica_groups=replica_groups + ) + context = hlo.cast(context, hidden.dtype) + else: + # S = Q @ K + cached_keys_gathered = attention_utils.gather_blocks( + cached_keys, block_tables=block_tables, neuron_config=self.neuron_config + ) + contexted_keys = attention_utils.contexted_kv( + cached_keys_gathered, key, cached_mask, cached_to_contexted, active_to_contexted + ) + score = attention.score( + query, + contexted_keys, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) - if self.neuron_config.shard_over_sequence: + score = attention.mask(score, mask, tp_degree=tp_degree) + + # C = softmax(Sa, Sp) @ (Va, Vp) + cached_values_gathered = attention_utils.gather_blocks( + cached_values, block_tables=block_tables, neuron_config=self.neuron_config + ) + contexted_values = attention_utils.contexted_kv( + cached_values_gathered, value, cached_mask, cached_to_contexted, active_to_contexted + ) + context = attention.context_combined( + score, + contexted_values, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + else: + # S = Q @ K + + score = attention.score( + query, + key, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + context = attention.context_combined( + score, + value, + n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, + neuron_config=self.neuron_config, + ) + + if self.neuron_config.shard_over_sequence and not self.neuron_config.enable_chunked_prefill: cache_ids, value, key = flash_decoding.select_values_within_bound( - cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0 + cache_ids, value, key, self.cores_per_kv_head, core_id, dim=0, n_positions=self.n_positions ) # KCache, VCache = K, V if cached_keys.sizes == key.sizes: @@ -843,4 +1315,7 @@ def attention( # O = (C @ wO) + bO output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config) + # we do zero padding so disable now + # if cores_per_attn_head and not self.neuron_config.shard_over_sequence: + # output = hlo.divide(output, cores_per_attn_head) return output, updated_keys, updated_values diff --git a/optimum/neuron/models/granite/model.py b/optimum/neuron/models/granite/model.py index ddd3aecf2..76e658ab3 100644 --- a/optimum/neuron/models/granite/model.py +++ b/optimum/neuron/models/granite/model.py @@ -16,7 +16,7 @@ import torch from transformers import PretrainedConfig -from transformers_neuronx import base, bucket, decoder, ops, utils +from transformers_neuronx import base, bucket, decoder, utils from transformers_neuronx.config import NeuronConfig from transformers_neuronx.constants import KV_SHARD_PAD, LAYOUT_HSB @@ -108,6 +108,20 @@ def __init__( [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes ) hlo_builder = GraniteForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) + if self.neuron_config.enable_chunked_prefill: + max_num_seqs = self.neuron_config.continuous_batching.max_num_seqs + block_size = self.neuron_config.continuous_batching.block_size + num_blocks = self.neuron_config.continuous_batching.num_blocks + + # define block buckets based on the n_positions + block_sizes = [n_pos * max_num_seqs // block_size for n_pos in self.token_buckets] + assert ( + max(block_sizes) <= num_blocks + ), "Too few blocks allocated, consider increasing gpu_memory_utilization or override" + # for chunked prefill we set the context batch sizes to the block sizes (we use the batch size bucketing + # for KV cache active blocks and the context length estimate bucket for number of queries) + self.context_batch_sizes = block_sizes + self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( tp_degree=tp_degree, n_positions_list=self.token_buckets, @@ -127,14 +141,16 @@ def __init__( unroll=self.unroll, buckets=self.token_buckets, model_obj=self ) self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder( - unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self + unroll=self.context_unroll, + buckets=self.context_buckets, + model_obj=self, + context_batch_sizes=self.context_batch_sizes, ) self.decoder_lm_head_for_speculation = {} self.decoder_lm_head_for_window_context = {} def load_weights(self): self.materialize_embeddings() - ops.init() for layer_id, layer in enumerate(self.chkpt_model.model.layers): if layer_id not in self.layers_after_partition: @@ -147,7 +163,8 @@ def load_weights(self): else: is_unit_scale = False new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale) - new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + if self.neuron_config.has_pre_attention_norm: + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) new_layer.add_attention_query(attn.q_proj.weight.detach().T, None) new_layer.add_attention_key(attn.k_proj.weight.detach().T, None) new_layer.add_attention_value(attn.v_proj.weight.detach().T, None) @@ -155,10 +172,30 @@ def load_weights(self): new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) else: new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) - new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + + if self.neuron_config.fused_rmsnorm_mlp: + dummy_post_attention_ln_weight = torch.ones_like(layer.post_attention_layernorm.weight.detach()) + new_layer.add_pre_mlp_layer_norm(dummy_post_attention_ln_weight, None) + else: + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state - if self.neuron_config.fuse_mlp: + if self.neuron_config.fused_rmsnorm_mlp: + fused_pre_mlp_ln_gate_weight = ( + mlp.gate_proj.weight + * layer.post_attention_layernorm.weight.detach().to(dtype=mlp.gate_proj.weight.dtype) + ) + new_layer.add_parameter( + fused_pre_mlp_ln_gate_weight.T, sharding=1, allow_pad=True, allow_quantize=True + ) + fused_pre_mlp_ln_up_weight = mlp.up_proj.weight * layer.post_attention_layernorm.weight.detach().to( + dtype=mlp.up_proj.weight.dtype + ) + new_layer.add_parameter(fused_pre_mlp_ln_up_weight.T, sharding=1, allow_pad=True, allow_quantize=True) + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True, out_feature_dim=0 + ) + elif self.neuron_config.fuse_mlp: assert all( getattr(mlp, attr, None) for attr in ["gate_proj", "up_proj"] ), "fuse_mlp need to have gate and up proj weights" @@ -179,8 +216,6 @@ def load_weights(self): ) else: new_layer.add_mlp_output( - mlp.down_proj.weight.detach(), - None, sharding=1, transposed=False, ) @@ -206,38 +241,61 @@ def load_weights(self): ) new_layer.to_neuron() layer.nullify() - if self.neuron_config.shard_over_sequence: + + # Adding core_id for sos or seq-norm (vocab parallel is used as default with seq-par norm) + add_core_id = self.neuron_config.shard_over_sequence or ( + self.neuron_config.sequence_parallel_norm and self.neuron_config.on_device_embedding + ) + if add_core_id: self.decoder_lm_head.add_pre_layer_parameter(torch.arange(self.config.tp_degree), sharding=0) # For pipeline parallel, we need to load ln and lm_head for now even if the pipeline stage doesn't compute the, because # 1) we need the ln_lm_head hlo for pp0 to get the logits shape and dtype # 2) we don't needs these for intermediate pp stages, but to keep things simple, just include ln_lm_head for all pp stages for now # 3) to get ln_lm_head hlo, we need to do weight loading and sharding # 4) this will introduce extra memory allocation, but ln_lm_head i/o tensor is much smaller and we can get rid of it when we can construct hlo in init - ln_f = self.chkpt_model.model.norm - ln_f.materialize() - self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + if not self.neuron_config.is_eagle_draft: + ln_f = self.chkpt_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + ln_f.nullify() lm_head = self.chkpt_model.lm_head lm_head.materialize() self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + lm_head.nullify() + if self.neuron_config.on_device_embedding: if self.neuron_config.sequence_parallel_norm: self.decoder_lm_head.add_pre_layer_parameter( - self.chkpt_model.model.embed_tokens.weight, sharding=None, allow_pad=True + self.chkpt_model.model.embed_tokens.weight, sharding=0, allow_pad=True ) else: self.decoder_lm_head.add_pre_layer_parameter( self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True ) - lm_head.nullify() + if self.neuron_config.is_eagle_draft: + self.chkpt_model.model.fc.materialize() + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.fc.weight.detach().T, sharding=1, allow_pad=True + ) + if self.chkpt_model.model.fc.bias is not None: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.fc.bias.detach(), sharding=0, allow_pad=True + ) + self.chkpt_model.model.fc.nullify() self.decoder_lm_head.to_neuron() self.init_rest_of_model() + self.maybe_nullify_embeddings() def materialize_embeddings(self): # Materialize the embedding to CPU self.chkpt_model.model.embed_tokens.materialize() + def maybe_nullify_embeddings(self): + if self.neuron_config.on_device_embedding: + self.chkpt_model.model.embed_tokens.nullify() + def init_rest_of_model(self): # Pipeline sparallel deosn't support executor right now if not self.neuron_config.is_pp(): @@ -290,14 +348,22 @@ def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwar return padded_inputs, input_embeddings, *rst def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + original_input_ids = input_ids if last_token_id is not None: # preprocess_and_embed() has already been invoked rst = cache_ids, start_ids, last_token_id else: # invoke preprocess_and_embed() input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs) # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) inputs = input_embeddings if input_embeddings is not None else input_ids + if "prev_hidden" in kwargs: + rst = *rst, kwargs["prev_hidden"] logits = self._forward(inputs, *rst) # Granite specific: divide logits by scaling factor logits = logits / self.config.logits_scaling - logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + if self.neuron_config.is_eagle_target: + logits, hidden = logits + logits = self._postprocess(original_input_ids, logits, start_ids=start_ids, **kwargs) + return logits, hidden + else: + return self._postprocess(original_input_ids, logits, start_ids=start_ids, **kwargs) return logits From 2454d3bf76fe3ec460d20043ee7e0ea452f6fd85 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 2 Jan 2025 14:56:31 +0000 Subject: [PATCH 07/18] test: fix neuronx_distributed import --- tests/distributed/test_model_parallelization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index ad0cbf241..62c08fb3f 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -64,8 +64,8 @@ import torch_xla.core.xla_model as xm if is_neuronx_distributed_available(): - from neuronx_distributed.modules.qkv_linear import get_kv_shared_group from neuronx_distributed.parallel_layers.parallel_state import ( + get_kv_shared_group, get_pipeline_model_parallel_rank, get_tensor_model_parallel_group, get_tensor_model_parallel_size, From c548a87c910609941f5a0c85acc1d2258efafbfc Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 2 Jan 2025 15:10:55 +0000 Subject: [PATCH 08/18] test(distributed): adapt to new torch_neuronx initialization --- tests/distributed_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed_utils.py b/tests/distributed_utils.py index 1ae1e1319..a060157a3 100644 --- a/tests/distributed_utils.py +++ b/tests/distributed_utils.py @@ -207,8 +207,8 @@ def _dist_run(self, local_rank, num_procs, master_port, tp_size, pp_size): raise RuntimeError("self.torchelastic_run_id was not set, it is needed to run a distributed test.") os.environ["TORCHELASTIC_RUN_ID"] = self.torchelastic_run_id - # Now that the environment has been set, we can configure the PJRT environment. - torch_neuronx.xla.configure_pjrt_environment() + # Now that the environment has been set, we can initialize the XLA environment. + torch_neuronx.initialization.initialize() if self.init_distributed: dist.init_process_group(backend=self.backend, rank=local_rank, world_size=num_procs) From 81610f48fa5393868e3b7078a371690df38cc7b3 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 3 Jan 2025 08:35:06 +0000 Subject: [PATCH 09/18] fix(distributed): skip non pickable attr duplication --- optimum/neuron/distributed/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 7fbd8467b..9b7b65782 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -17,6 +17,7 @@ import contextlib import copy import functools +import inspect import itertools import json import os @@ -1355,9 +1356,16 @@ def duplicate_module_with_random_weights_on_cpu(module: torch.nn.Module) -> torc for name in dir(module): attr = getattr(module, name) + if inspect.ismethod(attr): + continue if name in (children_names | buffer_names | parameter_names) or name.startswith("__"): continue - setattr(clone, name, copy.deepcopy(attr)) + try: + cloned_attr = copy.deepcopy(attr) + except Exception: + # Attribute is not pickable or cannot be copied + continue + setattr(clone, name, cloned_attr) for name, mod in module.named_children(): clone.add_module(name, duplicate_module_with_random_weights_on_cpu(mod)) From 13a428c1a4b78dd9ed206a3ac2f5792576d51856 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 3 Jan 2025 12:55:57 +0000 Subject: [PATCH 10/18] test(sd): export to bf16 Export to float leads to compilation errors in AWS Neuron SDK 2.21.0 --- tests/exporters/test_export.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index 167ea2d6c..51d87389b 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -220,6 +220,7 @@ def test_export_for_stable_diffusion_models(self, model_id): input_shapes = build_stable_diffusion_components_mandatory_shapes( **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4} ) + compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"} with TemporaryDirectory() as tmpdirname: models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( @@ -234,6 +235,7 @@ def test_export_for_stable_diffusion_models(self, model_id): models_and_neuron_configs=models_and_neuron_configs, output_dir=Path(tmpdirname), output_file_names=output_model_names, + compiler_kwargs=compiler_kwargs, ) validate_models_outputs( models_and_neuron_configs=models_and_neuron_configs, @@ -251,6 +253,7 @@ def test_export_for_stable_diffusion_xl_models(self, model_id): input_shapes = build_stable_diffusion_components_mandatory_shapes( **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4} ) + compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"} with TemporaryDirectory() as tmpdirname: models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( @@ -265,6 +268,7 @@ def test_export_for_stable_diffusion_xl_models(self, model_id): models_and_neuron_configs=models_and_neuron_configs, output_dir=Path(tmpdirname), output_file_names=output_model_names, + compiler_kwargs=compiler_kwargs, ) validate_models_outputs( models_and_neuron_configs=models_and_neuron_configs, @@ -283,6 +287,7 @@ def test_export_sd_with_fused_lora_weights(self): input_shapes = build_stable_diffusion_components_mandatory_shapes( **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4} ) + compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"} with TemporaryDirectory() as tmpdirname: models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( @@ -301,6 +306,7 @@ def test_export_sd_with_fused_lora_weights(self): models_and_neuron_configs=models_and_neuron_configs, output_dir=Path(tmpdirname), output_file_names=output_model_names, + compiler_kwargs=compiler_kwargs, ) validate_models_outputs( models_and_neuron_configs=models_and_neuron_configs, From 466d6d986bf9676a820553459835fbaf2f2e4b88 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 7 Jan 2025 13:25:40 +0000 Subject: [PATCH 11/18] test(tgi): fix gpt2 sample expectations --- text-generation-inference/tests/integration/test_generate.py | 2 +- text-generation-inference/tests/server/test_decode.py | 2 +- text-generation-inference/tests/server/test_prefill.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/text-generation-inference/tests/integration/test_generate.py b/text-generation-inference/tests/integration/test_generate.py index db716be57..0d54f9df7 100644 --- a/text-generation-inference/tests/integration/test_generate.py +++ b/text-generation-inference/tests/integration/test_generate.py @@ -47,7 +47,7 @@ async def test_model_single_request(tgi_service): seed=42, ) sample_expectations = { - "gpt2": "Deep Learning", + "gpt2": "researchers", "llama": "Deep Learning", "mistral": "Deep learning", "qwen2": "Deep Learning", diff --git a/text-generation-inference/tests/server/test_decode.py b/text-generation-inference/tests/server/test_decode.py index 2ab4c2da0..ff8449f71 100644 --- a/text-generation-inference/tests/server/test_decode.py +++ b/text-generation-inference/tests/server/test_decode.py @@ -36,7 +36,7 @@ def _test_decode(config_name, generator, do_sample): assert output.finish_reason == 0 if do_sample: expected_text = { - "gpt2": " The sun was set", + "gpt2": "The only things", "llama": "George Orwell, 1984", "mistral": "The sky was", "qwen2": " A young woman with", diff --git a/text-generation-inference/tests/server/test_prefill.py b/text-generation-inference/tests/server/test_prefill.py index 2120e5c59..648eba7c0 100644 --- a/text-generation-inference/tests/server/test_prefill.py +++ b/text-generation-inference/tests/server/test_prefill.py @@ -35,7 +35,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert len(generations) == batch_size if do_sample: expectations = { - "gpt2": [383, " The"], + "gpt2": [198, "\n"], "llama": [10058, " George"], "mistral": [450, " The"], "qwen2": [362, " A"], From 56e03c56c1c46717ecf067921a07f0bd3ec1d386 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 10 Jan 2025 11:26:52 +0100 Subject: [PATCH 12/18] [WIP] fix GQA QKV --- optimum/neuron/distributed/utils.py | 31 ++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 9b7b65782..816f30f3a 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -215,9 +215,15 @@ def get_parameter_names_mapping( parent_module_name, _ = fully_qualified_name.rsplit(".", maxsplit=1) mapping = {} for qkv_proj_name, proj_name in self._qkv_proj_name_to_proj_name.items(): - mapping[f"{parent_module_name}.{proj_name}.weight"] = f"{fully_qualified_name}.weight_{qkv_proj_name}" + if self.fuse_qkv: + mapping[f"{parent_module_name}.{proj_name}.weight"] = f"{fully_qualified_name}.weight_qkv" + else: + mapping[f"{parent_module_name}.{proj_name}.weight"] = f"{fully_qualified_name}.weight_{qkv_proj_name}" if self.use_bias: - mapping[f"{parent_module_name}.{proj_name}.bias"] = f"{fully_qualified_name}.bias_{qkv_proj_name}" + if self.fuse_qkv: + mapping[f"{parent_module_name}.{proj_name}.bias"] = f"{fully_qualified_name}.bias_qkv" + else: + mapping[f"{parent_module_name}.{proj_name}.bias"] = f"{fully_qualified_name}.bias_{qkv_proj_name}" if reversed: mapping = {v: k for k, v in mapping.items()} return mapping @@ -762,8 +768,12 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( ) proj_name = weight_name[-1] - weight = getattr(layer, weight_name) - bias = getattr(layer, f"bias_{proj_name}") + if layer.fuse_qkv: + weight = getattr(layer, "weight_qkv") + bias = getattr(layer, f"bias_qkv") + else: + weight = getattr(layer, weight_name) + bias = getattr(layer, f"bias_{proj_name}") num_attention_heads = layer.num_attention_heads num_key_value_heads = layer.num_key_value_heads @@ -781,11 +791,22 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( weight_data = create_kv_proj_local_weight_from_regular_weight( weight_data, kv_size_multiplier, weight.size(0) ) + print(weight_data.shape) else: weight_data = create_query_or_output_projection_local_weight_from_regular_weight( weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query" ) - weight.copy_(weight_data) + if layer.fuse_qkv: + if proj_name == "q": + s = slice(0, layer.q_output_size_per_partition) + elif proj_name == "k": + s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition) + else: + s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None) + print(layer.q_output_size_per_partition, layer.kv_output_size_per_partition) + weight[s, :] = weight_data + else: + weight.copy_(weight_data) mark_parameter_init_status_during_parallelization(weight, True) else: mark_parameter_init_status_during_parallelization(weight, False) From 04894249eb86726da69a43cadfbe953ecacc31bc Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Jan 2025 17:55:24 +0100 Subject: [PATCH 13/18] [WIP] GQA checkpointing works, but output_proj does not work --- optimum/neuron/distributed/base.py | 8 +- optimum/neuron/distributed/checkpointing.py | 83 +++++++++++-------- optimum/neuron/distributed/parallel_layers.py | 6 +- optimum/neuron/distributed/utils.py | 34 ++++++-- 4 files changed, 88 insertions(+), 43 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 20bbcec6a..6aacbe755 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -484,7 +484,7 @@ def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int else: # TODO: change kv heads. maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( - mod, f"weight_{proj_name}", linear_layer=fake_linear_mod + mod, proj_name, f"weight_{proj_name}", linear_layer=fake_linear_mod ) del fake_linear_mod @@ -678,6 +678,9 @@ def should_parallelize_layer_predicate_func(layer): "num_attention_heads": None, "num_key_value_heads": None, "kv_size_multiplier": None, + "fuse_qkv": None, + "q_output_size_per_partition": None, + "kv_output_size_per_partition": None , } for mod in model.modules(): if isinstance(mod, OptimumGQAQKVColumnParallelLinear): @@ -690,6 +693,9 @@ def should_parallelize_layer_predicate_func(layer): "num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads, "kv_size_multiplier": kv_size_multiplier, + "fuse_qkv": mod.fuse_qkv, + "q_output_size_per_partition": mod.q_output_size_per_partition, + "kv_output_size_per_partition": mod.kv_output_size_per_partition, } break diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 7f9ce7d78..064bda7c8 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -134,46 +134,63 @@ def consolidate_tensor_parallel_checkpoints( for name in parameter_names: # We need to handle the mapping between the GQA parameter names and the original names. is_gqa_qkv_weight = name in gqa_qkv_names_to_original_names + is_fuse_qkv = gqa_qkv_metadata["fuse_qkv"] if is_gqa_qkv_weight: - original_name = gqa_qkv_names_to_original_names[name] - weight_name = name.rsplit(".", maxsplit=1)[1] + if is_fuse_qkv: + original_names = [k for k, v in original_parameter_names_to_gqa_qkv_names.items() if v == name] + weight_names = [name.rsplit(".", maxsplit=1)[1] for name in original_names] + weight_names = ["weight_q", "weight_k", "weight_v"] + else: + original_names = [gqa_qkv_names_to_original_names[name]] + weight_names = [name.rsplit(".", maxsplit=1)[1]] else: - original_name = name - weight_name = "" # Not needed. + original_names = [name] + weight_names = [""] # Not needed. # For now all parameter metadatas are equal so it is enough to take the first element. # This might not be the case anymore when `ParameterMetadata` uses slices. sharded_metadata = sharded_metadatas[name] - if sharded_metadata.is_tied: - consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous() - else: - # Ensure that all tensors are contiguous before concatenating or further processing - weights = [state_dict[name].contiguous() for state_dict in state_dicts] - tp_size = len(weights) - - full_weight = ( - torch.cat( - weights, - dim=sharded_metadata.partition_dim, - ) - .to("cpu") - .contiguous() - ) # Ensure the result is also contiguous - - if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: + for original_name, weight_name in zip(original_names, weight_names): + if sharded_metadata.is_tied: + consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous() + else: + if is_fuse_qkv: + if weight_name == "weight_q": + s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"]) + elif weight_name == "weight_k": + s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"]) + else: + s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None) + else: + s = slice(None, None) + + # Ensure that all tensors are contiguous before concatenating or further processing + weights = [state_dict[name][s].contiguous() for state_dict in state_dicts] + tp_size = len(weights) + full_weight = ( - torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone() - ) - elif weight_name == "weight_q" or original_name in gqa_qkv_output_projections_names: - full_weight = create_gqa_query_or_output_projection_weight_from_full_weight( - full_weight, - tp_size, - gqa_qkv_metadata["num_attention_heads"], - gqa_qkv_metadata["num_key_value_heads"], - gqa_qkv_metadata["kv_size_multiplier"], - "query" if weight_name == "weight_q" else "output", - ) - consolidated_state_dict[original_name] = full_weight + torch.cat( + weights, + dim=sharded_metadata.partition_dim, + ) + .to("cpu") + .contiguous() + ) # Ensure the result is also contiguous + + if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]: + full_weight = ( + torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone() + ) + elif weight_name == "weight_q" or original_name in gqa_qkv_output_projections_names: + full_weight = create_gqa_query_or_output_projection_weight_from_full_weight( + full_weight, + tp_size, + gqa_qkv_metadata["num_attention_heads"], + gqa_qkv_metadata["num_key_value_heads"], + gqa_qkv_metadata["kv_size_multiplier"], + "query" if weight_name == "weight_q" else "output", + ) + consolidated_state_dict[original_name] = full_weight return consolidated_state_dict diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index 8e5ce0819..a90cbf074 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -379,8 +379,8 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear( key_linear = getattr(attention_layer, cls.KEYS_NAME) hidden_size = query_linear.weight.size(1) - query_in_features = query_linear.weight.size(0) - key_value_in_features = key_linear.weight.size(0) + query_out_features = query_linear.out_features + key_value_out_features = key_linear.out_features if kv_size_multiplier is None: kv_size_multiplier = get_tensor_model_parallel_size() // num_key_value_heads @@ -397,7 +397,7 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear( num_attention_heads, num_key_value_heads, hidden_size, - [query_in_features, key_value_in_features], + [query_out_features, key_value_out_features], gather_output=False, bias=query_linear.bias is not None, sequence_parallel_enabled=sequence_parallel_enabled, diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 816f30f3a..d2059ff9b 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -167,6 +167,7 @@ class OptimumGQAQKVColumnParallelLinear(GQAQKVColumnParallelLinear): Same as GQAQKVColumnParallelLinear with the needed metadata for `optimum-neuron`. """ + @requires_neuronx_distributed def __init__( self, query_proj_name: str, @@ -186,6 +187,9 @@ def __init__( keep_master_weight: bool = False, kv_size_multiplier: int = 1, ): + from neuronx_distributed.parallel_layers.utils import set_tensor_model_parallel_attributes + from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size + super().__init__( input_size, output_sizes, @@ -199,6 +203,15 @@ def __init__( kv_size_multiplier=kv_size_multiplier, ) + if self.fuse_qkv: + set_tensor_model_parallel_attributes( + tensor=self.weight_qkv, + is_parallel=True, + dim=0, + stride=1, + num_partitions=get_tensor_model_parallel_size(), + ) + self.query_proj_name = query_proj_name self.key_proj_name = key_proj_name self.value_proj_name = value_proj_name @@ -612,7 +625,8 @@ def create_kv_proj_local_weight_from_regular_weight( tp_rank = get_tensor_model_parallel_rank() repeated_weight = weight_data.repeat(kv_size_multiplier, 1) split = torch.split(repeated_weight, output_size_per_partition, dim=0) - return torch.cat(split[tp_rank::tp_size], dim=0) + res = torch.cat(split[tp_rank::tp_size], dim=0) + return res def compute_query_indices_for_rank( @@ -751,6 +765,7 @@ def create_local_bias_from_regular_bias( @requires_neuronx_distributed def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( layer: OptimumGQAQKVColumnParallelLinear, + proj_name: str, weight_name: str, linear_layer_weight_info: Optional[WeightInformation] = None, linear_layer_bias_weight_info: Optional[WeightInformation] = None, @@ -767,7 +782,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( "A linear's layer WeightInformation or a linear layer to copy the weights from need to specified." ) - proj_name = weight_name[-1] + # proj_name = weight_name[-1] if layer.fuse_qkv: weight = getattr(layer, "weight_qkv") bias = getattr(layer, f"bias_qkv") @@ -780,7 +795,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( kv_size_multiplier = layer.kv_size_multiplier with torch.no_grad(): - if not was_already_initialized_during_parallelization(weight): + if layer.fuse_qkv or not was_already_initialized_during_parallelization(weight): weight_data = None if linear_layer_weight_info is not None: weight_data = load_tensor_for_weight(linear_layer_weight_info) @@ -788,10 +803,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( weight_data = linear_layer.weight.data if weight_data is not None: if proj_name in "kv": + output_size = layer.kv_output_size_per_partition weight_data = create_kv_proj_local_weight_from_regular_weight( - weight_data, kv_size_multiplier, weight.size(0) + weight_data, kv_size_multiplier, output_size ) - print(weight_data.shape) else: weight_data = create_query_or_output_projection_local_weight_from_regular_weight( weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query" @@ -803,7 +818,6 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition) else: s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None) - print(layer.q_output_size_per_partition, layer.kv_output_size_per_partition) weight[s, :] = weight_data else: weight.copy_(weight_data) @@ -844,6 +858,12 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( original_to_gqa = layer.get_parameter_names_mapping(named_modules) for orig_name, gqa_name in original_to_gqa.items(): + if layer.query_proj_name in orig_name: + proj_name = "q" + elif layer.key_proj_name in orig_name: + proj_name = "k" + else: + proj_name = "v" linear_layer_qualified_name, _ = orig_name.rsplit(".", maxsplit=1) linear_weight_info, linear_bias_weight_info = get_linear_weight_info( weight_map, linear_layer_qualified_name, fail_if_not_found=False @@ -852,6 +872,7 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( if try_from_checkpoint and linear_weight_info is not None: maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( layer, + proj_name, weight_name, linear_layer_weight_info=linear_weight_info, linear_layer_bias_weight_info=linear_bias_weight_info, @@ -860,6 +881,7 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( orig_layer_name, _ = orig_name.rsplit(".", maxsplit=1) maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( layer, + proj_name, weight_name, linear_layer=model.get_submodule(orig_layer_name), ) From 9c2165c3ab7476e570264926283aaf4b3d41db93 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Jan 2025 19:26:50 +0100 Subject: [PATCH 14/18] Fix output_proj --- optimum/neuron/distributed/checkpointing.py | 4 +++- optimum/neuron/distributed/utils.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 064bda7c8..58500ef0f 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -159,8 +159,10 @@ def consolidate_tensor_parallel_checkpoints( s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"]) elif weight_name == "weight_k": s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"]) - else: + elif weight_name == "weight_v": s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None) + else: + s = slice(None, None) else: s = slice(None, None) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index d2059ff9b..1d57afa39 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -625,8 +625,7 @@ def create_kv_proj_local_weight_from_regular_weight( tp_rank = get_tensor_model_parallel_rank() repeated_weight = weight_data.repeat(kv_size_multiplier, 1) split = torch.split(repeated_weight, output_size_per_partition, dim=0) - res = torch.cat(split[tp_rank::tp_size], dim=0) - return res + return torch.cat(split[tp_rank::tp_size], dim=0) def compute_query_indices_for_rank( From 01d4a1868a400ae717218310459fc7df596a7171 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Jan 2025 19:27:09 +0100 Subject: [PATCH 15/18] Styling --- optimum/neuron/distributed/base.py | 2 +- optimum/neuron/distributed/checkpointing.py | 12 ++++++++++-- optimum/neuron/distributed/utils.py | 9 ++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 6aacbe755..e1312f3ac 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -680,7 +680,7 @@ def should_parallelize_layer_predicate_func(layer): "kv_size_multiplier": None, "fuse_qkv": None, "q_output_size_per_partition": None, - "kv_output_size_per_partition": None , + "kv_output_size_per_partition": None, } for mod in model.modules(): if isinstance(mod, OptimumGQAQKVColumnParallelLinear): diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 58500ef0f..77acb75d8 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -158,9 +158,17 @@ def consolidate_tensor_parallel_checkpoints( if weight_name == "weight_q": s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"]) elif weight_name == "weight_k": - s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"]) + s = slice( + gqa_qkv_metadata["q_output_size_per_partition"], + gqa_qkv_metadata["q_output_size_per_partition"] + + gqa_qkv_metadata["kv_output_size_per_partition"], + ) elif weight_name == "weight_v": - s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None) + s = slice( + gqa_qkv_metadata["q_output_size_per_partition"] + + gqa_qkv_metadata["kv_output_size_per_partition"], + None, + ) else: s = slice(None, None) else: diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 1d57afa39..54b8d0683 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -187,8 +187,8 @@ def __init__( keep_master_weight: bool = False, kv_size_multiplier: int = 1, ): - from neuronx_distributed.parallel_layers.utils import set_tensor_model_parallel_attributes from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size + from neuronx_distributed.parallel_layers.utils import set_tensor_model_parallel_attributes super().__init__( input_size, @@ -784,7 +784,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( # proj_name = weight_name[-1] if layer.fuse_qkv: weight = getattr(layer, "weight_qkv") - bias = getattr(layer, f"bias_qkv") + bias = getattr(layer, "bias_qkv") else: weight = getattr(layer, weight_name) bias = getattr(layer, f"bias_{proj_name}") @@ -814,7 +814,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( if proj_name == "q": s = slice(0, layer.q_output_size_per_partition) elif proj_name == "k": - s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition) + s = slice( + layer.q_output_size_per_partition, + layer.q_output_size_per_partition + layer.kv_output_size_per_partition, + ) else: s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None) weight[s, :] = weight_data From 0a58c8e2d1a94ecc0777a436110d55ed6a7a3bf7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 15 Jan 2025 19:28:03 +0100 Subject: [PATCH 16/18] Remove comment --- optimum/neuron/distributed/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 54b8d0683..8d3664d2c 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -781,7 +781,6 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( "A linear's layer WeightInformation or a linear layer to copy the weights from need to specified." ) - # proj_name = weight_name[-1] if layer.fuse_qkv: weight = getattr(layer, "weight_qkv") bias = getattr(layer, "bias_qkv") From 0e0ca881412bbe8c32c0026df16fe83f230aa505 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Jan 2025 11:32:39 +0100 Subject: [PATCH 17/18] nits --- optimum/neuron/distributed/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 8d3664d2c..5d1169588 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -203,6 +203,8 @@ def __init__( kv_size_multiplier=kv_size_multiplier, ) + # This is a bug from neuronx_distributed: self.weight_qkv has no parallel attributes, which are actually needed. + # It should be fixed at the next release. if self.fuse_qkv: set_tensor_model_parallel_attributes( tensor=self.weight_qkv, @@ -801,9 +803,8 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( weight_data = linear_layer.weight.data if weight_data is not None: if proj_name in "kv": - output_size = layer.kv_output_size_per_partition weight_data = create_kv_proj_local_weight_from_regular_weight( - weight_data, kv_size_multiplier, output_size + weight_data, kv_size_multiplier, layer.kv_output_size_per_partition ) else: weight_data = create_query_or_output_projection_local_weight_from_regular_weight( From d2a7487b0e4b4222f61e1b91f82468b30c291512 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 16 Jan 2025 18:37:09 +0100 Subject: [PATCH 18/18] Fix lora --- optimum/neuron/distributed/utils.py | 2 ++ tests/peft/test_peft_training.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 5d1169588..46f82be5b 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -434,6 +434,7 @@ def _peft_tuner_embedding_to_parallel_embedding( config.init_lora_weights, config.use_rslora, config.use_dora, + config.lora_bias, ) mark_parameter_init_status_during_parallelization(parent.lora_embedding_A[adapter_name], True) mark_parameter_init_status_during_parallelization(parent.lora_embedding_B[adapter_name], True) @@ -1124,6 +1125,7 @@ def _peft_tuner_linear_to_parallel_linear( config.init_lora_weights, config.use_rslora, config.use_dora, + config.lora_bias, ) if axis == "row": layer_to_parallelize = parent.lora_A[adapter_name] diff --git a/tests/peft/test_peft_training.py b/tests/peft/test_peft_training.py index 976c1fed3..041dca330 100644 --- a/tests/peft/test_peft_training.py +++ b/tests/peft/test_peft_training.py @@ -158,7 +158,7 @@ def test_save_pretrained(self, parallel_sizes, tmpdir): print(f"Checking that the parameter {name} matches") torch.testing.assert_close(tensor, state_dict[name]) - def test_peft_training(self, parallel_sizes, tmpdir): + def test_training_peft_model(self, parallel_sizes, tmpdir): _, tp_size, pp_size = parallel_sizes per_device_train_batch_size = 1 @@ -177,6 +177,8 @@ def test_peft_training(self, parallel_sizes, tmpdir): ) tokenizer, model = get_tokenizer_and_tiny_llama_model() + peft_config = get_peft_config() + peft_model = get_peft_model(model, peft_config) num_train_samples = num_eval_samples = 50 datasets = create_dummy_causal_lm_dataset( @@ -184,7 +186,7 @@ def test_peft_training(self, parallel_sizes, tmpdir): ) trainer = NeuronTrainer( - model, + peft_model, args, tokenizer=tokenizer, train_dataset=datasets["train"],