Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 11, 2024
2 parents bf1b408 + 6ba3db9 commit a0e4184
Show file tree
Hide file tree
Showing 61 changed files with 807 additions and 171 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/inference_cache_llm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/inference_cache_stable_diffusion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.18.3.0 aws-neuronx-runtime-lib=2.21.41.0-fb1705f5f aws-neuronx-collectives=2.21.46.0-69b77134b -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_inf2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_inf2_export.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_inf2_full_export.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_inf2_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Install cv2 dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_inf2_tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Checkout
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_trainium_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Install cv2 dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_trainium_distributed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Install cv2 dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_trainium_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.14.0-6e27b8d5b aws-neuronx-collectives=2.22.26.0-17a033bc8 -y
sudo apt-get install aws-neuronx-tools=2.19.0.0 aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 aws-neuronx-collectives=2.22.33.0-d2128d1aa -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Install cv2 dependencies
run: |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
MODEL_ID='Qwen/Qwen2.5-7B-Instruct'
HF_AUTO_CAST_TYPE='bf16'
MAX_BATCH_SIZE=32
MAX_INPUT_TOKENS=4000
MAX_TOTAL_TOKENS=4096
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
version: '3.7'

services:
tgi-1:
image: neuronx-tgi:latest
ports:
- "8081:8081"
environment:
- PORT=8081
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron0"
- "/dev/neuron1"
- "/dev/neuron2"
- "/dev/neuron3"

tgi-2:
image: neuronx-tgi:latest
ports:
- "8082:8082"
environment:
- PORT=8082
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron4"
- "/dev/neuron5"
- "/dev/neuron6"
- "/dev/neuron7"

tgi-3:
image: neuronx-tgi:latest
ports:
- "8083:8083"
environment:
- PORT=8083
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron8"
- "/dev/neuron9"
- "/dev/neuron10"
- "/dev/neuron11"

