55from torch import nn
66
77from detectron2 .structures import Boxes , Instances
8+ from detectron2 .utils .registry import _convert_target_to_string , locate
9+
10+ from .torchscript_patch import patch_builtin_len
811
912
1013@dataclass
@@ -39,48 +42,52 @@ def __call__(self, values):
3942 @staticmethod
4043 def _concat (values ):
4144 ret = ()
42- idx_mapping = []
45+ sizes = []
4346 for v in values :
4447 assert isinstance (v , tuple ), "Flattened results must be a tuple"
45- oldlen = len (ret )
4648 ret = ret + v
47- idx_mapping .append ([ oldlen , len (ret )] )
48- return ret , idx_mapping
49+ sizes .append (len (v ) )
50+ return ret , sizes
4951
5052 @staticmethod
51- def _split (values , idx_mapping ):
52- if len (idx_mapping ):
53- expected_len = idx_mapping [ - 1 ][ - 1 ]
53+ def _split (values , sizes ):
54+ if len (sizes ):
55+ expected_len = sum ( sizes )
5456 assert (
5557 len (values ) == expected_len
5658 ), f"Values has length { len (values )} but expect length { expected_len } ."
5759 ret = []
58- for (start , end ) in idx_mapping :
59- ret .append (values [start :end ])
60+ for k in range (len (sizes )):
61+ begin , end = sum (sizes [:k ]), sum (sizes [: k + 1 ])
62+ ret .append (values [begin :end ])
6063 return ret
6164
6265
6366@dataclass
6467class ListSchema (Schema ):
65- schemas : List [Schema ]
66- idx_mapping : List [List [int ]]
67- is_tuple : bool
68+ schemas : List [Schema ] # the schemas that define how to flatten each element in the list
69+ sizes : List [int ] # the flattened length of each element
6870
6971 def __call__ (self , values ):
70- values = self ._split (values , self .idx_mapping )
72+ values = self ._split (values , self .sizes )
7173 if len (values ) != len (self .schemas ):
7274 raise ValueError (
7375 f"Values has length { len (values )} but schemas " f"has length { len (self .schemas )} !"
7476 )
7577 values = [m (v ) for m , v in zip (self .schemas , values )]
76- return list (values ) if not self . is_tuple else tuple ( values )
78+ return list (values )
7779
7880 @classmethod
7981 def flatten (cls , obj ):
80- is_tuple = isinstance (obj , tuple )
8182 res = [flatten_to_tuple (k ) for k in obj ]
82- values , idx = cls ._concat ([k [0 ] for k in res ])
83- return values , cls ([k [1 ] for k in res ], idx , is_tuple )
83+ values , sizes = cls ._concat ([k [0 ] for k in res ])
84+ return values , cls ([k [1 ] for k in res ], sizes )
85+
86+
87+ @dataclass
88+ class TupleSchema (ListSchema ):
89+ def __call__ (self , values ):
90+ return tuple (super ().__call__ (values ))
8491
8592
8693@dataclass
@@ -94,12 +101,11 @@ def flatten(cls, obj):
94101
95102
96103@dataclass
97- class DictSchema (Schema ):
104+ class DictSchema (ListSchema ):
98105 keys : List [str ]
99- value_schema : ListSchema
100106
101107 def __call__ (self , values ):
102- values = self . value_schema (values )
108+ values = super (). __call__ (values )
103109 return dict (zip (self .keys , values ))
104110
105111 @classmethod
@@ -110,39 +116,40 @@ def flatten(cls, obj):
110116 keys = sorted (obj .keys ())
111117 values = [obj [k ] for k in keys ]
112118 ret , schema = ListSchema .flatten (values )
113- return ret , cls (keys , schema )
119+ return ret , cls (schema . schemas , schema . sizes , keys )
114120
115121
116122@dataclass
117- class InstancesSchema (Schema ):
118- field_names : List [str ]
119- field_schema : ListSchema
120-
123+ class InstancesSchema (DictSchema ):
121124 def __call__ (self , values ):
122125 image_size , fields = values [- 1 ], values [:- 1 ]
123- fields = self .field_schema (fields )
124- fields = dict (zip (self .field_names , fields ))
126+ fields = super ().__call__ (fields )
125127 return Instances (image_size , ** fields )
126128
127129 @classmethod
128130 def flatten (cls , obj ):
129- field_names = sorted (obj .get_fields ().keys ())
130- values = [obj .get (f ) for f in field_names ]
131- ret , schema = ListSchema .flatten (values )
131+ ret , schema = super ().flatten (obj .get_fields ())
132132 size = obj .image_size
133133 if not isinstance (size , torch .Tensor ):
134134 size = torch .tensor (size )
135- return ret + (size ,), cls ( field_names , schema )
135+ return ret + (size ,), schema
136136
137137
138138@dataclass
139- class BoxesSchema (Schema ):
139+ class TensorWrapSchema (Schema ):
140+ """
141+ For classes that are simple wrapper of tensors, e.g.
142+ Boxes, RotatedBoxes, BitMasks
143+ """
144+
145+ class_name : str
146+
140147 def __call__ (self , values ):
141- return Boxes (values [0 ])
148+ return locate ( self . class_name ) (values [0 ])
142149
143150 @classmethod
144151 def flatten (cls , obj ):
145- return (obj .tensor ,), cls ()
152+ return (obj .tensor ,), cls (_convert_target_to_string ( type ( obj )) )
146153
147154
148155# if more custom structures needed in the future, can allow
@@ -159,10 +166,11 @@ def flatten_to_tuple(obj):
159166 """
160167 schemas = [
161168 ((str , bytes ), IdentitySchema ),
162- (collections .abc .Sequence , ListSchema ),
169+ (list , ListSchema ),
170+ (tuple , TupleSchema ),
163171 (collections .abc .Mapping , DictSchema ),
164172 (Instances , InstancesSchema ),
165- (Boxes , BoxesSchema ),
173+ (Boxes , TensorWrapSchema ),
166174 ]
167175 for klass , schema in schemas :
168176 if isinstance (obj , klass ):
@@ -244,7 +252,7 @@ def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable]
244252 )
245253
246254 def forward (self , * args : torch .Tensor ):
247- with torch .no_grad ():
255+ with torch .no_grad (), patch_builtin_len () :
248256 inputs_orig_format = self .inputs_schema (args )
249257 outputs = self .inference_func (self .model , * inputs_orig_format )
250258 flattened_outputs , schema = flatten_to_tuple (outputs )
0 commit comments