Skip to content

Commit 9053089

Browse files
authored
add ability to compare intermedidate outputs
Differential Revision: D80118857 Pull Request resolved: #13482
1 parent 3b23603 commit 9053089

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,16 +855,16 @@ def _to_edge_and_lower_llama_xnnpack(
855855

856856
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
857857
if generate_etrecord:
858-
raise NotImplementedError(
859-
"export_llama does not support XNNPack and generating ETRecord at the moment."
860-
)
858+
builder_exported.generate_etrecord = True
861859

862860
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
863861
partitioners
864862
)
865863
if verbose:
866864
print_delegation_info(builder.edge_manager.exported_program().graph_module)
867865

866+
# we need builder.export_program
867+
868868
return builder.to_executorch(passes=additional_passes)
869869

870870

extension/llm/export/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
metadata: Optional[dict] = None,
9797
dynamic_shapes: Optional[Any] = None,
9898
save_exported_program: bool = False,
99+
generate_etrecord: bool = False,
99100
):
100101
# Store necessary constructor arguments.
101102
self.model = model
@@ -116,6 +117,7 @@ def __init__(
116117
self.metadata = metadata
117118
self.dynamic_shapes = dynamic_shapes
118119
self.save_exported_program = save_exported_program
120+
self.generate_etrecord = generate_etrecord
119121

120122
# Note: treat this as the source of truth for the result of
121123
# torch.export'ing a model. If the overall ExportedProgram is needed,
@@ -481,6 +483,7 @@ def to_edge_transform_and_lower(
481483
partitioner=partitioners,
482484
compile_config=edge_config,
483485
constant_methods=self.metadata,
486+
generate_etrecord=self.generate_etrecord,
484487
)
485488
if self.verbose:
486489
logging.info(f"Exported graph:\n{self.edge_manager.exported_program()}")

0 commit comments

Comments
 (0)