Skip to content

Commit

Permalink
fix + test
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Feb 9, 2024
1 parent ab582ce commit 0110d5d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 100 deletions.
25 changes: 25 additions & 0 deletions optimum/commands/export/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def parse_args_neuron(parser: "ArgumentParser"):
f" {str(list(TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.keys()) + list(TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS.keys()))}."
),
)
optional_group.add_argument(
"--library-name",
type=str,
choices=["transformers", "sentence_transformers"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library."),
)
optional_group.add_argument(
"--subfolder",
type=str,
default="",
help=(
"In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, specify the folder name here."
),
)
optional_group.add_argument(
"--atol",
type=float,
Expand All @@ -58,6 +73,16 @@ def parse_args_neuron(parser: "ArgumentParser"):
action="store_true",
help="Allow to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository.",
)
optional_group.add_argument(
"--compiler_workdir",
type=Path,
help="Path indicating the directory where to store intermediary files generated by Neuronx compiler.",
)
optional_group.add_argument(
"--disable-weights-neff-inline",
action="store_true",
help="Whether to disable the weights / neff graph inline. You can only replace weights of neuron-compiled models when the weights-neff inlining has been disabled during the compilation.",
)
optional_group.add_argument(
"--disable-validation",
action="store_true",
Expand Down
19 changes: 11 additions & 8 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,19 @@ def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]:

def parse_optlevel(args: argparse.Namespace) -> Dict[str, bool]:
"""
Parse the level of optimization the compiler should perform. If not specified apply `O2`(the best balance between model performance and compile time).
(NEURONX ONLY) Parse the level of optimization the compiler should perform. If not specified apply `O2`(the best balance between model performance and compile time).
"""
if args.O1:
optlevel = "1"
elif args.O2:
optlevel = "2"
elif args.O3:
optlevel = "3"
if is_neuronx_available():
if args.O1:
optlevel = "1"
elif args.O2:
optlevel = "2"
elif args.O3:
optlevel = "3"
else:
optlevel = "2"
else:
optlevel = "2"
optlevel = None
return optlevel


Expand Down
22 changes: 21 additions & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,17 @@ def export(
disable_fallback: bool = False,
) -> Tuple[List[str], List[str]]:
if is_neuron_available():
return export_neuron(model, config, output, auto_cast, auto_cast_type, disable_fast_relayout, disable_fallback)
return export_neuron(
model=model,
config=config,
output=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
disable_fast_relayout=disable_fast_relayout,
disable_fallback=disable_fallback,
)
elif is_neuronx_available():
return export_neuronx(
model=model,
Expand Down Expand Up @@ -570,6 +580,8 @@ def export_neuron(
model: "PreTrainedModel",
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
disable_fast_relayout: bool = False,
Expand All @@ -585,6 +597,10 @@ def export_neuron(
The Neuron configuration associated with the exported model.
output (`Path`):
Directory to store the exported Neuron model.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
auto_cast (`Optional[str]`, defaults to `None`):
Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `None`, `"matmul"` or `"all"`, you should use `None` to disable any auto-casting, use `"matmul"` to cast FP32 matrix multiplication operations, and use `"all"` to cast all FP32 operations.
auto_cast_type (`str`, defaults to `"bf16"`):
Expand All @@ -599,6 +615,8 @@ def export_neuron(
the Neuron configuration.
"""
output.parent.mkdir(parents=True, exist_ok=True)
if isinstance(compiler_workdir, Path):
compiler_workdir = compiler_workdir.as_posix()

if hasattr(model, "config"):
model.config.return_dict = True
Expand Down Expand Up @@ -626,6 +644,8 @@ def export_neuron(
dummy_inputs_tuple,
dynamic_batch_size=config.dynamic_batch_size,
compiler_args=compiler_args,
compiler_workdir=compiler_workdir,
separate_weights=not inline_weights_to_neff,
fallback=not disable_fallback,
)
torch.jit.save(neuron_model, output)
Expand Down
113 changes: 22 additions & 91 deletions tests/cli/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,102 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import tempfile
import unittest
from itertools import product
from typing import Dict, Optional

from optimum.exporters.neuron.model_configs import * # noqa: F403
from optimum.exporters.tasks import TasksManager
from optimum.neuron.utils import is_neuron_available, is_neuronx_available
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.utils import DEFAULT_DUMMY_SHAPES, logging
from optimum.utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


_COMMOM_COMMANDS = {
"--auto_cast": ["none", "matmul", "all"],
"--auto_cast_type": ["bf16", "fp16"], # "tf32", "mixed"
}
_NEURON_COMMANDS = {}
_NEURONX_COMMANDS = {}
_DYNAMIC_COMMANDS = {"neuron": ["--disable-fast-relayout"], "neuronx": []}


def _get_models_to_test(export_models_dict: Dict, random_pick: Optional[int] = 1):
models_to_test = []
for model_type, model_names_tasks in export_models_dict.items():
model_type = model_type.replace("_", "-")
task_config_mapping = TasksManager.get_supported_tasks_for_model_type(model_type, "neuron")

if isinstance(model_names_tasks, str): # test export of all tasks on the same model
tasks = list(task_config_mapping.keys())
model_tasks = {model_names_tasks: tasks}
else:
n_tested_tasks = sum(len(tasks) for tasks in model_names_tasks.values())
if n_tested_tasks != len(task_config_mapping):
logger.warning(f"Not all tasks are tested for {model_type}.")
model_tasks = model_names_tasks # possibly, test different tasks on different models

for model_name, tasks in model_tasks.items():
for task in tasks:
default_shapes = dict(DEFAULT_DUMMY_SHAPES)
TasksManager.get_exporter_config_constructor(
model_type=model_type,
exporter="neuron",
task=task,
model_name=model_name,
exporter_config_kwargs={**default_shapes},
)

models_to_test.append((f"{model_type}_{task}", model_name, task))

if random_pick is not None:
return sorted(random.choices(models_to_test, k=random_pick))
else:
return sorted(models_to_test)


def _get_commands_to_test(models_to_test):
commands_to_test = []
for test_name, model_name, task in models_to_test:
if is_neuron_available():
command_items = dict(_COMMOM_COMMANDS, **_NEURON_COMMANDS)
dynamic_args = _DYNAMIC_COMMANDS["neuron"]
elif is_neuronx_available():
command_items = dict(_COMMOM_COMMANDS, **_NEURONX_COMMANDS)
dynamic_args = _DYNAMIC_COMMANDS["neuronx"]
else:
continue

base_command = f"optimum-cli export neuron --model {model_name} --task {task}"

# mandatory shape arguments
model = TasksManager.get_model_from_task(task, model_name, framework="pt")
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="neuron", task=task
)
for axis in neuron_config_constructor.func.get_mandatory_axes_for_task(task):
default_size = DEFAULT_DUMMY_SHAPES[axis]
base_command += f" --{axis} {default_size}"

# compilation arguments
for extra_arg_options in product(*command_items.values()):
extra_command = " ".join(
[" ".join([arg, option]) for arg, option in zip(command_items, extra_arg_options)]
)
extra_command += " " + " ".join(random.choices(dynamic_args, k=random.randint(0, len(dynamic_args))))
command = base_command + " " + extra_command

commands_to_test.append((test_name + extra_command.strip(), command))

return sorted(commands_to_test)


@is_inferentia_test
class TestExportCLI(unittest.TestCase):
def test_helps_no_raise(self):
Expand All @@ -121,12 +37,27 @@ def test_helps_no_raise(self):
for command in commands:
subprocess.run(command, shell=True, check=True)

# @parameterized.expand(_get_commands_to_test(_get_models_to_test(EXPORT_MODELS_TINY)), skip_on_empty=True)
# def test_export_commands(self, test_name, command_content):
# with tempfile.TemporaryDirectory() as tempdir:
# command = command_content + f" {tempdir}"

# subprocess.run(command, shell=True, check=True)
def test_export_commands(self):
model_id = "hf-internal-testing/tiny-random-BertModel"
with tempfile.TemporaryDirectory() as tempdir:
subprocess.run(
[
"optimum-cli",
"export",
"neuron",
"--model",
model_id,
"--sequence_length",
"16",
"--batch_size",
"1",
"--task",
"text-classification",
tempdir,
],
shell=False,
check=True,
)

@requires_neuronx
def test_dynamic_batching(self):
Expand Down

0 comments on commit 0110d5d

Please sign in to comment.