22
33import logging
44import operator
5- from typing import Callable , Dict , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
66
77import numpy as np
88import torch
@@ -217,18 +217,51 @@ def aten_ops_native_group_norm(
217217 )
218218
219219
220+ def parse_cat_args (
221+ args : Tuple [Argument , ...], kwargs : Dict [str , Any ]
222+ ) -> Tuple [List [Any ], int ]:
223+ """
224+ Process inputs for torch.ops.aten.cat.default.
225+
226+ Handles these valid patterns:
227+ 1. args = ((t1, t2, ...), dim)
228+ 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
229+
230+ Returns:
231+ (input_tensors, dim)
232+ input_tensors: tuple of tensor arguments
233+ dim: integer concatenation dimension (default 0)
234+ """
235+
236+ if len (args ) > 1 and isinstance (args [0 ], (list , tuple )):
237+ input_tensors = list (args [0 ])
238+ dim = args_bounds_check (args , 1 , 0 )
239+
240+ else :
241+ # If single arg is itself a tuple/list, unwrap it
242+ if len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
243+ input_tensors = list (args [0 ])
244+ else :
245+ input_tensors = list (args )
246+
247+ dim = kwargs .get ("dim" , 0 )
248+
249+ return input_tensors , dim
250+
251+
220252def cat_validator (node : Node , settings : Optional [CompilationSettings ] = None ) -> bool :
221253 # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
222- for each_input in node .args [0 ]:
254+ inputs , _ = parse_cat_args (node .args , node .kwargs )
255+ for each_input in inputs :
223256 if isinstance (each_input , TRTTensor ) and any (s == 0 for s in each_input .shape ):
224257 return False
225258 return True
226259
227260
228261@dynamo_tensorrt_converter (
229262 torch .ops .aten .cat .default ,
230- capability_validator = cat_validator ,
231263 supports_dynamic_shapes = True ,
264+ capability_validator = cat_validator ,
232265)
233266def aten_ops_cat (
234267 ctx : ConversionContext ,
@@ -237,13 +270,14 @@ def aten_ops_cat(
237270 kwargs : Dict [str , Argument ],
238271 name : str ,
239272) -> Union [TRTTensor , Sequence [TRTTensor ]]:
273+ inputs , dim = parse_cat_args (args , kwargs )
240274 return impl .cat .cat (
241275 ctx ,
242276 target ,
243277 SourceIR .ATEN ,
244278 name ,
245- input = args [ 0 ] ,
246- dim = args_bounds_check ( args , 1 , 0 ) ,
279+ input = inputs ,
280+ dim = dim ,
247281 )
248282
249283
0 commit comments