Skip to content

[WIN] torch.compile for XPU using MSVC will fail with AssertionError #4465

@e87tn95h

Description

@e87tn95h

Describe the bug

I just follow "Set Up Environment" steps in How to use torch.compile on Windows CPU/XPU
docs for XPU on Windows 11, PyTorch 2.7 and get this AssertionError.

To reproduce it, you should remove triton subfolder in TORCHINDUCTOR_CACHE_DIR, $Env:TEMP\torchinductor_$Env:USERNAME if it available on your machine.

Code:

import torch
device="xpu"
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10).to(device), torch.randn(10, 10).to(device)))

Output (tail part):

  File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__    
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
    include_dir, self._libsycl_dir = find_sycl(include_dir)
                                     ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
    assert len(sycl_dirs) != 0
           ^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Output (All):

E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Triton compilation failed: triton_poi_fused_add_cos_sin_0
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] def triton_poi_fused_add_cos_sin_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xnumel = 100
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xoffset = tl.program_id(0) * XBLOCK
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xmask = xindex < xnumel
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     x0 = xindex
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp1 = tl_math.sin(tmp0)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp2 = tl_math.cos(tmp0)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp3 = tmp1 + tmp2
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tl.store(out_ptr0 + (x0), tmp3, xmask)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] metadata: {'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': 0, 'constants': {'XBLOCK': 128}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'device_type': 'xpu', 'num_warps': 4, 'num_stages': 1, 'debug': True, 'cc': {'architecture': 13182698496, 'driver_version': '1.3.29516', 'gpu_eu_count': 128, 'gpu_subslice_count': 16, 'has_atomic64': True, 'has_bfloat16_conversions': False, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': False, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 128, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 15526334464, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.71.4'}}
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Traceback (most recent call last):
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     binary = triton.compile(*compile_args, **compile_kwargs)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     backend = make_backend(target)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]               ^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     return actives[0](target)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]            ^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     if COMPILATION_HELPER.libsycl_dir:
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     val = self.func(instance)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     self._compute_compilation_options_lazy
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     val = self.func(instance)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     include_dir, self._libsycl_dir = find_sycl(include_dir)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]                                      ^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     assert len(sycl_dirs) != 0
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]            ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] AssertionError
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Triton compilation failed: triton_poi_fused_add_cos_sin_0
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] def triton_poi_fused_add_cos_sin_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xnumel = 100
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xoffset = tl.program_id(0) * XBLOCK
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     xmask = xindex < xnumel
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     x0 = xindex
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp1 = tl_math.sin(tmp0)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp2 = tl_math.cos(tmp0)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tmp3 = tmp1 + tmp2
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     tl.store(out_ptr0 + (x0), tmp3, xmask)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] metadata: {'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': 0, 'constants': {'XBLOCK': 128}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'device_type': 'xpu', 'num_warps': 4, 'num_stages': 1, 'debug': True, 'cc': {'architecture': 13182698496, 'driver_version': '1.3.29516', 'gpu_eu_count': 128, 'gpu_subslice_count': 16, 'has_atomic64': True, 'has_bfloat16_conversions': False, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': False, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 128, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 15526334464, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.71.4'}}
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Traceback (most recent call last):
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     binary = triton.compile(*compile_args, **compile_kwargs)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     backend = make_backend(target)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]               ^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     return actives[0](target)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]            ^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     if COMPILATION_HELPER.libsycl_dir:
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     val = self.func(instance)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     self._compute_compilation_options_lazy
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     val = self.func(instance)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]           ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     include_dir, self._libsycl_dir = find_sycl(include_dir)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]                                      ^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]   File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]     assert len(sycl_dirs) != 0
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]            ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] AssertionError
Traceback (most recent call last):
  File "C:\Users\__username__\PyTorch\pytorch-test\example_xpu.py", line 8, in <module>
    print(opt_foo1(torch.randn(10, 10).to(device), torch.randn(10, 10).to(device)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\output_graph.py", line 1544, in _call_user_compiler
    raise BackendCompilerFailed(
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\output_graph.py", line 1519, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\__init__.py", line 2347, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 2088, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\backends\common.py", line 101, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 1168, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\_aot_autograd\autograd_cache.py", line 775, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 1153, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py", line 219, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 1943, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 628, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\repro\after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 720, in _compile_fx_inner
    mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 1296, in load_with_key
    compiled_graph, cache_info = FxGraphCache._lookup_graph(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 1060, in _lookup_graph
    artifact_path = graph.after_deserialization(constants)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\output_code.py", line 554, in after_deserialization
    self.current_callable = PyCodeCache.load_by_key_path(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 2747, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\compile_tasks.py", line 36, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "C:\Users\__username__\AppData\Local\Temp\torchinductor___username__\wu\cwuzqx6h22bkw6egpccl4pikysbkqih24mbrmilaminlxsjbh7ic.py", line 46, in <module>
    triton_poi_fused_add_cos_sin_0 = async_compile.triton('triton_poi_fused_add_cos_sin_0', '''
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\async_compile.py", line 346, in triton
    kernel.precompile(warm_cache_only=False)
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 276, in precompile
    self._precompile_worker()
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 296, in _precompile_worker
    compile_results.append(self._precompile_config(c))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
    backend = make_backend(target)
              ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
    return actives[0](target)
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
    mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
    if COMPILATION_HELPER.libsycl_dir:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
    self._compute_compilation_options_lazy
  File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
    include_dir, self._libsycl_dir = find_sycl(include_dir)
                                     ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
    assert len(sycl_dirs) != 0
           ^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

My Findings

I found this assert is at the end of these lines. We will visit these lines if we do not have both icpx command and ONEAPI_ROOT environment variable at runtime.

sycl_dirs = []
for f in importlib.metadata.files("intel-sycl-rt"):
# sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders
# being add: include and include/sycl.
if f.name == "sycl.hpp":
include_dir += [str(f.locate().parent.parent.resolve())]
if f.name in ["libsycl.so", "sycl8.dll", "sycl8.lib"]:
sycl_dir = str(f.locate().parent.resolve())
# should we handle `_` somehow?
if os.name == "nt":
_ = os.add_dll_directory(sycl_dir)
sycl_dirs.append(sycl_dir)
assert len(sycl_dirs) != 0

Also I found existing issue mentioned at #3175 (comment). This issue is marked as closed but it still persist.

f.name results of PackagePath in the list returned by importlib.metadata.files have slightly differece between Windows and Linux as the following results. It looks safe to use f.locate() to get absolute path first for using filename.

# libsycl_check.py
import importlib.metadata

def test():
    for f in importlib.metadata.files("intel-sycl-rt"):
        for name in ["sycl.hpp", "libsycl.so", "sycl8.dll", "sycl8.lib"]:
            if name in str(f):
                print(f"{str(f)}  {str(f.name)}  {f.locate().name}")
                break

if __name__ == "__main__":
    test()

Windows:

> python .\libsycl_check.py
..\..\Library\bin\sycl8.dll  ..\..\Library\bin\sycl8.dll  sycl8.dll
..\..\Library\include\sycl\CL\sycl.hpp  ..\..\Library\include\sycl\CL\sycl.hpp  sycl.hpp
..\..\Library\include\sycl\sycl.hpp  ..\..\Library\include\sycl\sycl.hpp  sycl.hpp
..\..\Library\lib\sycl8.lib  ..\..\Library\lib\sycl8.lib  sycl8.lib

Linux:

$ python libsycl_check.py
../../../include/sycl/CL/sycl.hpp  sycl.hpp  sycl.hpp
../../../include/sycl/sycl.hpp  sycl.hpp  sycl.hpp
../../__pycache__/libsycl.so.8.0.0-gdb.cpython-312.pyc  libsycl.so.8.0.0-gdb.cpython-312.pyc  libsycl.so.8.0.0-gdb.cpython-312.pyc
../../libsycl.so  libsycl.so  libsycl.so
../../libsycl.so.8  libsycl.so.8  libsycl.so.8
../../libsycl.so.8.0.0  libsycl.so.8.0.0  libsycl.so.8.0.0
../../libsycl.so.8.0.0-gdb.py  libsycl.so.8.0.0-gdb.py  libsycl.so.8.0.0-gdb.py

Finally, I apologize you that I would not able to make and send a PR for your project. If this idea is reasonable for you and your project, feel free to take it and close this anytime.

Environment details

Triton: 3.3.1 (dependencies by torch 2.7.1+xpu)
GPU: MTL (Intel Core Ultra 7 155H)

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions