From 1d60cef42d626da72ac58a311a82d123cb35ba2a Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Wed, 18 Dec 2024 13:42:26 +0000 Subject: [PATCH] fix(qwen2): adapt to latest TnX --- optimum/neuron/models/qwen2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/models/qwen2/model.py b/optimum/neuron/models/qwen2/model.py index 8ee60d9b4..f34a5bd63 100644 --- a/optimum/neuron/models/qwen2/model.py +++ b/optimum/neuron/models/qwen2/model.py @@ -287,6 +287,7 @@ def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwar return padded_inputs, input_embeddings, *rst def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + original_input_ids = input_ids if last_token_id is not None: # preprocess_and_embed() has already been invoked rst = cache_ids, start_ids, last_token_id else: # invoke preprocess_and_embed() @@ -294,5 +295,5 @@ def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) inputs = input_embeddings if input_embeddings is not None else input_ids logits = self._forward(inputs, *rst) - logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + logits = self._postprocess(original_input_ids, logits, start_ids=start_ids, **kwargs) return logits