Skip to content

Commit

Permalink
fix t5 decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Nov 4, 2024
1 parent 02c331d commit 5ec92d8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 5 additions & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,9 @@ def export_neuronx(
checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)

# Construct compiler configurations
logger.info(f"Using Neuron: --auto-cast {auto_cast}")
if auto_cast is not None:
logger.info(f"Using Neuron: --auto-cast {auto_cast}")

auto_cast = "matmult" if auto_cast == "matmul" else auto_cast
compiler_args = ["--auto-cast", auto_cast]

Expand All @@ -552,6 +552,10 @@ def export_neuronx(
compiler_args = ["--auto-cast", "none"]

compiler_args.extend(["--optlevel", optlevel])
logger.info(f"Using Neuron: --optlevel {optlevel}")

if getattr(config._config, "is_encoder_decoder", False):
compiler_args.extend(["--model-type", "transformer"])

compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific

Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ def update_past(self, past_key_values):

def reorder_cache(self, past_key_values, beam_idx):
for i in range(len(past_key_values)):
past_key_values[i] = torch.index_select(past_key_values[i], 0, beam_idx)
gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i])
past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index)
return past_key_values

def forward(
Expand Down

0 comments on commit 5ec92d8

Please sign in to comment.