Skip to content

Commit 3ef2dfc

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Implement cpp deserializer. (pytorch#136398)
Differential Revision: D63206258 This diff introduces a mechanism to generate a json-compatible deserializer in cpp using nlohmann json (already being used by AOTI). Why we need this? Because there will be a lot of cases where people don't want to use Python to load the graph (e.g. cpp runtime), and instead they can use this header to deserialize the JSON graph. Every time we call update_schema.py to update the schema, the header will be auto generated and included into the source files. Pull Request resolved: pytorch#136398 Approved by: https://github.com/angelayi
1 parent f98c601 commit 3ef2dfc

15 files changed

+2657
-63
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
.github/scripts/gql_mocks.json linguist-generated=true
66
third_party/LICENSES_BUNDLED.txt linguist-generated=true
77
tools/build/bazel/requirements.txt linguist-generated=true
8+
torch/csrc/utils/generated_serialization_types.h linguist-generated=true

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ libtorch_python_core_sources = [
844844
"torch/csrc/fx/node.cpp",
845845
"torch/csrc/mps/Module.cpp",
846846
"torch/csrc/mtia/Module.cpp",
847+
"torch/csrc/export/pybind.cpp",
847848
"torch/csrc/inductor/aoti_package/pybind.cpp",
848849
"torch/csrc/inductor/aoti_runner/pybind.cpp",
849850
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",

scripts/export/update_schema.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
commit = schema_check.update_schema()
3131

32-
if os.path.exists(args.prefix + commit.path):
32+
if os.path.exists(args.prefix + commit.yaml_path):
3333
if commit.result["SCHEMA_VERSION"] < commit.base["SCHEMA_VERSION"]:
3434
raise RuntimeError(
3535
f"Schema version downgraded from {commit.base['SCHEMA_VERSION']} to {commit.result['SCHEMA_VERSION']}."
@@ -55,17 +55,28 @@
5555
+ f"Reason: {reason}"
5656
)
5757

58-
header = (
59-
"# @" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
58+
first_line = (
59+
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
6060
)
61-
header += f"\n# checksum<<{commit.checksum_result}>>"
62-
payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
61+
checksum = f"checksum<<{commit.checksum_result}>>"
62+
yaml_header = "# " + first_line
63+
yaml_header += "\n# " + checksum
64+
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
6365

64-
content = header + "\n" + payload
66+
cpp_header = "// " + first_line
67+
cpp_header += "\n// " + checksum
68+
cpp_header += "\n// clang-format off"
69+
cpp_header += "\n" + commit.cpp_header
70+
cpp_header += "\n// clang-format on"
71+
cpp_header += "\n"
72+
73+
yaml_content = yaml_header + "\n" + yaml_payload
6574

6675
if args.dry_run:
67-
print(content)
68-
print("\nWill write the above schema to" + args.prefix + commit.path)
76+
print(yaml_content)
77+
print("\nWill write the above schema to" + args.prefix + commit.yaml_path)
6978
else:
70-
with open(args.prefix + commit.path, "w") as f:
71-
f.write(content)
79+
with open(args.prefix + commit.yaml_path, "w") as f:
80+
f.write(yaml_content)
81+
with open(args.prefix + commit.cpp_header_path, "w") as f:
82+
f.write(cpp_header)

test/export/test_cpp_serdes.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Owner(s): ["oncall: export"]
2+
3+
4+
import torch
5+
from torch._export.serde.serialize import deserialize, serialize
6+
7+
8+
try:
9+
from . import test_export, testing
10+
except ImportError:
11+
import test_export # @manual=fbcode//caffe2/test:test_export-library
12+
import testing # @manual=fbcode//caffe2/test:test_export-library
13+
14+
from torch.export import export
15+
16+
17+
test_classes = {}
18+
19+
20+
def mocked_cpp_serdes_export(*args, **kwargs):
21+
ep = export(*args, **kwargs)
22+
try:
23+
payload = serialize(ep)
24+
except Exception:
25+
return ep
26+
cpp_ep = torch._C._export.deserialize_exported_program(payload.exported_program)
27+
loaded_json = torch._C._export.serialize_exported_program(cpp_ep)
28+
payload.exported_program = loaded_json.encode()
29+
loaded_ep = deserialize(payload)
30+
return loaded_ep
31+
32+
33+
def make_dynamic_cls(cls):
34+
cls_prefix = "CppSerdes"
35+
36+
test_class = testing.make_test_cls_with_mocked_export(
37+
cls,
38+
cls_prefix,
39+
"_cpp_serdes",
40+
mocked_cpp_serdes_export,
41+
xfail_prop="_expected_failure_cpp_serdes",
42+
)
43+
44+
test_classes[test_class.__name__] = test_class
45+
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
46+
globals()[test_class.__name__] = test_class
47+
test_class.__module__ = __name__
48+
49+
50+
tests = [
51+
test_export.TestExport,
52+
]
53+
for test in tests:
54+
make_dynamic_cls(test)
55+
del test
56+
57+
if __name__ == "__main__":
58+
from torch._dynamo.test_case import run_tests
59+
60+
run_tests()

test/export/test_export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,6 +2951,7 @@ def forward(self, x):
29512951
export(N(), inputs, dynamic_shapes=dynamic_shapes)
29522952

29532953
@testing.expectedFailureSerDer # no unbacked bindings after deserialization?
2954+
@testing.expectedFailureCppSerDes # no unbacked bindings after deserialization?
29542955
@testing.expectedFailureSerDerNonStrict
29552956
def test_unbacked_bindings_for_divisible_u_symint(self):
29562957
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
@@ -3673,6 +3674,7 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs):
36733674
self._test_export_same_as_eager(kw_func, args, kwargs)
36743675

36753676
@testing.expectedFailureSerDer # we don't save placeholder metadata
3677+
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
36763678
@testing.expectedFailureSerDerNonStrict
36773679
@testing.expectedFailureNonStrict
36783680
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
@@ -8078,6 +8080,7 @@ def forward(self, x):
80788080
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
80798081

80808082
@testing.expectedFailureRetraceabilityNonStrict
8083+
@testing.expectedFailureCppSerDes # dynamic shape serialization
80818084
def test_disable_forced_specializations_ok(self):
80828085
# check that we don't force specialization, and defer to runtime asserts
80838086
# with allow_complex_guards_as_runtime_asserts=True to successfully export
@@ -8198,6 +8201,7 @@ def forward(self, w, x, y, z):
81988201

81998202
# TODO requires_grad doesn't seem to work with serialization.
82008203
@testing.expectedFailureSerDer
8204+
@testing.expectedFailureCppSerDes
82018205
@testing.expectedFailureSerDerNonStrict
82028206
def test_preserve_requires_grad_placeholders(self):
82038207
class Module(torch.nn.Module):
@@ -8536,6 +8540,7 @@ def forward(self, x, y):
85368540
ep.graph_module.code
85378541
)
85388542

8543+
@testing.expectedFailureCppSerDes
85398544
def test_slice_with_floordiv(self):
85408545
# slice operation emits runtime assert s0//2 <= s1
85418546
class M1(torch.nn.Module):
@@ -9105,6 +9110,7 @@ def test_dynamic_shapes_serdes_user_errors(self):
91059110
_load_dynamic_shapes(spec, from_dict=True)
91069111

91079112
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
9113+
@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
91089114
@testing.expectedFailureSerDerNonStrict
91099115
@testing.expectedFailureRetraceabilityNonStrict
91109116
def test_dim_dynamic(self):

test/export/test_schema.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,13 @@ def test_schema_check(self):
106106
commit = _Commit(
107107
result=src,
108108
checksum_result="",
109-
path="",
109+
yaml_path="",
110110
additions=additions,
111111
subtractions=subtractions,
112112
base=dst,
113113
checksum_base="",
114+
cpp_header="",
115+
cpp_header_path="",
114116
)
115117
next_version, _ = check(commit)
116118
self.assertEqual(next_version, [4, 1])
@@ -138,11 +140,13 @@ def test_schema_check(self):
138140
commit = _Commit(
139141
result=src,
140142
checksum_result="",
141-
path="",
143+
yaml_path="",
142144
additions=additions,
143145
subtractions=subtractions,
144146
base=dst,
145147
checksum_base="",
148+
cpp_header="",
149+
cpp_header_path="",
146150
)
147151
next_version, _ = check(commit)
148152
self.assertEqual(next_version, [4, 1])
@@ -173,11 +177,13 @@ def test_schema_check(self):
173177
commit = _Commit(
174178
result=src,
175179
checksum_result="",
176-
path="",
180+
yaml_path="",
177181
additions=additions,
178182
subtractions=subtractions,
179183
base=dst,
180184
checksum_base="",
185+
cpp_header="",
186+
cpp_header_path="",
181187
)
182188
next_version, _ = check(commit)
183189
self.assertEqual(next_version, [3, 3])
@@ -231,11 +237,13 @@ def test_schema_check(self):
231237
commit = _Commit(
232238
result=src,
233239
checksum_result="",
234-
path="",
240+
yaml_path="",
235241
additions=additions,
236242
subtractions=subtractions,
237243
base=dst,
238244
checksum_base="",
245+
cpp_header="",
246+
cpp_header_path="",
239247
)
240248
next_version, _ = check(commit)
241249
self.assertEqual(next_version, [3, 3])
@@ -259,11 +267,13 @@ def test_schema_check(self):
259267
commit = _Commit(
260268
result=src,
261269
checksum_result="",
262-
path="",
270+
yaml_path="",
263271
additions=additions,
264272
subtractions=subtractions,
265273
base=dst,
266274
checksum_base="",
275+
cpp_header="",
276+
cpp_header_path="",
267277
)
268278
next_version, _ = check(commit)
269279
self.assertEqual(next_version, [3, 3])
@@ -294,11 +304,13 @@ def test_schema_check(self):
294304
commit = _Commit(
295305
result=src,
296306
checksum_result="",
297-
path="",
307+
yaml_path="",
298308
additions=additions,
299309
subtractions=subtractions,
300310
base=dst,
301311
checksum_base="",
312+
cpp_header="",
313+
cpp_header_path="",
302314
)
303315
next_version, _ = check(commit)
304316
self.assertEqual(next_version, [3, 3])
@@ -326,11 +338,13 @@ def test_schema_check(self):
326338
commit = _Commit(
327339
result=src,
328340
checksum_result="",
329-
path="",
341+
yaml_path="",
330342
additions=additions,
331343
subtractions=subtractions,
332344
base=dst,
333345
checksum_base="",
346+
cpp_header="",
347+
cpp_header_path="",
334348
)
335349
next_version, _ = check(commit)
336350
self.assertEqual(next_version, [4, 1])

test/export/testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,8 @@ def expectedFailureSerDerPreDispatch(fn):
284284
def expectedFailurePreDispatchRunDecomp(fn):
285285
fn._expected_failure_pre_dispatch = True
286286
return fn
287+
288+
289+
def expectedFailureCppSerDes(fn):
290+
fn._expected_failure_cpp_serdes = True
291+
return fn

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
6464

6565
from . import (
6666
_aoti,
67+
_export,
6768
_cpu,
6869
_dynamo,
6970
_functorch,

torch/_C/_export.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Defined in torch/csrc/export/pybind.cpp
2+
3+
class CppExportedProgram: ...
4+
5+
def deserialize_exported_program(
6+
serialized_program: str,
7+
) -> CppExportedProgram: ...
8+
def serialize_exported_program(
9+
cpp_exported_program: CppExportedProgram,
10+
) -> str: ...

0 commit comments

Comments
 (0)