13
13
import contextlib
14
14
import logging
15
15
from enum import Enum
16
- from typing import Any , Callable , Dict , List , Optional
16
+ from typing import Any , Callable , Dict , List , Optional , Tuple
17
17
from unittest .mock import patch
18
18
19
19
import torch
@@ -81,14 +81,13 @@ class LLMEdgeManager:
81
81
82
82
def __init__ (
83
83
self ,
84
- model ,
85
- modelname ,
86
- max_seq_len ,
87
- dtype ,
88
- use_kv_cache ,
89
- example_inputs ,
84
+ model : torch . nn . Module ,
85
+ modelname : str ,
86
+ max_seq_len : int ,
87
+ use_kv_cache : bool ,
88
+ example_inputs : Tuple [ torch . Tensor , ...] ,
89
+ dtype : Optional [ DType ] = None ,
90
90
example_kwarg_inputs : Optional [Dict ] = None ,
91
- args : Optional [Any ] = None ,
92
91
enable_dynamic_shape : bool = False ,
93
92
generate_full_logits : bool = False ,
94
93
calibration_tasks : Optional [List [str ]] = None ,
@@ -99,36 +98,42 @@ def __init__(
99
98
verbose : bool = False ,
100
99
metadata : Optional [dict ] = None ,
101
100
dynamic_shapes : Optional [Any ] = None ,
101
+ use_legacy_export : bool = False ,
102
+ save_exported_program : bool = False ,
102
103
):
104
+ # Store necessary constructor arguments.
103
105
self .model = model
104
- # Note: treat this as the source of truth for the result of
105
- # torch.export'ing a model. If the overall ExportedProgram is needed,
106
- # make sure to re-export this graph module to persist any changes. See
107
- # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
108
- self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
109
106
self .modelname = modelname
110
107
self .max_seq_len = max_seq_len
111
- self .dtype = dtype
108
+ self .use_kv_cache = use_kv_cache
112
109
self .example_inputs = example_inputs
110
+ self .dtype = dtype
113
111
self .example_kwarg_inputs = example_kwarg_inputs
114
- self .use_kv_cache = use_kv_cache
115
- self .generate_full_logits = generate_full_logits
116
112
self .enable_dynamic_shape = enable_dynamic_shape
117
- self .verbose = verbose
118
- self .metadata = metadata
119
- self .applied_source_transforms = []
120
- self .edge_manager : Optional [EdgeProgramManager ] = None
121
- self .export_program = None
122
- self .output_dir = "."
123
- self .dynamic_shapes = dynamic_shapes
124
- self ._saved_pte_filename = None
125
- self .args = args
113
+ self .generate_full_logits = generate_full_logits
126
114
self .calibration_tasks = calibration_tasks
127
115
self .calibration_limit = calibration_limit
128
116
self .calibration_seq_length = calibration_seq_length
129
117
self .calibration_data = calibration_data
130
118
self .tokenizer_path = tokenizer_path
131
- self .canonical_passes = [RemoveRedundantTransposes ()]
119
+ self .verbose = verbose
120
+ self .metadata = metadata
121
+ self .dynamic_shapes = dynamic_shapes
122
+ self .use_legacy_export = use_legacy_export
123
+ self .save_exported_program = save_exported_program
124
+
125
+ # Note: treat this as the source of truth for the result of
126
+ # torch.export'ing a model. If the overall ExportedProgram is needed,
127
+ # make sure to re-export this graph module to persist any changes. See
128
+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
129
+ self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
130
+ self .edge_manager : Optional [EdgeProgramManager ] = None
131
+ self .canonical_passes = [
132
+ RemoveRedundantTransposes ()
133
+ ] # Graph transformations optimizations.
134
+ self .export_program = None # Final result of lowering to executorch.
135
+ self .output_dir = "."
136
+ self ._saved_pte_filename = None
132
137
133
138
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
134
139
"""
@@ -167,10 +172,9 @@ def source_transform(
167
172
"""
168
173
for transform in transforms :
169
174
self .model = transform (self .model )
170
- self .applied_source_transforms .extend (transforms )
171
175
172
176
if self .verbose :
173
- logging .info (f"Applied source transforms: { self . applied_source_transforms } " )
177
+ logging .info (f"Applied source transforms: { transforms } " )
174
178
logging .info (f"Model after source transforms: { self .model } " )
175
179
return self
176
180
@@ -209,8 +213,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
209
213
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
210
214
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
211
215
with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
212
- if hasattr ( self .args , "qnn" ) and self . args . qnn :
213
- # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
216
+ if self .use_legacy_export :
217
+ # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
214
218
# See issue: https://github.com/pytorch/executorch/issues/7373
215
219
216
220
with patch .object (
@@ -256,8 +260,12 @@ def export(self) -> "LLMEdgeManager":
256
260
# Persisting those changes back to an ExportedProgram will require
257
261
# an additional export().
258
262
self .pre_autograd_graph_module = exported_module .module ()
259
- if hasattr (self .args , "export_only" ) and self .args .export_only :
260
- torch .export .save (exported_module , self .args .output_name )
263
+ if self .save_exported_program :
264
+ export_output = f"{ self .modelname } .pt2"
265
+ logging .info (
266
+ f"Saving torch.export()/export_for_training() result to { export_output } "
267
+ )
268
+ torch .export .save (exported_module , export_output )
261
269
return self
262
270
263
271
def run_canonical_optimizations (self ):
@@ -421,7 +429,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
421
429
self .export ()
422
430
423
431
override_export_behaviour = contextlib .nullcontext ()
424
- if hasattr ( self .args , "qnn" ) and self . args . qnn :
432
+ if self .use_legacy_export :
425
433
override_export_behaviour = patch .object (
426
434
torch ._utils_internal ,
427
435
"export_training_ir_rollout_check" ,
0 commit comments