- 
                Notifications
    You must be signed in to change notification settings 
- Fork 730
Add torch ops for d2go models #1509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            dncnbuck
  wants to merge
  59
  commits into
  apple:main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
dncnbuck:add-torch-ops-for-d2go-models
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
      
        
          +230
        
        
          −7
        
        
          
        
      
    
  
  
     Open
                    Changes from 16 commits
      Commits
    
    
            Show all changes
          
          
            59 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      c5fe58a
              
                handle-split-op-when-num-splits-1
              
              
                dncnbuck 5542de8
              
                handle when unpacked tuple contains only single value
              
              
                dncnbuck fdd1590
              
                roi_align implimentation
              
              
                dncnbuck 7b4cbd9
              
                add torch op numel
              
              
                dncnbuck b002207
              
                add torch op nms
              
              
                dncnbuck 6389427
              
                add torch op repeat_interleave
              
              
                dncnbuck cecae9c
              
                add torch op narrow
              
              
                dncnbuck 15185d8
              
                add torch op logicaland
              
              
                dncnbuck e1b7d0f
              
                handle broadcasting indicies for torch index op
              
              
                dncnbuck 70f1954
              
                patch torch clamp op to handle int dtype
              
              
                dncnbuck 9d2d092
              
                return copy of inpt tensor if no dtype is given
              
              
                dncnbuck b913630
              
                remove accidential typo
              
              
                dncnbuck bd08a2b
              
                Merge branch 'main' into add-torch-ops-for-d2go-models
              
              
                dncnbuck b0074cc
              
                remove logicaland op and alias new logical_and op
              
              
                dncnbuck a9fb7ed
              
                consistent use of double quotes
              
              
                dncnbuck 29217d5
              
                remove link to crop and resize layer in NN
              
              
                dncnbuck fb0cd19
              
                Merge branch 'main' into add-torch-ops-for-d2go-models
              
              
                dncnbuck 12662dd
              
                Merge branch 'main' into add-torch-ops-for-d2go-models
              
              
                dncnbuck b268f9b
              
                6.0b1 Release (#1508)
              
              
                TobyRoseman c58abbd
              
                Add 6.0b1 install instructions to README.md (#1510)
              
              
                TobyRoseman df41d90
              
                Update README.md (#1511)
              
              
                ArjunSharda 573f103
              
                remove logicaland op and alias new logical_and op
              
              
                dncnbuck 8834011
              
                consistent use of double quotes
              
              
                dncnbuck 12b3cc1
              
                remove link to crop and resize layer in NN
              
              
                dncnbuck bdcfe40
              
                Docs for v6 with layer_norm fix (#1514)
              
              
                tonybove-apple 4508f19
              
                Update ---bug-report.md (#1513)
              
              
                ArjunSharda 7944178
              
                Fix a bug when destructing coreml model (#1515)
              
              
                jakesabathia2 3356450
              
                Formatting fixes and compression submenu (#1518)
              
              
                tonybove-apple 203b555
              
                Update CONTRIBUTING.md (#1521)
              
              
                ArjunSharda 01983e6
              
                Add torch AdaptiveAvgPool2d test. (#1502)
              
              
                fukatani f181995
              
                Update BUILDING.md (#1523)
              
              
                ArjunSharda 9715d07
              
                Update ---feature-request.md (change of wording mostly) (#1524)
              
              
                ArjunSharda d13735d
              
                Torch eq and ne ops supports bool type. (#1501)
              
              
                fukatani 7ce9f6e
              
                Merge branch 'add-torch-ops-for-d2go-models' of https://github.com/dn…
              
              
                dncnbuck 5d842ec
              
                accept incoming changes
              
              
                dncnbuck 4353c4c
              
                Add tests for numel and narrow
              
              
                dncnbuck ed2f33e
              
                Add tests for torch.op.nms
              
              
                dncnbuck bf5de6b
              
                tidy up
              
              
                dncnbuck b2e8153
              
                tidy up
              
              
                dncnbuck 20da0e2
              
                handle-split-op-when-num-splits-1
              
              
                dncnbuck ca4cd92
              
                handle when unpacked tuple contains only single value
              
              
                dncnbuck c80a3a7
              
                handle broadcasting indicies for torch index op
              
              
                dncnbuck 8631d1b
              
                patch torch clamp op to handle int dtype
              
              
                dncnbuck 2f05538
              
                return copy of inpt tensor if no dtype is given
              
              
                dncnbuck ed02c4d
              
                remove accidential typo
              
              
                dncnbuck ec550ca
              
                Docs for v6 with layer_norm fix (#1514)
              
              
                tonybove-apple 78ab5fd
              
                Update ---bug-report.md (#1513)
              
              
                ArjunSharda f8e1776
              
                Fix a bug when destructing coreml model (#1515)
              
              
                jakesabathia2 d96b7d6
              
                Formatting fixes and compression submenu (#1518)
              
              
                tonybove-apple c082d4c
              
                Update CONTRIBUTING.md (#1521)
              
              
                ArjunSharda 108f5da
              
                Add torch AdaptiveAvgPool2d test. (#1502)
              
              
                fukatani 47debd3
              
                Update BUILDING.md (#1523)
              
              
                ArjunSharda e1aaf57
              
                Update ---feature-request.md (change of wording mostly) (#1524)
              
              
                ArjunSharda 1f29b6a
              
                Torch eq and ne ops supports bool type. (#1501)
              
              
                fukatani f25a684
              
                Add tests for numel and narrow
              
              
                dncnbuck f2f795b
              
                Add tests for torch.op.nms
              
              
                dncnbuck 9be029f
              
                tidy up
              
              
                dncnbuck 37eef0e
              
                resolve conflict
              
              
                dncnbuck 9e842a2
              
                some code clean up
              
              
                dncnbuck File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -2586,6 +2586,10 @@ def upsample_nearest2d(context, node): | |
| def tupleunpack(context, node): | ||
| inputs = _get_inputs(context, node, expected=1) | ||
| values = inputs[0] | ||
|  | ||
| if len(node.outputs) == 1: | ||
| values = [values] | ||
|  | ||
| # Node input could have been turned into constant array in @tupleconstruct | ||
| if not isinstance(values, tuple) and not isinstance(values, list): | ||
| values = values.val | ||
|  | @@ -3085,8 +3089,11 @@ def index(context, node): | |
| # For multiple index axes case, we now assume that all the index have equal shape | ||
| for index in valid_indices: | ||
| if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape): | ||
| raise NotImplementedError("Broadcasable tensor index not supported.") | ||
|  | ||
| broadcast_inputs = _broadcast_tensors([valid_indices[0], index]) | ||
| index = broadcast_inputs[1] | ||
| valid_indices[0] = broadcast_inputs[0] | ||
| valid_indices.append(index) | ||
|  | ||
| # First stack the index together | ||
| indices_rank = valid_indices[0].rank | ||
| indices = mb.stack(values=valid_indices, axis=indices_rank) | ||
|  | @@ -3386,6 +3393,18 @@ def _slice(context, node): | |
| context.add(res) | ||
|  | ||
|  | ||
| def _num_splits_and_sizes(split_sizes): | ||
| if split_sizes.sym_val is not None: | ||
| return len(split_sizes.sym_val), split_sizes.sym_val | ||
|  | ||
| if any_symbolic(split_sizes.shape): | ||
| raise ValueError("Unable to determine number of splits") | ||
|  | ||
| num_splits = len(split_sizes.shape) | ||
| sizes = [get_new_symbol() for _ in range(num_splits)] | ||
| return num_splits, sizes | ||
|  | ||
|  | ||
| @register_torch_op(torch_alias=["split_with_sizes"]) | ||
| def split(context, node): | ||
| inputs = _get_inputs(context, node, expected=3) | ||
|  | @@ -3413,6 +3432,14 @@ def split(context, node): | |
| else: | ||
| partial_size = mb.mul(x=tmp, y=remainder) | ||
| split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0) | ||
|  | ||
|  | ||
| num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes) | ||
| if num_splits == 1: | ||
| out = mb.identity(x=x, name=node.name) | ||
| context.add(out, node.name) | ||
| return | ||
|  | ||
| res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name) | ||
| context.add(res, torch_name=node.name) | ||
|  | ||
|  | @@ -3470,6 +3497,13 @@ def to(context, node): | |
| "Received invalid arguments for PyTorch conversion of op {}".format(node) | ||
| ) | ||
|  | ||
| # We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype | ||
| # see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to | ||
| if dtype is None: | ||
| out = mb.identity(x=_input, name=node.name) | ||
| context.add(out, node.name) | ||
| return | ||
|  | ||
| torch_dtype = NUM_TO_TORCH_DTYPE[dtype] | ||
| if isinstance(_input, Var) and _input.val is not None: | ||
| _input = _input.val | ||
|  | @@ -3912,8 +3946,20 @@ def ceil(context, node): | |
| @register_torch_op | ||
| def clamp(context, node): | ||
| inputs = _get_inputs(context, node, expected=3) | ||
| min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min | ||
| max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max | ||
| if not inputs[1]: | ||
| min_val = _np.finfo(_np.float32).min | ||
| else: | ||
| min_val = inputs[1] | ||
| if types.builtin_to_string(min_val.dtype).startswith('int'): | ||
| min_val = mb.cast(x=min_val, dtype='fp32') | ||
|  | ||
| if not inputs[2]: | ||
| max_val = _np.finfo(_np.float32).max | ||
| else: | ||
| max_val = inputs[2] | ||
| if types.builtin_to_string(max_val.dtype).startswith('int'): | ||
| max_val = mb.cast(x=max_val, dtype='fp32') | ||
|  | ||
| context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name)) | ||
|  | ||
| @register_torch_op | ||
|  | @@ -4062,7 +4108,7 @@ def is_floating_point(context, node): | |
| is_float = types.is_float(inputs[0].dtype) | ||
| context.add(mb.const(val=is_float, name=node.name)) | ||
|  | ||
| @register_torch_op() | ||
| @register_torch_op(torch_alias=["__and_", "__and__"]) | ||
| def logical_and(context, node): | ||
| inputs = _get_inputs(context, node, expected=2) | ||
| x, y = inputs | ||
|  | @@ -4241,6 +4287,11 @@ def _make_tensor(list_of_tensor, name, rank): | |
| context.add(mb.identity(x=val, name=node.name)) | ||
| return | ||
|  | ||
| if inputs[2] is None: | ||
| res = mb.const(val=[val.val], name=node.name) | ||
| context.add(res, torch_name=node.name) | ||
| return | ||
|  | ||
| # Case 2: Create a tensor filled with a single value | ||
| val = val.val # element val to fill | ||
| msg_prefix = 'torch::tensor {} '.format(node.name) | ||
|  | @@ -4471,7 +4522,6 @@ def _scatter(context, inputs, mode, name): | |
| axis=axis, mode=mode, name=name) | ||
| context.add(result) | ||
|  | ||
|  | ||
| @register_torch_op | ||
| def scatter(context, node): | ||
| inputs = _get_inputs(context, node) | ||
|  | @@ -4489,8 +4539,155 @@ def scatter(context, node): | |
|  | ||
| _scatter(context, inputs, mode, node.name) | ||
|  | ||
|  | ||
| @register_torch_op | ||
| def scatter_add(context, node): | ||
| inputs = _get_inputs(context, node) | ||
| _scatter(context, inputs, 'add', node.name) | ||
|  | ||
| @register_torch_op | ||
| def roi_align(context, node): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there unit tests for this method? | ||
| inputs = _get_inputs(context, node) | ||
|  | ||
| x = context[node.inputs[0]] | ||
| input_shape = x.shape # (B, h_in, w_in, C) | ||
| if len(input_shape) != 4: | ||
| raise ValueError( | ||
| '"CropResize" op: expected input rank 4, got {}'.format(x.rank) | ||
| ) | ||
| Hin, Win = input_shape[1:3] | ||
|  | ||
| const_box_info = True | ||
| if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None: | ||
| const_box_info = False | ||
|  | ||
| extrapolation_value = context[node.inputs[2]].val | ||
|  | ||
| # CoreML index information along with boxes | ||
| if const_box_info: | ||
| boxes = context[node.inputs[1]].val | ||
| # CoreML expects boxes/ROI in | ||
| # [N, 1, 5, 1, 1] format | ||
| boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1) | ||
| else: | ||
| boxes = inputs[1] | ||
| boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1]) | ||
| # Get Height and Width of crop | ||
| h_out = inputs[3] | ||
| w_out = inputs[4] | ||
|  | ||
| # Torch input format: [B, C, h_in, w_in] | ||
| # CoreML input format: [B, C, h_in, w_in] | ||
|  | ||
| # Crop Resize | ||
| x = mb.crop_resize( | ||
| x=x, | ||
| roi=boxes, | ||
| target_height=h_out.val, | ||
| target_width=w_out.val, | ||
| normalized_coordinates=True, | ||
| spatial_scale=extrapolation_value, | ||
| box_coordinate_mode="CORNERS_HEIGHT_FIRST", | ||
| sampling_mode='OFFSET_CORNERS', | ||
| ) | ||
|  | ||
| # CoreML output format: [N, 1, C, h_out, w_out] | ||
| # Torch output format: [N, C, h_out, w_out] | ||
| x = mb.squeeze(x=x, axes=[1]) | ||
|  | ||
| context.add(x, torch_name=node.outputs[0]) | ||
|  | ||
| @register_torch_op | ||
| def numel(context, node): | ||
| inputs = _get_inputs(context, node, expected=1) | ||
| context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0]) | ||
|  | ||
| @register_torch_op | ||
| def nms(context, node): | ||
| inputs = _get_inputs(context, node) | ||
| boxes = inputs[0] | ||
|  | ||
| num_boxes = boxes.shape[1] | ||
| max_boxes = num_boxes # we set the max_boxes just to be # input boxes | ||
|  | ||
| scores = inputs[1] | ||
| iou_threshold = inputs[2] | ||
| boxes = mb.expand_dims(x=boxes, axes=[0]) | ||
| scores = mb.expand_dims(x=scores, axes=[0, -1]) | ||
|  | ||
| # Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not | ||
| # have float('-inf') support, converted to minimum float32 instead | ||
| score_threshold = -3.4e38 | ||
|  | ||
| _, _, x, _ = mb.non_maximum_suppression( | ||
| boxes=boxes, | ||
| scores=scores, | ||
| iou_threshold=iou_threshold, | ||
| score_threshold=score_threshold, | ||
| max_boxes=max_boxes | ||
| ) | ||
|  | ||
| if not is_symbolic(num_boxes): | ||
| x = mb.squeeze(x=x, axes=[0]) | ||
| x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name) | ||
| else: | ||
| x = mb.squeeze(x=x, axes=[0], name=node.name) | ||
| context.add(x, torch_name=node.name) | ||
|  | ||
| @register_torch_op | ||
| def repeat_interleave(context, node): | ||
| inputs = _get_inputs(context, node) | ||
|  | ||
| x = inputs[0] | ||
| reps = inputs[1] | ||
| dim = inputs[2] if inputs[2] else 0 | ||
|  | ||
| perm = [] + [axis for axis in range(x.rank) if axis not in []] | ||
|  | ||
| x = mb.transpose(x=x, perm=perm) # torch.transpose(x, 0, 1) | ||
| x = mb.tile(x=x, reps=reps.val[0], name=node.name) # torch.repeat(x, size) | ||
| x = mb.reshape(x=x, shape=(-1, x.shape[0])) # x.view(-1, 2) | ||
| x = mb.transpose(x=x, perm=(-1, 0)) # torch.transpose(x, 0, 1) | ||
| dims = list(x.shape) | ||
|  | ||
| # Implementation of flatten | ||
| total = 1 | ||
| start_val = dim | ||
| end_val = -1 | ||
| start = len(dims) + start_val if start_val < 0 else start_val | ||
| end = len(dims) + end_val if end_val < 0 else end_val | ||
|  | ||
| if start > len(dims) or end > len(dims) or start < 0 or end < 0: | ||
| raise ValueError( | ||
| "Invalid start and end. (start, end) == ({}, {})".format(start, end_val) | ||
| ) | ||
| if start > end: | ||
| raise ValueError( | ||
| "Start must be before end. (start, end) == ({}, {})".format(start, end_val) | ||
| ) | ||
| x_shape = mb.shape(x=x) | ||
|  | ||
| shape1 = mb.slice_by_index(x=x_shape, begin=[0], end=[start]) | ||
| shape2 = mb.slice_by_index(x=x_shape, begin=[end + 1], end=[len(dims)]) | ||
|  | ||
| flatten_dim = -1 | ||
| if not any_symbolic(x.shape): | ||
| flatten_dim = 1 | ||
| for dim in dims[start: end + 1]: | ||
| flatten_dim *= dim | ||
|  | ||
| shape = mb.concat(values=(shape1, [flatten_dim], shape2), axis=0) | ||
| shape = mb.cast(x=shape, dtype="int32") | ||
| reshape = mb.reshape(x=x, shape=shape, name=node.name) | ||
|  | ||
| context.add(reshape, node.name) | ||
|  | ||
| @register_torch_op | ||
| def narrow(context, node): | ||
| data, dim, start, length = _get_inputs(context, node, expected=4) | ||
| data_shape = mb.shape(x=data).val | ||
| begin = [0]*len(data_shape) | ||
| end = [x for x in data_shape] | ||
| begin[dim.val] = start.val | ||
| end[dim.val] = start.val+length.val | ||
| out = mb.slice_by_index(x=data, begin=begin, end=end) | ||
| context.add(out, torch_name=node.name) | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this just be an inner method of the
splitmethod?