Skip to content

Remove args from LLMEdgeManager and misc cleanup #10288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
_get_source_transforms(
modelname=args.model,
dtype_override=dtype_override,
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype),
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
args=args,
)
)
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def _load_llama_model(
return LLMEdgeManager(
model=model,
modelname=modelname,
max_seq_len=model.max_seq_len,
max_seq_len=model.max_seq_len, # type: ignore
dtype=dtype_override,
use_kv_cache=use_kv_cache,
generate_full_logits=generate_full_logits,
Expand All @@ -1119,6 +1119,8 @@ def _load_llama_model(
calibration_seq_length=calibration_seq_length,
calibration_data=calibration_data,
tokenizer_path=tokenizer_path,
use_legacy_export=args.qnn,
save_exported_program=args.export_only,
verbose=verbose,
metadata=_load_llama_model_metadata(
weight_type,
Expand All @@ -1139,7 +1141,6 @@ def _load_llama_model(
model.vocab_size,
metadata_str,
),
args=args,
)


Expand Down
2 changes: 0 additions & 2 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def forward(self, input_pos, embeddings):
use_kv_cache=True,
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
dynamic_shapes=dynamic_shapes,
args=llava.text_model_args,
)

dtype_override = DType.fp32
Expand Down Expand Up @@ -161,7 +160,6 @@ def forward(self, images):
use_kv_cache=True,
example_inputs=(resized,),
dynamic_shapes=dynamic_shapes,
args=None,
)
.export()
.pt2e_quantize([quantizer])
Expand Down
74 changes: 41 additions & 33 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import contextlib
import logging
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -81,14 +81,13 @@ class LLMEdgeManager:

def __init__(
self,
model,
modelname,
max_seq_len,
dtype,
use_kv_cache,
example_inputs,
model: torch.nn.Module,
modelname: str,
max_seq_len: int,
use_kv_cache: bool,
example_inputs: Tuple[torch.Tensor, ...],
dtype: Optional[DType] = None,
example_kwarg_inputs: Optional[Dict] = None,
args: Optional[Any] = None,
enable_dynamic_shape: bool = False,
generate_full_logits: bool = False,
calibration_tasks: Optional[List[str]] = None,
Expand All @@ -99,36 +98,42 @@ def __init__(
verbose: bool = False,
metadata: Optional[dict] = None,
dynamic_shapes: Optional[Any] = None,
use_legacy_export: bool = False,
save_exported_program: bool = False,
):
# Store necessary constructor arguments.
self.model = model
# Note: treat this as the source of truth for the result of
# torch.export'ing a model. If the overall ExportedProgram is needed,
# make sure to re-export this graph module to persist any changes. See
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
self.modelname = modelname
self.max_seq_len = max_seq_len
self.dtype = dtype
self.use_kv_cache = use_kv_cache
self.example_inputs = example_inputs
self.dtype = dtype
self.example_kwarg_inputs = example_kwarg_inputs
self.use_kv_cache = use_kv_cache
self.generate_full_logits = generate_full_logits
self.enable_dynamic_shape = enable_dynamic_shape
self.verbose = verbose
self.metadata = metadata
self.applied_source_transforms = []
self.edge_manager: Optional[EdgeProgramManager] = None
self.export_program = None
self.output_dir = "."
self.dynamic_shapes = dynamic_shapes
self._saved_pte_filename = None
self.args = args
self.generate_full_logits = generate_full_logits
self.calibration_tasks = calibration_tasks
self.calibration_limit = calibration_limit
self.calibration_seq_length = calibration_seq_length
self.calibration_data = calibration_data
self.tokenizer_path = tokenizer_path
self.canonical_passes = [RemoveRedundantTransposes()]
self.verbose = verbose
self.metadata = metadata
self.dynamic_shapes = dynamic_shapes
self.use_legacy_export = use_legacy_export
self.save_exported_program = save_exported_program

# Note: treat this as the source of truth for the result of
# torch.export'ing a model. If the overall ExportedProgram is needed,
# make sure to re-export this graph module to persist any changes. See
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
self.edge_manager: Optional[EdgeProgramManager] = None
self.canonical_passes = [
RemoveRedundantTransposes()
] # Graph transformations optimizations.
self.export_program = None # Final result of lowering to executorch.
self.output_dir = "."
self._saved_pte_filename = None

def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -167,10 +172,9 @@ def source_transform(
"""
for transform in transforms:
self.model = transform(self.model)
self.applied_source_transforms.extend(transforms)

if self.verbose:
logging.info(f"Applied source transforms: {self.applied_source_transforms}")
logging.info(f"Applied source transforms: {transforms}")
logging.info(f"Model after source transforms: {self.model}")
return self

Expand Down Expand Up @@ -209,8 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if hasattr(self.args, "qnn") and self.args.qnn:
# TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
if self.use_legacy_export:
# TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
# See issue: https://github.com/pytorch/executorch/issues/7373

with patch.object(
Expand Down Expand Up @@ -256,8 +260,12 @@ def export(self) -> "LLMEdgeManager":
# Persisting those changes back to an ExportedProgram will require
# an additional export().
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
if self.save_exported_program:
export_output = f"{self.modelname}.pt2"
logging.info(
f"Saving torch.export()/export_for_training() result to {export_output}"
)
torch.export.save(exported_module, export_output)
return self

def run_canonical_optimizations(self):
Expand Down Expand Up @@ -421,7 +429,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
self.export()

override_export_behaviour = contextlib.nullcontext()
if hasattr(self.args, "qnn") and self.args.qnn:
if self.use_legacy_export:
override_export_behaviour = patch.object(
torch._utils_internal,
"export_training_ir_rollout_check",
Expand Down
Loading