13
13
import contextlib
14
14
import logging
15
15
from enum import Enum
16
- from typing import Any , Callable , Dict , List , Optional
16
+ from typing import Any , Callable , Dict , List , Optional , Tuple
17
17
from unittest .mock import patch
18
18
19
19
import torch
@@ -80,14 +80,13 @@ class LLMEdgeManager:
80
80
81
81
def __init__ (
82
82
self ,
83
- model ,
84
- modelname ,
85
- max_seq_len ,
86
- dtype ,
87
- use_kv_cache ,
88
- example_inputs ,
83
+ model : torch . nn . Module ,
84
+ modelname : str ,
85
+ max_seq_len : int ,
86
+ dtype : DType ,
87
+ use_kv_cache : bool ,
88
+ example_inputs : Tuple [ torch . Tensor , ...] ,
89
89
example_kwarg_inputs : Optional [Dict ] = None ,
90
- args : Optional [Any ] = None ,
91
90
enable_dynamic_shape : bool = False ,
92
91
generate_full_logits : bool = False ,
93
92
calibration_tasks : Optional [List [str ]] = None ,
@@ -98,36 +97,42 @@ def __init__(
98
97
verbose : bool = False ,
99
98
metadata : Optional [dict ] = None ,
100
99
dynamic_shapes : Optional [Any ] = None ,
100
+ use_legacy_export : bool = False ,
101
+ save_exported_program : bool = False ,
101
102
):
103
+ # Store necessary constructor arguments.
102
104
self .model = model
103
- # Note: treat this as the source of truth for the result of
104
- # torch.export'ing a model. If the overall ExportedProgram is needed,
105
- # make sure to re-export this graph module to persist any changes. See
106
- # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
107
- self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
108
105
self .modelname = modelname
109
106
self .max_seq_len = max_seq_len
110
107
self .dtype = dtype
108
+ self .use_kv_cache = use_kv_cache
111
109
self .example_inputs = example_inputs
112
110
self .example_kwarg_inputs = example_kwarg_inputs
113
- self .use_kv_cache = use_kv_cache
114
- self .generate_full_logits = generate_full_logits
115
111
self .enable_dynamic_shape = enable_dynamic_shape
116
- self .verbose = verbose
117
- self .metadata = metadata
118
- self .applied_source_transforms = []
119
- self .edge_manager : Optional [EdgeProgramManager ] = None
120
- self .export_program = None
121
- self .output_dir = "."
122
- self .dynamic_shapes = dynamic_shapes
123
- self ._saved_pte_filename = None
124
- self .args = args
112
+ self .generate_full_logits = generate_full_logits
125
113
self .calibration_tasks = calibration_tasks
126
114
self .calibration_limit = calibration_limit
127
115
self .calibration_seq_length = calibration_seq_length
128
116
self .calibration_data = calibration_data
129
117
self .tokenizer_path = tokenizer_path
130
- self .canonical_passes = [RemoveRedundantTransposes ()]
118
+ self .verbose = verbose
119
+ self .metadata = metadata
120
+ self .dynamic_shapes = dynamic_shapes
121
+ self .use_legacy_export = use_legacy_export
122
+ self .save_exported_program = save_exported_program
123
+
124
+ # Note: treat this as the source of truth for the result of
125
+ # torch.export'ing a model. If the overall ExportedProgram is needed,
126
+ # make sure to re-export this graph module to persist any changes. See
127
+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
128
+ self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
129
+ self .edge_manager : Optional [EdgeProgramManager ] = None
130
+ self .canonical_passes = [
131
+ RemoveRedundantTransposes ()
132
+ ] # Graph transformations optimizations.
133
+ self .export_program = None # Final result of lowering to executorch.
134
+ self .output_dir = "."
135
+ self ._saved_pte_filename = None
131
136
132
137
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
133
138
"""
@@ -166,10 +171,9 @@ def source_transform(
166
171
"""
167
172
for transform in transforms :
168
173
self .model = transform (self .model )
169
- self .applied_source_transforms .extend (transforms )
170
174
171
175
if self .verbose :
172
- logging .info (f"Applied source transforms: { self . applied_source_transforms } " )
176
+ logging .info (f"Applied source transforms: { transforms } " )
173
177
logging .info (f"Model after source transforms: { self .model } " )
174
178
return self
175
179
@@ -203,8 +207,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203
207
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204
208
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
205
209
with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
206
- if hasattr ( self .args , "qnn" ) and self . args . qnn :
207
- # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
210
+ if self .use_legacy_export :
211
+ # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
208
212
# See issue: https://github.com/pytorch/executorch/issues/7373
209
213
210
214
with patch .object (
@@ -250,8 +254,12 @@ def export(self) -> "LLMEdgeManager":
250
254
# Persisting those changes back to an ExportedProgram will require
251
255
# an additional export().
252
256
self .pre_autograd_graph_module = exported_module .module ()
253
- if hasattr (self .args , "export_only" ) and self .args .export_only :
254
- torch .export .save (exported_module , self .args .output_name )
257
+ if self .save_exported_program :
258
+ export_output = f"{ self .modelname } .pt2"
259
+ logging .info (
260
+ f"Saving torch.export()/export_for_training() result to { export_output } "
261
+ )
262
+ torch .export .save (exported_module , export_output )
255
263
return self
256
264
257
265
def run_canonical_optimizations (self ):
@@ -415,7 +423,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
415
423
self .export ()
416
424
417
425
override_export_behaviour = contextlib .nullcontext ()
418
- if hasattr ( self .args , "qnn" ) and self . args . qnn :
426
+ if self .use_legacy_export :
419
427
override_export_behaviour = patch .object (
420
428
torch ._utils_internal ,
421
429
"export_training_ir_rollout_check" ,
0 commit comments