|
20 | 20 | Iterator, |
21 | 21 | List, |
22 | 22 | Optional, |
| 23 | + Sequence, |
23 | 24 | Set, |
24 | 25 | Tuple, |
25 | 26 | TYPE_CHECKING, |
|
43 | 44 |
|
44 | 45 | from .. import async_compile, config, ir |
45 | 46 | from ..codecache import output_code_log |
46 | | -from ..ir import ReinterpretView |
| 47 | +from ..ir import IRNode, ReinterpretView |
47 | 48 | from ..runtime import triton_heuristics |
48 | 49 | from ..runtime.hints import DeviceProperties |
49 | 50 | from ..utils import ( |
@@ -1016,7 +1017,7 @@ def generate_user_defined_triton_kernel( |
1016 | 1017 |
|
1017 | 1018 | args = [self.val_to_arg_str(v) for v in raw_args] |
1018 | 1019 | arg_types = [ |
1019 | | - arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) |
| 1020 | + arg.get_dtype() if isinstance(arg, IRNode) else type(arg) |
1020 | 1021 | for arg in raw_args |
1021 | 1022 | ] |
1022 | 1023 | self.generate_kernel_call( |
@@ -1306,15 +1307,15 @@ def codegen_sizevar(self, x: Expr) -> str: |
1306 | 1307 | def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: |
1307 | 1308 | return f"{basename}[{index}]" |
1308 | 1309 |
|
1309 | | - def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: |
1310 | | - parts = list(map(self.codegen_python_sizevar, shape)) |
| 1310 | + def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: |
| 1311 | + parts = [*map(self.codegen_python_sizevar, shape)] |
1311 | 1312 | if len(parts) == 0: |
1312 | 1313 | return "()" |
1313 | 1314 | if len(parts) == 1: |
1314 | 1315 | return f"({parts[0]}, )" |
1315 | 1316 | return f"({', '.join(parts)})" |
1316 | 1317 |
|
1317 | | - def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: |
| 1318 | + def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: |
1318 | 1319 | return self.codegen_python_shape_tuple(shape) |
1319 | 1320 |
|
1320 | 1321 | def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: |
|
0 commit comments