-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Describe the bug
I just follow "Set Up Environment" steps in How to use torch.compile on Windows CPU/XPU
docs for XPU on Windows 11, PyTorch 2.7 and get this AssertionError.
To reproduce it, you should remove triton subfolder in TORCHINDUCTOR_CACHE_DIR, $Env:TEMP\torchinductor_$Env:USERNAME if it available on your machine.
Code:
import torch
device="xpu"
def foo(x, y):
a = torch.sin(x)
b = torch.cos(x)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10).to(device), torch.randn(10, 10).to(device)))Output (tail part):
File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
val = self.func(instance)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
include_dir, self._libsycl_dir = find_sycl(include_dir)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
assert len(sycl_dirs) != 0
^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Output (All):
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Triton compilation failed: triton_poi_fused_add_cos_sin_0
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] def triton_poi_fused_add_cos_sin_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xnumel = 100
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xoffset = tl.program_id(0) * XBLOCK
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xindex = xoffset + tl.arange(0, XBLOCK)[:]
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xmask = xindex < xnumel
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] x0 = xindex
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp0 = tl.load(in_ptr0 + (x0), xmask)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp1 = tl_math.sin(tmp0)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp2 = tl_math.cos(tmp0)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp3 = tmp1 + tmp2
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tl.store(out_ptr0 + (x0), tmp3, xmask)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] metadata: {'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': 0, 'constants': {'XBLOCK': 128}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'device_type': 'xpu', 'num_warps': 4, 'num_stages': 1, 'debug': True, 'cc': {'architecture': 13182698496, 'driver_version': '1.3.29516', 'gpu_eu_count': 128, 'gpu_subslice_count': 16, 'has_atomic64': True, 'has_bfloat16_conversions': False, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': False, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 128, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 15526334464, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.71.4'}}
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Traceback (most recent call last):
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] binary = triton.compile(*compile_args, **compile_kwargs)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] backend = make_backend(target)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] return actives[0](target)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] if COMPILATION_HELPER.libsycl_dir:
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] val = self.func(instance)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] self._compute_compilation_options_lazy
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] val = self.func(instance)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] include_dir, self._libsycl_dir = find_sycl(include_dir)
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] assert len(sycl_dirs) != 0
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.411000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] AssertionError
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Triton compilation failed: triton_poi_fused_add_cos_sin_0
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] def triton_poi_fused_add_cos_sin_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xnumel = 100
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xoffset = tl.program_id(0) * XBLOCK
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xindex = xoffset + tl.arange(0, XBLOCK)[:]
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] xmask = xindex < xnumel
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] x0 = xindex
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp0 = tl.load(in_ptr0 + (x0), xmask)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp1 = tl_math.sin(tmp0)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp2 = tl_math.cos(tmp0)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tmp3 = tmp1 + tmp2
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] tl.store(out_ptr0 + (x0), tmp3, xmask)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0]
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] metadata: {'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': 0, 'constants': {'XBLOCK': 128}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'device_type': 'xpu', 'num_warps': 4, 'num_stages': 1, 'debug': True, 'cc': {'architecture': 13182698496, 'driver_version': '1.3.29516', 'gpu_eu_count': 128, 'gpu_subslice_count': 16, 'has_atomic64': True, 'has_bfloat16_conversions': False, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': False, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 128, 'max_num_sub_groups': 128, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [8, 16, 32], 'total_memory': 15526334464, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.71.4'}}
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] Traceback (most recent call last):
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] binary = triton.compile(*compile_args, **compile_kwargs)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] backend = make_backend(target)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] return actives[0](target)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] if COMPILATION_HELPER.libsycl_dir:
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] val = self.func(instance)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] self._compute_compilation_options_lazy
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] val = self.func(instance)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] include_dir, self._libsycl_dir = find_sycl(include_dir)
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] assert len(sycl_dirs) != 0
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] ^^^^^^^^^^^^^^^^^^^
E0609 23:21:43.865000 1924 venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py:539] [0/0] AssertionError
Traceback (most recent call last):
File "C:\Users\__username__\PyTorch\pytorch-test\example_xpu.py", line 8, in <module>
print(opt_foo1(torch.randn(10, 10).to(device), torch.randn(10, 10).to(device)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\eval_frame.py", line 663, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\output_graph.py", line 1544, in _call_user_compiler
raise BackendCompilerFailed(
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\output_graph.py", line 1519, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\__init__.py", line 2347, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 2088, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\backends\common.py", line 101, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 1168, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\_aot_autograd\autograd_cache.py", line 775, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 1153, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 820, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py", line 219, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_functorch\aot_autograd.py", line 479, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 1943, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 628, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_dynamo\repro\after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\compile_fx.py", line 720, in _compile_fx_inner
mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 1296, in load_with_key
compiled_graph, cache_info = FxGraphCache._lookup_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 1060, in _lookup_graph
artifact_path = graph.after_deserialization(constants)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\output_code.py", line 554, in after_deserialization
self.current_callable = PyCodeCache.load_by_key_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\codecache.py", line 2747, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\compile_tasks.py", line 36, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "C:\Users\__username__\AppData\Local\Temp\torchinductor___username__\wu\cwuzqx6h22bkw6egpccl4pikysbkqih24mbrmilaminlxsjbh7ic.py", line 46, in <module>
triton_poi_fused_add_cos_sin_0 = async_compile.triton('triton_poi_fused_add_cos_sin_0', '''
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\async_compile.py", line 346, in triton
kernel.precompile(warm_cache_only=False)
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 276, in precompile
self._precompile_worker()
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 296, in _precompile_worker
compile_results.append(self._precompile_config(c))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\torch\_inductor\runtime\triton_heuristics.py", line 537, in _precompile_config
binary = triton.compile(*compile_args, **compile_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 220, in compile
backend = make_backend(target)
^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\compiler\compiler.py", line 326, in make_backend
return actives[0](target)
^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\compiler.py", line 134, in __init__
mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 265, in compile_module_from_src
if COMPILATION_HELPER.libsycl_dir:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
val = self.func(instance)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 145, in libsycl_dir
self._compute_compilation_options_lazy
File "C:\Users\__username__\AppData\Local\Programs\Python\Python311\Lib\functools.py", line 1001, in __get__
val = self.func(instance)
^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 109, in _compute_compilation_options_lazy
include_dir, self._libsycl_dir = find_sycl(include_dir)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\__username__\PyTorch\pytorch-test\venv-pytorch27-xpu\Lib\site-packages\triton\backends\intel\driver.py", line 78, in find_sycl
assert len(sycl_dirs) != 0
^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
My Findings
I found this assert is at the end of these lines. We will visit these lines if we do not have both icpx command and ONEAPI_ROOT environment variable at runtime.
intel-xpu-backend-for-triton/third_party/intel/backend/driver.py
Lines 66 to 79 in 8d32f5d
| sycl_dirs = [] | |
| for f in importlib.metadata.files("intel-sycl-rt"): | |
| # sycl/sycl.hpp and sycl/CL/sycl.hpp results in both folders | |
| # being add: include and include/sycl. | |
| if f.name == "sycl.hpp": | |
| include_dir += [str(f.locate().parent.parent.resolve())] | |
| if f.name in ["libsycl.so", "sycl8.dll", "sycl8.lib"]: | |
| sycl_dir = str(f.locate().parent.resolve()) | |
| # should we handle `_` somehow? | |
| if os.name == "nt": | |
| _ = os.add_dll_directory(sycl_dir) | |
| sycl_dirs.append(sycl_dir) | |
| assert len(sycl_dirs) != 0 |
Also I found existing issue mentioned at #3175 (comment). This issue is marked as closed but it still persist.
f.name results of PackagePath in the list returned by importlib.metadata.files have slightly differece between Windows and Linux as the following results. It looks safe to use f.locate() to get absolute path first for using filename.
# libsycl_check.py
import importlib.metadata
def test():
for f in importlib.metadata.files("intel-sycl-rt"):
for name in ["sycl.hpp", "libsycl.so", "sycl8.dll", "sycl8.lib"]:
if name in str(f):
print(f"{str(f)} {str(f.name)} {f.locate().name}")
break
if __name__ == "__main__":
test()Windows:
> python .\libsycl_check.py
..\..\Library\bin\sycl8.dll ..\..\Library\bin\sycl8.dll sycl8.dll
..\..\Library\include\sycl\CL\sycl.hpp ..\..\Library\include\sycl\CL\sycl.hpp sycl.hpp
..\..\Library\include\sycl\sycl.hpp ..\..\Library\include\sycl\sycl.hpp sycl.hpp
..\..\Library\lib\sycl8.lib ..\..\Library\lib\sycl8.lib sycl8.lib
Linux:
$ python libsycl_check.py
../../../include/sycl/CL/sycl.hpp sycl.hpp sycl.hpp
../../../include/sycl/sycl.hpp sycl.hpp sycl.hpp
../../__pycache__/libsycl.so.8.0.0-gdb.cpython-312.pyc libsycl.so.8.0.0-gdb.cpython-312.pyc libsycl.so.8.0.0-gdb.cpython-312.pyc
../../libsycl.so libsycl.so libsycl.so
../../libsycl.so.8 libsycl.so.8 libsycl.so.8
../../libsycl.so.8.0.0 libsycl.so.8.0.0 libsycl.so.8.0.0
../../libsycl.so.8.0.0-gdb.py libsycl.so.8.0.0-gdb.py libsycl.so.8.0.0-gdb.py
Finally, I apologize you that I would not able to make and send a PR for your project. If this idea is reasonable for you and your project, feel free to take it and close this anytime.
Environment details
Triton: 3.3.1 (dependencies by torch 2.7.1+xpu)
GPU: MTL (Intel Core Ultra 7 155H)