Skip to content

Commit 8707043

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Remove args from LLMEdgeManager and misc cleanup (#10288)
Summary: Decouple `export_llama`'s `argparse.namespace` from `LLMEdgeManager` + misc cleanup Test Plan: CI Reviewed By: iseeyuan Differential Revision: D73225183 Pulled By: jackzhxng
1 parent cbca483 commit 8707043

File tree

3 files changed

+45
-38
lines changed

3 files changed

+45
-38
lines changed

examples/models/llama/export_llama_lib.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
653653
_get_source_transforms(
654654
modelname=args.model,
655655
dtype_override=dtype_override,
656-
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype),
656+
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
657657
args=args,
658658
)
659659
)
@@ -1106,7 +1106,7 @@ def _load_llama_model(
11061106
return LLMEdgeManager(
11071107
model=model,
11081108
modelname=modelname,
1109-
max_seq_len=model.max_seq_len,
1109+
max_seq_len=model.max_seq_len, # type: ignore
11101110
dtype=dtype_override,
11111111
use_kv_cache=use_kv_cache,
11121112
generate_full_logits=generate_full_logits,
@@ -1119,6 +1119,8 @@ def _load_llama_model(
11191119
calibration_seq_length=calibration_seq_length,
11201120
calibration_data=calibration_data,
11211121
tokenizer_path=tokenizer_path,
1122+
use_legacy_export=args.qnn,
1123+
save_exported_program=args.export_only,
11221124
verbose=verbose,
11231125
metadata=_load_llama_model_metadata(
11241126
weight_type,
@@ -1139,7 +1141,6 @@ def _load_llama_model(
11391141
model.vocab_size,
11401142
metadata_str,
11411143
),
1142-
args=args,
11431144
)
11441145

11451146

examples/models/llava/export_llava.py

-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def forward(self, input_pos, embeddings):
9292
use_kv_cache=True,
9393
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
9494
dynamic_shapes=dynamic_shapes,
95-
args=llava.text_model_args,
9695
)
9796

9897
dtype_override = DType.fp32
@@ -161,7 +160,6 @@ def forward(self, images):
161160
use_kv_cache=True,
162161
example_inputs=(resized,),
163162
dynamic_shapes=dynamic_shapes,
164-
args=None,
165163
)
166164
.export()
167165
.pt2e_quantize([quantizer])

extension/llm/export/builder.py

+41-33
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import contextlib
1414
import logging
1515
from enum import Enum
16-
from typing import Any, Callable, Dict, List, Optional
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1717
from unittest.mock import patch
1818

1919
import torch
@@ -81,14 +81,13 @@ class LLMEdgeManager:
8181

8282
def __init__(
8383
self,
84-
model,
85-
modelname,
86-
max_seq_len,
87-
dtype,
88-
use_kv_cache,
89-
example_inputs,
84+
model: torch.nn.Module,
85+
modelname: str,
86+
max_seq_len: int,
87+
use_kv_cache: bool,
88+
example_inputs: Tuple[torch.Tensor, ...],
89+
dtype: Optional[DType] = None,
9090
example_kwarg_inputs: Optional[Dict] = None,
91-
args: Optional[Any] = None,
9291
enable_dynamic_shape: bool = False,
9392
generate_full_logits: bool = False,
9493
calibration_tasks: Optional[List[str]] = None,
@@ -99,36 +98,42 @@ def __init__(
9998
verbose: bool = False,
10099
metadata: Optional[dict] = None,
101100
dynamic_shapes: Optional[Any] = None,
101+
use_legacy_export: bool = False,
102+
save_exported_program: bool = False,
102103
):
104+
# Store necessary constructor arguments.
103105
self.model = model
104-
# Note: treat this as the source of truth for the result of
105-
# torch.export'ing a model. If the overall ExportedProgram is needed,
106-
# make sure to re-export this graph module to persist any changes. See
107-
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
108-
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
109106
self.modelname = modelname
110107
self.max_seq_len = max_seq_len
111-
self.dtype = dtype
108+
self.use_kv_cache = use_kv_cache
112109
self.example_inputs = example_inputs
110+
self.dtype = dtype
113111
self.example_kwarg_inputs = example_kwarg_inputs
114-
self.use_kv_cache = use_kv_cache
115-
self.generate_full_logits = generate_full_logits
116112
self.enable_dynamic_shape = enable_dynamic_shape
117-
self.verbose = verbose
118-
self.metadata = metadata
119-
self.applied_source_transforms = []
120-
self.edge_manager: Optional[EdgeProgramManager] = None
121-
self.export_program = None
122-
self.output_dir = "."
123-
self.dynamic_shapes = dynamic_shapes
124-
self._saved_pte_filename = None
125-
self.args = args
113+
self.generate_full_logits = generate_full_logits
126114
self.calibration_tasks = calibration_tasks
127115
self.calibration_limit = calibration_limit
128116
self.calibration_seq_length = calibration_seq_length
129117
self.calibration_data = calibration_data
130118
self.tokenizer_path = tokenizer_path
131-
self.canonical_passes = [RemoveRedundantTransposes()]
119+
self.verbose = verbose
120+
self.metadata = metadata
121+
self.dynamic_shapes = dynamic_shapes
122+
self.use_legacy_export = use_legacy_export
123+
self.save_exported_program = save_exported_program
124+
125+
# Note: treat this as the source of truth for the result of
126+
# torch.export'ing a model. If the overall ExportedProgram is needed,
127+
# make sure to re-export this graph module to persist any changes. See
128+
# https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
129+
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
130+
self.edge_manager: Optional[EdgeProgramManager] = None
131+
self.canonical_passes = [
132+
RemoveRedundantTransposes()
133+
] # Graph transformations optimizations.
134+
self.export_program = None # Final result of lowering to executorch.
135+
self.output_dir = "."
136+
self._saved_pte_filename = None
132137

133138
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
134139
"""
@@ -167,10 +172,9 @@ def source_transform(
167172
"""
168173
for transform in transforms:
169174
self.model = transform(self.model)
170-
self.applied_source_transforms.extend(transforms)
171175

172176
if self.verbose:
173-
logging.info(f"Applied source transforms: {self.applied_source_transforms}")
177+
logging.info(f"Applied source transforms: {transforms}")
174178
logging.info(f"Model after source transforms: {self.model}")
175179
return self
176180

@@ -209,8 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
209213
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
210214
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
211215
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
212-
if hasattr(self.args, "qnn") and self.args.qnn:
213-
# TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
216+
if self.use_legacy_export:
217+
# TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
214218
# See issue: https://github.com/pytorch/executorch/issues/7373
215219

216220
with patch.object(
@@ -256,8 +260,12 @@ def export(self) -> "LLMEdgeManager":
256260
# Persisting those changes back to an ExportedProgram will require
257261
# an additional export().
258262
self.pre_autograd_graph_module = exported_module.module()
259-
if hasattr(self.args, "export_only") and self.args.export_only:
260-
torch.export.save(exported_module, self.args.output_name)
263+
if self.save_exported_program:
264+
export_output = f"{self.modelname}.pt2"
265+
logging.info(
266+
f"Saving torch.export()/export_for_training() result to {export_output}"
267+
)
268+
torch.export.save(exported_module, export_output)
261269
return self
262270

263271
def run_canonical_optimizations(self):
@@ -421,7 +429,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
421429
self.export()
422430

423431
override_export_behaviour = contextlib.nullcontext()
424-
if hasattr(self.args, "qnn") and self.args.qnn:
432+
if self.use_legacy_export:
425433
override_export_behaviour = patch.object(
426434
torch._utils_internal,
427435
"export_training_ir_rollout_check",

0 commit comments

Comments
 (0)