1313import contextlib
1414import logging
1515from enum import Enum
16- from typing import Any , Callable , Dict , List , Optional
16+ from typing import Any , Callable , Dict , List , Optional , Tuple
1717from unittest .mock import patch
1818
1919import torch
@@ -80,14 +80,13 @@ class LLMEdgeManager:
8080
8181 def __init__ (
8282 self ,
83- model ,
84- modelname ,
85- max_seq_len ,
86- dtype ,
87- use_kv_cache ,
88- example_inputs ,
83+ model : torch . nn . Module ,
84+ modelname : str ,
85+ max_seq_len : int ,
86+ dtype : DType ,
87+ use_kv_cache : bool ,
88+ example_inputs : Tuple [ torch . Tensor , ...] ,
8989 example_kwarg_inputs : Optional [Dict ] = None ,
90- args : Optional [Any ] = None ,
9190 enable_dynamic_shape : bool = False ,
9291 generate_full_logits : bool = False ,
9392 calibration_tasks : Optional [List [str ]] = None ,
@@ -98,36 +97,42 @@ def __init__(
9897 verbose : bool = False ,
9998 metadata : Optional [dict ] = None ,
10099 dynamic_shapes : Optional [Any ] = None ,
100+ use_legacy_export : bool = False ,
101+ save_exported_program : bool = False ,
101102 ):
103+ # Store necessary constructor arguments.
102104 self .model = model
103- # Note: treat this as the source of truth for the result of
104- # torch.export'ing a model. If the overall ExportedProgram is needed,
105- # make sure to re-export this graph module to persist any changes. See
106- # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
107- self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
108105 self .modelname = modelname
109106 self .max_seq_len = max_seq_len
110107 self .dtype = dtype
108+ self .use_kv_cache = use_kv_cache
111109 self .example_inputs = example_inputs
112110 self .example_kwarg_inputs = example_kwarg_inputs
113- self .use_kv_cache = use_kv_cache
114- self .generate_full_logits = generate_full_logits
115111 self .enable_dynamic_shape = enable_dynamic_shape
116- self .verbose = verbose
117- self .metadata = metadata
118- self .applied_source_transforms = []
119- self .edge_manager : Optional [EdgeProgramManager ] = None
120- self .export_program = None
121- self .output_dir = "."
122- self .dynamic_shapes = dynamic_shapes
123- self ._saved_pte_filename = None
124- self .args = args
112+ self .generate_full_logits = generate_full_logits
125113 self .calibration_tasks = calibration_tasks
126114 self .calibration_limit = calibration_limit
127115 self .calibration_seq_length = calibration_seq_length
128116 self .calibration_data = calibration_data
129117 self .tokenizer_path = tokenizer_path
130- self .canonical_passes = [RemoveRedundantTransposes ()]
118+ self .verbose = verbose
119+ self .metadata = metadata
120+ self .dynamic_shapes = dynamic_shapes
121+ self .use_legacy_export = use_legacy_export
122+ self .save_exported_program = save_exported_program
123+
124+ # Note: treat this as the source of truth for the result of
125+ # torch.export'ing a model. If the overall ExportedProgram is needed,
126+ # make sure to re-export this graph module to persist any changes. See
127+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
128+ self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
129+ self .edge_manager : Optional [EdgeProgramManager ] = None
130+ self .canonical_passes = [
131+ RemoveRedundantTransposes ()
132+ ] # Graph transformations optimizations.
133+ self .export_program = None # Final result of lowering to executorch.
134+ self .output_dir = "."
135+ self ._saved_pte_filename = None
131136
132137 def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
133138 """
@@ -166,10 +171,9 @@ def source_transform(
166171 """
167172 for transform in transforms :
168173 self .model = transform (self .model )
169- self .applied_source_transforms .extend (transforms )
170174
171175 if self .verbose :
172- logging .info (f"Applied source transforms: { self . applied_source_transforms } " )
176+ logging .info (f"Applied source transforms: { transforms } " )
173177 logging .info (f"Model after source transforms: { self .model } " )
174178 return self
175179
@@ -203,8 +207,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203207 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204208 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
205209 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
206- if hasattr ( self .args , "qnn" ) and self . args . qnn :
207- # TODO: this is temporary, as qnn flow does not work with new, non-functional export IR.
210+ if self .use_legacy_export :
211+ # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
208212 # See issue: https://github.com/pytorch/executorch/issues/7373
209213
210214 with patch .object (
@@ -250,8 +254,12 @@ def export(self) -> "LLMEdgeManager":
250254 # Persisting those changes back to an ExportedProgram will require
251255 # an additional export().
252256 self .pre_autograd_graph_module = exported_module .module ()
253- if hasattr (self .args , "export_only" ) and self .args .export_only :
254- torch .export .save (exported_module , self .args .output_name )
257+ if self .save_exported_program :
258+ export_output = f"{ self .modelname } .pt2"
259+ logging .info (
260+ f"Saving torch.export()/export_for_training() result to { export_output } "
261+ )
262+ torch .export .save (exported_module , export_output )
255263 return self
256264
257265 def run_canonical_optimizations (self ):
@@ -415,7 +423,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
415423 self .export ()
416424
417425 override_export_behaviour = contextlib .nullcontext ()
418- if hasattr ( self .args , "qnn" ) and self . args . qnn :
426+ if self .use_legacy_export :
419427 override_export_behaviour = patch .object (
420428 torch ._utils_internal ,
421429 "export_training_ir_rollout_check" ,
0 commit comments