Skip to content

Commit c58264c

Browse files
desertfirepytorchmergebot
authored andcommitted
[inductor] Support multiple symbolic numel expr in CudaWrapperCodeGen (pytorch#102093)
Summary: Add a set to avoid generating extra `auto` when seeing the symbolic numel expression for the second time. Pull Request resolved: pytorch#102093 Approved by: https://github.com/jansel
1 parent 7042e10 commit c58264c

File tree

4 files changed

+27
-25
lines changed

4 files changed

+27
-25
lines changed

test/inductor/test_cpp_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class BaseTest(NamedTuple):
202202
BaseTest("test_embedding_bag"), # test default FallbackKernel
203203
BaseTest("test_index_put_deterministic_fallback"),
204204
BaseTest("test_linear1"),
205-
# BaseTest("test_linear2"),
205+
BaseTest("test_linear2"),
206206
BaseTest("test_mm_views"),
207207
BaseTest("test_multi_device"),
208208
BaseTest("test_profiler_mark_wrapper_call"),

torch/_inductor/codegen/triton.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,8 @@ def dense_size_str(self):
16651665
sizes.append("1")
16661666
return f"[{', '.join(sizes)}]"
16671667

1668-
def call_kernel(self, code, name: str):
1668+
def call_kernel(self, name: str):
1669+
wrapper = V.graph.wrapper_code
16691670
_, call_args, _ = self.args.python_argdefs()
16701671
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
16711672
for i in range(len(call_args)):
@@ -1677,26 +1678,31 @@ def call_kernel(self, code, name: str):
16771678
if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
16781679
expr = tree.numel
16791680
else:
1681+
expr = f"{name}_{tree.prefix}numel"
1682+
# TODO(voz): Tragic. This should at the very least be a util to slapp on declare and ending.
1683+
# The real fix here is to revisit our cross language calling convention.
1684+
if expr not in wrapper.kenel_numel_expr:
1685+
wrapper.kenel_numel_expr.add(expr)
1686+
wrapper.writeline(
1687+
f"{wrapper.declare}{expr} = {pexpr(tree.numel)}{wrapper.ending}"
1688+
)
1689+
else:
1690+
wrapper.writeline(f"{expr} = {pexpr(tree.numel)}{wrapper.ending}")
16801691
# We can get symbolic expressions here, like s0*64
16811692
# It is fine to have them here, but we need to handle them correctly as their own type
16821693
# This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
16831694
# scalars as well.
16841695
# This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
16851696
# constant now, need type info. I agree, this needs type info, and while this is not true type info
16861697
# it suffices as a type hint for the purposes of producing the correct code for this type.
1687-
expr = SymbolicCallArg(f"{name}_{tree.prefix}numel")
1688-
# TODO(voz): Tragic. This should at the very least be a util to slapp on declare and ending.
1689-
# The real fix here is to revisit our cross language calling convention.
1690-
code.writeline(
1691-
f"{code.declare}{expr} = {pexpr(tree.numel)}{code.ending}"
1692-
)
1698+
expr = SymbolicCallArg(expr)
16931699

16941700
if tree.prefix != "r" or self.inside_reduction:
16951701
call_args.append(expr)
16961702
if tree.prefix != "r":
16971703
grid.append(expr)
16981704

1699-
code.generate_kernel_call(
1705+
wrapper.generate_kernel_call(
17001706
name,
17011707
call_args,
17021708
grid,
@@ -1985,7 +1991,7 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
19851991
src_code = kernel.codegen_kernel()
19861992
kernel_name = self.define_kernel(src_code, node_schedule)
19871993

1988-
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
1994+
kernel.call_kernel(kernel_name)
19891995

19901996
if (
19911997
V.graph.wrapper_code.supports_intermediate_hooks
@@ -2082,7 +2088,7 @@ def codegen_template(self, template_node, epilogue_nodes):
20822088

20832089
src_code = render()
20842090
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
2085-
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
2091+
kernel.call_kernel(kernel_name)
20862092
self.scheduler.free_buffers()
20872093

20882094
def codegen_sync(self):

torch/_inductor/codegen/wrapper.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
sympy_product,
2525
)
2626
from ..virtualized import V
27-
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
27+
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
2828

2929

3030
pexpr = PythonPrinter().doprint
@@ -253,6 +253,7 @@ def __init__(self):
253253
self.wrapper_call = IndentedBuffer()
254254
self.src_to_kernel = {}
255255
self.kernel_to_hash = {}
256+
self.kenel_numel_expr = set()
256257
self.lines = []
257258
self.declare = ""
258259
self.ending = ""
@@ -665,14 +666,6 @@ def generate_kernel_call(
665666
else:
666667
self.writeline(self.wrap_kernel_call(name, call_args))
667668

668-
def call_kernel(self, name: str, kernel: Kernel):
669-
tmp = IndentedBuffer()
670-
kernel.call_kernel(self, tmp, name)
671-
for line in tmp.getvalue().split("\n"):
672-
line = line.strip()
673-
if line:
674-
self.writeline(line)
675-
676669
def writeline(self, line):
677670
self.lines.append(line)
678671

torch/_inductor/select_algorithm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,26 +289,29 @@ def initialize_range_tree(self, pid_cache):
289289
self.body.clear()
290290
self.indexing_code.clear()
291291

292-
def call_kernel(self, code, name: str):
292+
def call_kernel(self, name: str):
293+
wrapper = V.graph.wrapper_code
293294
_, call_args, _ = self.args.python_argdefs()
294295

295296
for i in range(len(call_args)):
296297
if V.graph.is_unspec_arg(call_args[i]):
297298
call_args[i] = call_args[i] + ".item()"
298299
call_args = ", ".join(call_args)
299300

300-
stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
301+
stream_name = wrapper.write_get_cuda_stream(
302+
V.graph.scheduler.current_device.index
303+
)
301304

302-
V.graph.wrapper_code.add_import_once(f"import {self.grid_fn.__module__}")
303-
meta = V.graph.wrapper_code.add_meta_once(self.meta)
305+
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
306+
meta = wrapper.add_meta_once(self.meta)
304307

305308
grid_call = [texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes] + [
306309
meta
307310
]
308311
grid_call = (
309312
f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
310313
)
311-
code.writeline(
314+
wrapper.writeline(
312315
f"{name}.run({call_args}, grid={grid_call}, stream={stream_name})"
313316
)
314317

0 commit comments

Comments
 (0)