loadbalancer:
image: nginx:alpine
ports:
- "8080:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- tgi-1
- tgi-2
- tgi-3
deploy:
placement:
constraints: [node.role == manager]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
### Nginx TGI Load Balancer
events {}
http {
upstream tgicluster {
server tgi-1:8081;
server tgi-2:8082;
server tgi-3:8083;
}
server {
listen 80;
location / {
proxy_pass http://tgicluster;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_id,Date,Input type,Requests per Second,Request Latency (s),Time-to-first-token (ms),Inter Token Latency (ms),Output Token Throughput (t/s)
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,synchronous,0.16124266966166817,6.200973322516994,309.0427423778333,25.97485797662497,36.57662664430473
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,0.49461558754572243,11.853755130606183,361.0207387956522,48.287324351631526,117.7268931509251
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,0.8060968082815412,16.24768308375744,375.21653479718145,67.57783339749032,189.26981548491378
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,1.083945791108799,21.60137509382688,391.8051444567167,90.79909233959562,253.1763846248275
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,1.360321529639815,22.870551178060428,896.7999958553197,94.77224932706102,315.15228174103277
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,1.6004688460192356,27.518067228297394,1464.1120346883934,112.11173711716121,371.45881623077696
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,1.8374073942778475,29.824548766450974,1626.0160196174695,122.31885821081055,423.5491627411547
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,2.0547734036381797,33.20240214091389,2375.624083671249,133.96148232126046,472.60651633847726
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,2.0780593811446972,40.66464872033365,8195.832600516658,138.66332340499426,486.5759282406912
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,[email protected] req/sec,2.116392255309062,36.28229375148383,4732.812661824264,134.96114258046998,494.68585904605914
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,throughput,3.6428876172319473,25.543468462793452,7593.348495583786,77.56031301334828,844.191272561698
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


def get_node_results(node_url):

metrics = requests.get(node_url + "/metrics").text

counters = {
Expand Down
7 changes: 7 additions & 0 deletions benchmark/text-generation/accuracy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ You can evaluate:
| | |none | 0|acc_norm ||0.7581|± |0.0043|
|lambada_openai| 1|none | 0|acc ||0.7173|± |0.0063|
| | |none | 0|perplexity ||3.1102|± |0.0769|

### Qwen/Qwen2.5-Math-7B-Instruct

|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match||0.8878|± |0.0087|
| | |strict-match | 5|exact_match||0.8870|± |0.0087|
2 changes: 1 addition & 1 deletion docs/source/containers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ print(f"llm image uri: {llm_image}")
| Type | Optimum Version | Image URI |
|-----------------------------|-----------------|---------------------------------------------|
| Training | 0.0.24 | `763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training-neuronx:2.1.2-transformers4.41.1-neuronx-py310-sdk2.19.1-ubuntu20.04` |
| Training | 0.0.25 | `763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training-neuronx:2.1.2-transformers4.43.2-neuronx-py310-sdk2.20.0-ubuntu20.04` |
| Inference | 0.0.25 | `763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference-neuronx:2.1.2-transformers4.43.2-neuronx-py310-sdk2.20.0-ubuntu20.04` |
| Text Generation Inference | 0.0.25 | `763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.2-optimum0.0.25-neuronx-py310-ubuntu22.04` |
Expand Down
19 changes: 10 additions & 9 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import torch
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig

Expand All @@ -29,8 +30,8 @@
DIFFUSION_MODEL_CONTROLNET_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_TRANSFORMER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
Expand All @@ -52,8 +53,8 @@
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
check_mandatory_input_shapes,
get_encoder_decoder_models_for_export,
get_diffusion_models_for_export,
get_encoder_decoder_models_for_export,
replace_stable_diffusion_submodels,
)

Expand Down Expand Up @@ -216,7 +217,9 @@ def infer_stable_diffusion_shapes_from_diffusers(
elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
max_sequence_length = model.tokenizer_2.model_max_length
else:
raise AttributeError(f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute.")
raise AttributeError(
f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute."
)
vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
Expand All @@ -230,7 +233,7 @@ def infer_stable_diffusion_shapes_from_diffusers(
input_shapes["text_encoder"].update({"sequence_length": max_sequence_length})
if hasattr(model, "text_encoder_2"):
input_shapes["text_encoder_2"] = input_shapes["text_encoder"]

# UNet or Transformer
unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet"
unet_or_transformer_num_channels = getattr(model, unet_or_transformer_name).config.in_channels
Expand All @@ -245,9 +248,9 @@ def infer_stable_diffusion_shapes_from_diffusers(
input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length
input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor
input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer")
if unet_or_transformer_name=="transformer":
if unet_or_transformer_name == "transformer":
input_shapes[unet_or_transformer_name]["encoder_hidden_size"] = model.text_encoder.config.hidden_size

# VAE
input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width})
input_shapes["vae_decoder"].update(
Expand Down Expand Up @@ -435,9 +438,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
if getattr(model, "unet", None) is not None:
output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(
DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME
)
output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME)
if getattr(model, "transformer", None) is not None:
output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join(
DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME
Expand Down
27 changes: 20 additions & 7 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import torch

from optimum.utils import logging

from ...exporters.base import ExportConfig
from ...neuron.utils import is_neuron_available, is_transformers_neuronx_available
from optimum.utils import logging


if TYPE_CHECKING:
Expand Down Expand Up @@ -323,7 +324,9 @@ def generate_dummy_inputs(
float_dtype = mapper[self.float_dtype]
else:
float_dtype = self.float_dtype
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework="pt", int_dtype=self.int_dtype, float_dtype=float_dtype)
dummy_inputs[input_name] = dummy_input_gen.generate(
input_name, framework="pt", int_dtype=self.int_dtype, float_dtype=float_dtype
)
input_was_inserted = True
break
if not input_was_inserted:
Expand Down Expand Up @@ -450,17 +453,23 @@ class NeuronDecoderConfig(NeuronConfig):
NEURONX_CLASS = None
CONTINUOUS_BATCHING = False
ATTENTION_lAYOUT = "HSB"
FUSE_QKV = True

def __init__(self, task: str):
if not is_transformers_neuronx_available():
raise ModuleNotFoundError(
"The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]."
)
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(f"{class_name} not found in {module_name}. Please check transformers-neuronx version.")
if isinstance(self.NEURONX_CLASS, type):
self._neuronx_class = self.NEURONX_CLASS
else:
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(
f"{class_name} not found in {module_name}. Please check transformers-neuronx version."
)

@property
def neuronx_class(self):
Expand All @@ -473,3 +482,7 @@ def continuous_batching(self):
@property
def attention_layout(self):
return self.ATTENTION_lAYOUT

@property
def fuse_qkv(self):
return self.FUSE_QKV
10 changes: 6 additions & 4 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def validate_model_outputs(
ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
neuron_inputs = tuple(inputs.values())
elif any(pattern in getattr(config._config, "_class_name", "").lower() for pattern in ["controlnet", "transformer"]):
elif any(
pattern in getattr(config._config, "_class_name", "").lower() for pattern in ["controlnet", "transformer"]
):
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
neuron_inputs = ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
Expand Down Expand Up @@ -253,8 +255,8 @@ def validate_model_outputs(
ref_output = torch.stack(ref_outputs[name])
neuron_output = torch.stack(neuron_output)
elif isinstance(neuron_output, list):
ref_output = [output for output in ref_outputs[name]]
neuron_output = [output for output in neuron_output]
ref_output = ref_outputs[name]
neuron_output = neuron_output

logger.info(f'\t- Validating Neuron Model output "{name}":')

Expand Down Expand Up @@ -529,7 +531,7 @@ def export_neuronx(
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)

dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs = config.flatten_inputs(dummy_inputs)
dummy_inputs_tuple = tuple(dummy_inputs.values())

Expand Down
Loading

0 comments on commit a0e4184

Please sign in to comment.