Skip to content

Commit 43e029b

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Remove args from LLMEdgeManager and misc cleanup (#10288)
Summary: Decouple `export_llama`'s `argparse.namespace` from `LLMEdgeManager` + misc cleanup Test Plan: CI Differential Revision: D73225183 Pulled By: jackzhxng
1 parent 381ae5d commit 43e029b

File tree

3 files changed

+42
-34
lines changed

3 files changed

+42
-34
lines changed

examples/models/llama/export_llama_lib.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,8 @@ def _load_llama_model(
11191119
calibration_seq_length=calibration_seq_length,
11201120
calibration_data=calibration_data,
11211121
tokenizer_path=tokenizer_path,
1122+
use_legacy_export=args.qnn,
1123+
save_exported_program=args.export_only,
11221124
verbose=verbose,
11231125
metadata=_load_llama_model_metadata(
11241126
weight_type,
@@ -1139,7 +1141,6 @@ def _load_llama_model(
11391141
model.vocab_size,
11401142
metadata_str,
11411143
),
1142-
args=args,
11431144
)
11441145

11451146

examples/models/llava/export_llava.py

-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def forward(self, input_pos, embeddings):
9292
use_kv_cache=True,
9393
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
9494
dynamic_shapes=dynamic_shapes,
95-
args=llava.text_model_args,
9695
)
9796

9897
dtype_override = DType.fp32

extension/llm/export/builder.py

+40-32
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import contextlib
1414
import logging
1515
from enum import Enum
16-
from typing import Any, Callable, Dict, List, Optional
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1717
from unittest.mock import patch
1818

1919
import torch
@@ -80,14 +80,13 @@ class LLMEdgeManager:
8080

8181
def __init__(
8282
self,
83-
model,
84-
modelname,
85-
max_seq_len,
86-
dtype,
87-
use_kv_cache,
88-
example_inputs,
83+
model: torch.nn.Module,
84+
modelname: str,
85+
max_seq_len: int,
86+
dtype: DType,
87+
use_kv_cache: bool,
88+
example_inputs: Tuple[torch.Tensor, ...],
8989
example_kwarg_inputs: Optional[Dict] = None,
90-
args: Optional[Any] = None,
9190
enable_dynamic_shape: bool = False,
9291
generate_full_logits: bool = False,
9392
calibration_tasks: Optional[List[str]] = None,
@@ -98,36 +97,42 @@ def __init__(
9897
verbose: bool = False,
9998
metadata: Optional[dict] = None,
10099
dynamic_shapes: Optional[Any] = None,
100+
use_legacy_export: bool = False,
101+
save_exported_program: bool = False,
101102
):
103+
# Store necessary constructor arguments.
102104
self.model = model
103-
# Note: treat this as the source of truth for the result of
104-
# torch.export'ing a model. If the overall ExportedProgram is needed,
105-
# make sure to re-export this graph module to persist any changes. See
106-
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
107-
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
108105
self.modelname = modelname
109106
self.max_seq_len = max_seq_len
110107
self.dtype = dtype
108+
self.use_kv_cache = use_kv_cache
111109
self.example_inputs = example_inputs
112110
self.example_kwarg_inputs = example_kwarg_inputs
113-
self.use_kv_cache = use_kv_cache
114-
self.generate_full_logits = generate_full_logits
115111
self.enable_dynamic_shape = enable_dynamic_shape
116-
self.verbose = verbose
117-
self.metadata = metadata
118-
self.applied_source_transforms = []
119-
self.edge_manager: Optional[EdgeProgramManager] = None
120-
self.export_program = None
121-
self.output_dir = "."
122-
self.dynamic_shapes = dynamic_shapes
123-
self._saved_pte_filename = None
124-
self.args = args
112+
self.generate_full_logits = generate_full_logits
125113
self.calibration_tasks = calibration_tasks
126114
self.calibration_limit = calibration_limit
127115
self.calibration_seq_length = calibration_seq_length
128116
self.calibration_data = calibration_data
129117
self.tokenizer_path = tokenizer_path
130-
self.canonical_passes = [RemoveRedundantTransposes()]
118+
self.verbose = verbose
119+
self.metadata = metadata
120+
self.dynamic_shapes = dynamic_shapes
121+
self.use_legacy_export = use_legacy_export
122+
self.save_exported_program = save_exported_program
123+
124+
# Note: treat this as the source of truth for the result of
125+
# torch.export'ing a model. If the overall ExportedProgram is needed,
126+
# make sure to re-export this graph module to persist any changes. See
127+
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
128+
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
129+
self.edge_manager: Optional[EdgeProgramManager] = None
130+
self.canonical_passes = [
131+
RemoveRedundantTransposes()
132+
] # Graph transformations optimizations.
133+
self.export_program = None # Final result of lowering to executorch.
134+
self.output_dir = "."
135+
self._saved_pte_filename = None
131136

132137
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
133138
"""
@@ -166,10 +171,9 @@ def source_transform(
166171
"""
167172
for transform in transforms:
168173
self.model = transform(self.model)
169-
self.applied_source_transforms.extend(transforms)
170174

171175
if self.verbose:
172-
logging.info(f"Applied source transforms: {self.applied_source_transforms}")
176+
logging.info(f"Applied source transforms: {transforms}")
173177
logging.info(f"Model after source transforms: {self.model}")
174178
return self
175179

@@ -203,8 +207,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203207
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204208
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
205209
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
206-
if hasattr(self.args, "qnn") and self.args.qnn:
207-
# TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
210+
if self.use_legacy_export:
211+
# TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
208212
# See issue: https://github.com/pytorch/executorch/issues/7373
209213

210214
with patch.object(
@@ -250,8 +254,12 @@ def export(self) -> "LLMEdgeManager":
250254
# Persisting those changes back to an ExportedProgram will require
251255
# an additional export().
252256
self.pre_autograd_graph_module = exported_module.module()
253-
if hasattr(self.args, "export_only") and self.args.export_only:
254-
torch.export.save(exported_module, self.args.output_name)
257+
if self.save_exported_program:
258+
export_output = f"{self.modelname}.pt2"
259+
logging.info(
260+
f"Saving torch.export()/export_for_training() result to {export_output}"
261+
)
262+
torch.export.save(exported_module, export_output)
255263
return self
256264

257265
def run_canonical_optimizations(self):
@@ -415,7 +423,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
415423
self.export()
416424

417425
override_export_behaviour = contextlib.nullcontext()
418-
if hasattr(self.args, "qnn") and self.args.qnn:
426+
if self.use_legacy_export:
419427
override_export_behaviour = patch.object(
420428
torch._utils_internal,
421429
"export_training_ir_rollout_check",

0 commit comments

Comments
 (0)