Skip to content

Commit 29cc517

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
call patch_builtin_len in TracingAdapter
Summary: so it's hidden from user code Also simplify the flattening code Reviewed By: alexander-kirillov Differential Revision: D26609834 fbshipit-source-id: adf56b191419ad05bbdb2a61114756dc9045a950
1 parent b3ca212 commit 29cc517

File tree

4 files changed

+56
-52
lines changed

4 files changed

+56
-52
lines changed

detectron2/export/flatten.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from torch import nn
66

77
from detectron2.structures import Boxes, Instances
8+
from detectron2.utils.registry import _convert_target_to_string, locate
9+
10+
from .torchscript_patch import patch_builtin_len
811

912

1013
@dataclass
@@ -39,48 +42,52 @@ def __call__(self, values):
3942
@staticmethod
4043
def _concat(values):
4144
ret = ()
42-
idx_mapping = []
45+
sizes = []
4346
for v in values:
4447
assert isinstance(v, tuple), "Flattened results must be a tuple"
45-
oldlen = len(ret)
4648
ret = ret + v
47-
idx_mapping.append([oldlen, len(ret)])
48-
return ret, idx_mapping
49+
sizes.append(len(v))
50+
return ret, sizes
4951

5052
@staticmethod
51-
def _split(values, idx_mapping):
52-
if len(idx_mapping):
53-
expected_len = idx_mapping[-1][-1]
53+
def _split(values, sizes):
54+
if len(sizes):
55+
expected_len = sum(sizes)
5456
assert (
5557
len(values) == expected_len
5658
), f"Values has length {len(values)} but expect length {expected_len}."
5759
ret = []
58-
for (start, end) in idx_mapping:
59-
ret.append(values[start:end])
60+
for k in range(len(sizes)):
61+
begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
62+
ret.append(values[begin:end])
6063
return ret
6164

6265

6366
@dataclass
6467
class ListSchema(Schema):
65-
schemas: List[Schema]
66-
idx_mapping: List[List[int]]
67-
is_tuple: bool
68+
schemas: List[Schema] # the schemas that define how to flatten each element in the list
69+
sizes: List[int] # the flattened length of each element
6870

6971
def __call__(self, values):
70-
values = self._split(values, self.idx_mapping)
72+
values = self._split(values, self.sizes)
7173
if len(values) != len(self.schemas):
7274
raise ValueError(
7375
f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
7476
)
7577
values = [m(v) for m, v in zip(self.schemas, values)]
76-
return list(values) if not self.is_tuple else tuple(values)
78+
return list(values)
7779

7880
@classmethod
7981
def flatten(cls, obj):
80-
is_tuple = isinstance(obj, tuple)
8182
res = [flatten_to_tuple(k) for k in obj]
82-
values, idx = cls._concat([k[0] for k in res])
83-
return values, cls([k[1] for k in res], idx, is_tuple)
83+
values, sizes = cls._concat([k[0] for k in res])
84+
return values, cls([k[1] for k in res], sizes)
85+
86+
87+
@dataclass
88+
class TupleSchema(ListSchema):
89+
def __call__(self, values):
90+
return tuple(super().__call__(values))
8491

8592

8693
@dataclass
@@ -94,12 +101,11 @@ def flatten(cls, obj):
94101

95102

96103
@dataclass
97-
class DictSchema(Schema):
104+
class DictSchema(ListSchema):
98105
keys: List[str]
99-
value_schema: ListSchema
100106

101107
def __call__(self, values):
102-
values = self.value_schema(values)
108+
values = super().__call__(values)
103109
return dict(zip(self.keys, values))
104110

105111
@classmethod
@@ -110,39 +116,40 @@ def flatten(cls, obj):
110116
keys = sorted(obj.keys())
111117
values = [obj[k] for k in keys]
112118
ret, schema = ListSchema.flatten(values)
113-
return ret, cls(keys, schema)
119+
return ret, cls(schema.schemas, schema.sizes, keys)
114120

115121

116122
@dataclass
117-
class InstancesSchema(Schema):
118-
field_names: List[str]
119-
field_schema: ListSchema
120-
123+
class InstancesSchema(DictSchema):
121124
def __call__(self, values):
122125
image_size, fields = values[-1], values[:-1]
123-
fields = self.field_schema(fields)
124-
fields = dict(zip(self.field_names, fields))
126+
fields = super().__call__(fields)
125127
return Instances(image_size, **fields)
126128

