From 8707043483d2084ef2566de3c3c9e21494d81515 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 18 Apr 2025 16:38:50 -0700 Subject: [PATCH] Remove args from LLMEdgeManager and misc cleanup (#10288) Summary: Decouple `export_llama`'s `argparse.namespace` from `LLMEdgeManager` + misc cleanup Test Plan: CI Reviewed By: iseeyuan Differential Revision: D73225183 Pulled By: jackzhxng --- examples/models/llama/export_llama_lib.py | 7 ++- examples/models/llava/export_llava.py | 2 - extension/llm/export/builder.py | 74 +++++++++++++---------- 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 2553f82139a..21bee7c6680 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -653,7 +653,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: _get_source_transforms( modelname=args.model, dtype_override=dtype_override, - checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), + checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore args=args, ) ) @@ -1106,7 +1106,7 @@ def _load_llama_model( return LLMEdgeManager( model=model, modelname=modelname, - max_seq_len=model.max_seq_len, + max_seq_len=model.max_seq_len, # type: ignore dtype=dtype_override, use_kv_cache=use_kv_cache, generate_full_logits=generate_full_logits, @@ -1119,6 +1119,8 @@ def _load_llama_model( calibration_seq_length=calibration_seq_length, calibration_data=calibration_data, tokenizer_path=tokenizer_path, + use_legacy_export=args.qnn, + save_exported_program=args.export_only, verbose=verbose, metadata=_load_llama_model_metadata( weight_type, @@ -1139,7 +1141,6 @@ def _load_llama_model( model.vocab_size, metadata_str, ), - args=args, ) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 5fcddb610b7..66b61840866 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -92,7 +92,6 @@ def forward(self, input_pos, embeddings): use_kv_cache=True, example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings), dynamic_shapes=dynamic_shapes, - args=llava.text_model_args, ) dtype_override = DType.fp32 @@ -161,7 +160,6 @@ def forward(self, images): use_kv_cache=True, example_inputs=(resized,), dynamic_shapes=dynamic_shapes, - args=None, ) .export() .pt2e_quantize([quantizer]) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index d0cf9a2d9d2..2dee6b0954a 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -13,7 +13,7 @@ import contextlib import logging from enum import Enum -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch import torch @@ -81,14 +81,13 @@ class LLMEdgeManager: def __init__( self, - model, - modelname, - max_seq_len, - dtype, - use_kv_cache, - example_inputs, + model: torch.nn.Module, + modelname: str, + max_seq_len: int, + use_kv_cache: bool, + example_inputs: Tuple[torch.Tensor, ...], + dtype: Optional[DType] = None, example_kwarg_inputs: Optional[Dict] = None, - args: Optional[Any] = None, enable_dynamic_shape: bool = False, generate_full_logits: bool = False, calibration_tasks: Optional[List[str]] = None, @@ -99,36 +98,42 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, + use_legacy_export: bool = False, + save_exported_program: bool = False, ): + # Store necessary constructor arguments. self.model = model - # Note: treat this as the source of truth for the result of - # torch.export'ing a model. If the overall ExportedProgram is needed, - # make sure to re-export this graph module to persist any changes. See - # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921 - self.pre_autograd_graph_module: Optional[torch.nn.Module] = None self.modelname = modelname self.max_seq_len = max_seq_len - self.dtype = dtype + self.use_kv_cache = use_kv_cache self.example_inputs = example_inputs + self.dtype = dtype self.example_kwarg_inputs = example_kwarg_inputs - self.use_kv_cache = use_kv_cache - self.generate_full_logits = generate_full_logits self.enable_dynamic_shape = enable_dynamic_shape - self.verbose = verbose - self.metadata = metadata - self.applied_source_transforms = [] - self.edge_manager: Optional[EdgeProgramManager] = None - self.export_program = None - self.output_dir = "." - self.dynamic_shapes = dynamic_shapes - self._saved_pte_filename = None - self.args = args + self.generate_full_logits = generate_full_logits self.calibration_tasks = calibration_tasks self.calibration_limit = calibration_limit self.calibration_seq_length = calibration_seq_length self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path - self.canonical_passes = [RemoveRedundantTransposes()] + self.verbose = verbose + self.metadata = metadata + self.dynamic_shapes = dynamic_shapes + self.use_legacy_export = use_legacy_export + self.save_exported_program = save_exported_program + + # Note: treat this as the source of truth for the result of + # torch.export'ing a model. If the overall ExportedProgram is needed, + # make sure to re-export this graph module to persist any changes. See + # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921 + self.pre_autograd_graph_module: Optional[torch.nn.Module] = None + self.edge_manager: Optional[EdgeProgramManager] = None + self.canonical_passes = [ + RemoveRedundantTransposes() + ] # Graph transformations optimizations. + self.export_program = None # Final result of lowering to executorch. + self.output_dir = "." + self._saved_pte_filename = None def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -167,10 +172,9 @@ def source_transform( """ for transform in transforms: self.model = transform(self.model) - self.applied_source_transforms.extend(transforms) if self.verbose: - logging.info(f"Applied source transforms: {self.applied_source_transforms}") + logging.info(f"Applied source transforms: {transforms}") logging.info(f"Model after source transforms: {self.model}") return self @@ -209,8 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if hasattr(self.args, "qnn") and self.args.qnn: - # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR. + if self.use_legacy_export: + # TODO: for use cases such as qnn, which does not work with new, non-functional export IR. # See issue: https://github.com/pytorch/executorch/issues/7373 with patch.object( @@ -256,8 +260,12 @@ def export(self) -> "LLMEdgeManager": # Persisting those changes back to an ExportedProgram will require # an additional export(). self.pre_autograd_graph_module = exported_module.module() - if hasattr(self.args, "export_only") and self.args.export_only: - torch.export.save(exported_module, self.args.output_name) + if self.save_exported_program: + export_output = f"{self.modelname}.pt2" + logging.info( + f"Saving torch.export()/export_for_training() result to {export_output}" + ) + torch.export.save(exported_module, export_output) return self def run_canonical_optimizations(self): @@ -421,7 +429,7 @@ def export_to_edge(self) -> "LLMEdgeManager": self.export() override_export_behaviour = contextlib.nullcontext() - if hasattr(self.args, "qnn") and self.args.qnn: + if self.use_legacy_export: override_export_behaviour = patch.object( torch._utils_internal, "export_training_ir_rollout_check",