127129
@classmethod
128130
def flatten(cls, obj):
129-
field_names = sorted(obj.get_fields().keys())
130-
values = [obj.get(f) for f in field_names]
131-
ret, schema = ListSchema.flatten(values)
131+
ret, schema = super().flatten(obj.get_fields())
132132
size = obj.image_size
133133
if not isinstance(size, torch.Tensor):
134134
size = torch.tensor(size)
135-
return ret + (size,), cls(field_names, schema)
135+
return ret + (size,), schema
136136

137137

138138
@dataclass
139-
class BoxesSchema(Schema):
139+
class TensorWrapSchema(Schema):
140+
"""
141+
For classes that are simple wrapper of tensors, e.g.
142+
Boxes, RotatedBoxes, BitMasks
143+
"""
144+
145+
class_name: str
146+
140147
def __call__(self, values):
141-
return Boxes(values[0])
148+
return locate(self.class_name)(values[0])
142149

143150
@classmethod
144151
def flatten(cls, obj):
145-
return (obj.tensor,), cls()
152+
return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
146153

147154

148155
# if more custom structures needed in the future, can allow
@@ -159,10 +166,11 @@ def flatten_to_tuple(obj):
159166
"""
160167
schemas = [
161168
((str, bytes), IdentitySchema),
162-
(collections.abc.Sequence, ListSchema),
169+
(list, ListSchema),
170+
(tuple, TupleSchema),
163171
(collections.abc.Mapping, DictSchema),
164172
(Instances, InstancesSchema),
165-
(Boxes, BoxesSchema),
173+
(Boxes, TensorWrapSchema),
166174
]
167175
for klass, schema in schemas:
168176
if isinstance(obj, klass):
@@ -244,7 +252,7 @@ def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable]
244252
)
245253

246254
def forward(self, *args: torch.Tensor):
247-
with torch.no_grad():
255+
with torch.no_grad(), patch_builtin_len():
248256
inputs_orig_format = self.inputs_schema(args)
249257
outputs = self.inference_func(self.model, *inputs_orig_format)
250258
flattened_outputs, schema = flatten_to_tuple(outputs)

detectron2/export/torchscript_patch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def _import(path):
270270
)
271271

272272

273-
# TODO: this is a private utility. Should be made more useful through a model export api.
274273
@contextmanager
275274
def patch_builtin_len(modules=()):
276275
"""

tests/test_export_torchscript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _test_model(self, config_path, inference_func):
103103

104104
wrapper = TracingAdapter(model, image, inference_func)
105105
wrapper.eval()
106-
with torch.no_grad(), patch_builtin_len():
106+
with torch.no_grad():
107107
small_image = nn.functional.interpolate(image, scale_factor=0.5)
108108
# trace with a different image, and the trace must still work
109109
traced_model = torch.jit.trace(wrapper, (small_image,))

tools/deploy/export_model.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,15 @@ def inference(model, inputs):
9191

9292
traceable_model = TracingAdapter(torch_model, inputs, inference)
9393

94-
from detectron2.export.torchscript_patch import patch_builtin_len
95-
96-
with patch_builtin_len():
97-
if args.format == "torchscript":
98-
ts_model = torch.jit.trace(traceable_model, (image,))
99-
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
100-
torch.jit.save(ts_model, f)
101-
dump_torchscript_IR(ts_model, args.output)
102-
elif args.format == "onnx":
103-
# NOTE onnx export currently failing in pytorch
104-
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
105-
torch.onnx.export(traceable_model, (image,), f)
94+
if args.format == "torchscript":
95+
ts_model = torch.jit.trace(traceable_model, (image,))
96+
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
97+
torch.jit.save(ts_model, f)
98+
dump_torchscript_IR(ts_model, args.output)
99+
elif args.format == "onnx":
100+
# NOTE onnx export currently failing in pytorch
101+
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
102+
torch.onnx.export(traceable_model, (image,), f)
106103
logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
107104
logger.info("Outputs schema: " + str(traceable_model.outputs_schema))
108105

@@ -114,7 +111,7 @@ def inference(model, inputs):
114111
def eval_wrapper(inputs):
115112
"""
116113
The exported model does not contain the final resize step, which is typically
117-
useless for deployment but needed for evaluation. We add it manually here.
114+
unused in deployment but needed for evaluation. We add it manually here.
118115
"""
119116
input = inputs[0]
120117
instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"]

0 commit comments

Comments
 (0)