Description
Working on Julia support for oneAPI, I've isolated a test failure that only occurs on my A770 to the following Julia code:
using oneAPI
# complete reduction values by a group, using local memory for communication
@inline function partialsum_group(val::T, neutral) where {T}
items = get_local_size(0)
item = get_local_id(0)
# local mem for a complete reduction
shared = oneLocalArray(T, (1024,))
@inbounds shared[item] = val
# perform a reduction
d = 1
while d < items
barrier()
index = 2 * d * (item-1) + 1
@inbounds if index <= items
other_val = if index + d <= items
shared[index+d]
else
neutral
end
shared[index] = shared[index] + other_val
end
d *= 2
end
# load the final value on the first item
if item == 1
val = @inbounds shared[item]
end
return val
end
# partial reduction of the input vector, using a grid-stride loop
function partialsum(elements, reduced, input)
localIdx = get_local_id(0)
localDim = get_local_size(0)
groupIdx = get_group_id(0)
groupDim = get_num_groups(0)
@inbounds begin
# load the neutral value
#
# for this MWE, the value is always 0, but hard-coding it makes the bug occur
# less often.
neutral = reduced[groupIdx]
val = neutral + neutral
# reduce serially across chunks of the input vector.
#
# for this MWE, we only execute the loop body once (as we allocate exactly
# items * groups elements), but removing the loop makes the bug happen less often.
ireduce = localIdx + (groupIdx - 1) * localDim
while ireduce <= elements
val = val + input[ireduce]
ireduce += localDim * groupDim
end
# reduce all values within the group
val = partialsum_group(val, neutral)
# write back to memory
if localIdx == 1
reduced[groupIdx] = val
end
end
return
end
items = 800
groups = 100
elements = Int32(items * groups)
input = oneAPI.ones(Int32, elements)
reduced = oneAPI.zeros(Int32, groups)
@oneapi items=items groups=groups partialsum(elements, reduced, input)
reduced
This function computes a partial summation, e.g., starting with an array of 800*100=80000 32-bit integers set to 1, it should reduce to an array of 100 32-bit integers set to 800. However, often I'm getting that some groups haven't fully reduced:
julia> main()
Int32[800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 776, 752, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 768, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800]
Notice the 776 and 752 in there. I haven't been able to reproduce this on other hardware, because lowering the groupsize to e.g. the 512 that my NUC with integrated Xe GPU supports results in the bug not happening anymore. It's of course possible that there's something wrong with my implementation, so feel free to point out any issues.
The above code compiles down to the following LLVM IR:
; ModuleID = 'text'
source_filename = "text"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"
@local_memory = internal unnamed_addr addrspace(3) global [1024 x i32] zeroinitializer, align 4
declare i64 @_Z12get_local_idj(i32)
declare i64 @_Z14get_local_sizej(i32)
declare i64 @_Z12get_group_idj(i32)
declare i64 @_Z14get_num_groupsj(i32)
declare void @_Z7barrierj(i32)
define spir_kernel void @_Z10partialsum5Int32PS_PS_(i32 signext %0, i64 zeroext %1, i64 zeroext %2) local_unnamed_addr #0 {
conversion:
%3 = call i64 @_Z12get_local_idj(i32 0)
%4 = add i64 %3, 1
%5 = call i64 @_Z14get_local_sizej(i32 0)
%6 = call i64 @_Z12get_group_idj(i32 0)
%7 = call i64 @_Z14get_num_groupsj(i32 0)
%8 = inttoptr i64 %1 to i32*
%9 = getelementptr inbounds i32, i32* %8, i64 %6
%10 = load i32, i32* %9, align 1
%11 = shl i32 %10, 1
%12 = mul i64 %6, %5
%13 = add i64 %4, %12
%14 = sext i32 %0 to i64
%.not8 = icmp sgt i64 %13, %14
br i1 %.not8, label %L42, label %L37.lr.ph
L37.lr.ph: ; preds = %conversion
%15 = inttoptr i64 %2 to i32*
%16 = mul i64 %7, %5
br label %L37
L37: ; preds = %L37, %L37.lr.ph
%value_phi110 = phi i32 [ %11, %L37.lr.ph ], [ %20, %L37 ]
%value_phi9 = phi i64 [ %13, %L37.lr.ph ], [ %21, %L37 ]
%17 = add i64 %value_phi9, -1
%18 = getelementptr inbounds i32, i32* %15, i64 %17
%19 = load i32, i32* %18, align 1
%20 = add i32 %19, %value_phi110
%21 = add i64 %value_phi9, %16
%.not = icmp sgt i64 %21, %14
br i1 %.not, label %L42, label %L37
L42: ; preds = %L37, %conversion
%value_phi1.lcssa = phi i32 [ %11, %conversion ], [ %20, %L37 ]
%22 = call i64 @_Z14get_local_sizej(i32 0)
%23 = call i64 @_Z12get_local_idj(i32 0)
%24 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %23
store i32 %value_phi1.lcssa, i32 addrspace(3)* %24, align 4
%.not16 = icmp sgt i64 %22, 1
br i1 %.not16, label %L95, label %L176
L95: ; preds = %L174, %L42
%value_phi27 = phi i64 [ %25, %L174 ], [ 1, %L42 ]
call void @_Z7barrierj(i32 0)
%25 = shl i64 %value_phi27, 1
%26 = mul i64 %25, %23
%27 = or i64 %26, 1
%.not2 = icmp sgt i64 %27, %22
br i1 %.not2, label %L174, label %L104
L104: ; preds = %L95
%28 = add i64 %27, %value_phi27
%.not3 = icmp sgt i64 %28, %22
br i1 %.not3, label %L145, label %L107
L107: ; preds = %L104
%29 = add i64 %28, -1
%30 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %29
%31 = load i32, i32 addrspace(3)* %30, align 4
br label %L145
L145: ; preds = %L107, %L104
%value_phi3 = phi i32 [ %31, %L107 ], [ %10, %L104 ]
%32 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %26
%33 = load i32, i32 addrspace(3)* %32, align 4
%34 = add i32 %33, %value_phi3
store i32 %34, i32 addrspace(3)* %32, align 4
br label %L174
L174: ; preds = %L145, %L95
%.not1 = icmp slt i64 %25, %22
br i1 %.not1, label %L95, label %L176
L176: ; preds = %L174, %L42
%.not4 = icmp eq i64 %23, 0
br i1 %.not4, label %L192, label %L201
L192: ; preds = %L176
%35 = load i32, i32 addrspace(3)* %24, align 4
br label %L201
L201: ; preds = %L192, %L176
%value_phi7 = phi i32 [ %35, %L192 ], [ %value_phi1.lcssa, %L176 ]
%.not5 = icmp eq i64 %3, 0
br i1 %.not5, label %L203, label %L206
L203: ; preds = %L201
store i32 %value_phi7, i32* %9, align 1
br label %L206
L206: ; preds = %L203, %L201
ret void
}
attributes #0 = { "probe-stack"="inline-asm" }
!llvm.module.flags = !{!0, !1}
!opencl.ocl.version = !{!2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2}
!opencl.spirv.version = !{!3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3}
!julia.kernel = !{!4}
!0 = !{i32 2, !"Dwarf Version", i32 4}
!1 = !{i32 2, !"Debug Info Version", i32 3}
!2 = !{i32 2, i32 0}
!3 = !{i32 1, i32 5}
!4 = !{void (i32, i64, i64)* @_Z10partialsum5Int32PS_PS_}
... which is then compiled to the following SPIR-V:
; SPIR-V
; Version: 1.0
; Generator: Khronos LLVM/SPIR-V Translator; 14
; Bound: 100
; Schema: 0
OpCapability Addresses
OpCapability Linkage
OpCapability Kernel
OpCapability Int64
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %17 "_Z10partialsum5Int32PS_PS_" %__spirv_BuiltInLocalInvocationId %__spirv_BuiltInWorkgroupSize %__spirv_BuiltInWorkgroupId %__spirv_BuiltInNumWorkgroups
OpSource OpenCL_C 200000
OpName %local_memory "local_memory"
OpName %__spirv_BuiltInLocalInvocationId "__spirv_BuiltInLocalInvocationId"
OpName %__spirv_BuiltInWorkgroupSize "__spirv_BuiltInWorkgroupSize"
OpName %__spirv_BuiltInWorkgroupId "__spirv_BuiltInWorkgroupId"
OpName %__spirv_BuiltInNumWorkgroups "__spirv_BuiltInNumWorkgroups"
OpName %conversion "conversion"
OpName %L37_lr_ph "L37.lr.ph"
OpName %L37 "L37"
OpName %L42 "L42"
OpName %L95 "L95"
OpName %L104 "L104"
OpName %L107 "L107"
OpName %L145 "L145"
OpName %L174 "L174"
OpName %L176 "L176"
OpName %L192 "L192"
OpName %L201 "L201"
OpName %L203 "L203"
OpName %L206 "L206"
OpName %_not8 ".not8"
OpName %value_phi110 "value_phi110"
OpName %value_phi9 "value_phi9"
OpName %_not ".not"
OpName %value_phi1_lcssa "value_phi1.lcssa"
OpName %_not16 ".not16"
OpName %value_phi27 "value_phi27"
OpName %_not2 ".not2"
OpName %_not3 ".not3"
OpName %value_phi3 "value_phi3"
OpName %_not1 ".not1"
OpName %_not4 ".not4"
OpName %value_phi7 "value_phi7"
OpName %_not5 ".not5"
OpDecorate %__spirv_BuiltInNumWorkgroups BuiltIn NumWorkgroups
OpDecorate %__spirv_BuiltInWorkgroupSize BuiltIn WorkgroupSize
OpDecorate %__spirv_BuiltInWorkgroupId BuiltIn WorkgroupId
OpDecorate %__spirv_BuiltInLocalInvocationId BuiltIn LocalInvocationId
OpDecorate %__spirv_BuiltInLocalInvocationId Constant
OpDecorate %__spirv_BuiltInWorkgroupSize Constant
OpDecorate %__spirv_BuiltInWorkgroupId Constant
OpDecorate %__spirv_BuiltInNumWorkgroups Constant
OpDecorate %19 FuncParamAttr Zext
OpDecorate %20 FuncParamAttr Zext
OpDecorate %18 FuncParamAttr Sext
OpDecorate %__spirv_BuiltInWorkgroupId LinkageAttributes "__spirv_BuiltInWorkgroupId" Import
OpDecorate %__spirv_BuiltInNumWorkgroups LinkageAttributes "__spirv_BuiltInNumWorkgroups" Import
OpDecorate %__spirv_BuiltInWorkgroupSize LinkageAttributes "__spirv_BuiltInWorkgroupSize" Import
OpDecorate %__spirv_BuiltInLocalInvocationId LinkageAttributes "__spirv_BuiltInLocalInvocationId" Import
OpDecorate %local_memory Alignment 4
%ulong = OpTypeInt 64 0
%uint = OpTypeInt 32 0
%ulong_1024 = OpConstant %ulong 1024
%ulong_1 = OpConstant %ulong 1
%uint_1 = OpConstant %uint 1
%ulong_18446744073709551615 = OpConstant %ulong 18446744073709551615
%ulong_0 = OpConstant %ulong 0
%uint_2 = OpConstant %uint 2
%uint_0 = OpConstant %uint 0
%_arr_uint_ulong_1024 = OpTypeArray %uint %ulong_1024
%_ptr_Workgroup__arr_uint_ulong_1024 = OpTypePointer Workgroup %_arr_uint_ulong_1024
%v3ulong = OpTypeVector %ulong 3
%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
%void = OpTypeVoid
%16 = OpTypeFunction %void %uint %ulong %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%bool = OpTypeBool
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%6 = OpConstantNull %_arr_uint_ulong_1024
%local_memory = OpVariable %_ptr_Workgroup__arr_uint_ulong_1024 Workgroup %6
%__spirv_BuiltInLocalInvocationId = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInWorkgroupSize = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInWorkgroupId = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInNumWorkgroups = OpVariable %_ptr_Input_v3ulong Input
%17 = OpFunction %void None %16
%18 = OpFunctionParameter %uint
%19 = OpFunctionParameter %ulong
%20 = OpFunctionParameter %ulong
%conversion = OpLabel
%35 = OpLoad %v3ulong %__spirv_BuiltInLocalInvocationId Aligned 32
%36 = OpCompositeExtract %ulong %35 0
%38 = OpIAdd %ulong %36 %ulong_1
%39 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupSize Aligned 32
%40 = OpCompositeExtract %ulong %39 0
%41 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupId Aligned 32
%42 = OpCompositeExtract %ulong %41 0
%43 = OpLoad %v3ulong %__spirv_BuiltInNumWorkgroups Aligned 32
%44 = OpCompositeExtract %ulong %43 0
%46 = OpConvertUToPtr %_ptr_Function_uint %19
%47 = OpInBoundsPtrAccessChain %_ptr_Function_uint %46 %42
%48 = OpLoad %uint %47 Aligned 1
%50 = OpShiftLeftLogical %uint %48 %uint_1
%51 = OpIMul %ulong %42 %40
%52 = OpIAdd %ulong %38 %51
%53 = OpSConvert %ulong %18
%_not8 = OpSGreaterThan %bool %52 %53
OpBranchConditional %_not8 %L42 %L37_lr_ph
%L37_lr_ph = OpLabel
%56 = OpConvertUToPtr %_ptr_Function_uint %20
%57 = OpIMul %ulong %44 %40
OpBranch %L37
%L37 = OpLabel
%value_phi110 = OpPhi %uint %50 %L37_lr_ph %58 %L37
%value_phi9 = OpPhi %ulong %52 %L37_lr_ph %60 %L37
%63 = OpIAdd %ulong %value_phi9 %ulong_18446744073709551615
%64 = OpInBoundsPtrAccessChain %_ptr_Function_uint %56 %63
%65 = OpLoad %uint %64 Aligned 1
%58 = OpIAdd %uint %65 %value_phi110
%60 = OpIAdd %ulong %value_phi9 %57
%_not = OpSGreaterThan %bool %60 %53
OpBranchConditional %_not %L42 %L37
%L42 = OpLabel
%value_phi1_lcssa = OpPhi %uint %50 %conversion %58 %L37
%70 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupSize Aligned 32
%71 = OpCompositeExtract %ulong %70 0
%72 = OpLoad %v3ulong %__spirv_BuiltInLocalInvocationId Aligned 32
%73 = OpCompositeExtract %ulong %72 0
%76 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %73
OpStore %76 %value_phi1_lcssa Aligned 4
%_not16 = OpSGreaterThan %bool %71 %ulong_1
OpBranchConditional %_not16 %L95 %L176
%L95 = OpLabel
%value_phi27 = OpPhi %ulong %78 %L174 %ulong_1 %L42
OpControlBarrier %uint_2 %uint_2 %uint_0
%78 = OpShiftLeftLogical %ulong %value_phi27 %ulong_1
%83 = OpIMul %ulong %78 %73
%84 = OpBitwiseOr %ulong %83 %ulong_1
%_not2 = OpSGreaterThan %bool %84 %71
OpBranchConditional %_not2 %L174 %L104
%L104 = OpLabel
%86 = OpIAdd %ulong %84 %value_phi27
%_not3 = OpSGreaterThan %bool %86 %71
OpBranchConditional %_not3 %L145 %L107
%L107 = OpLabel
%88 = OpIAdd %ulong %86 %ulong_18446744073709551615
%89 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %88
%90 = OpLoad %uint %89 Aligned 4
OpBranch %L145
%L145 = OpLabel
%value_phi3 = OpPhi %uint %90 %L107 %48 %L104
%92 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %83
%93 = OpLoad %uint %92 Aligned 4
%94 = OpIAdd %uint %93 %value_phi3
OpStore %92 %94 Aligned 4
OpBranch %L174
%L174 = OpLabel
%_not1 = OpSLessThan %bool %78 %71
OpBranchConditional %_not1 %L95 %L176
%L176 = OpLabel
%_not4 = OpIEqual %bool %73 %ulong_0
OpBranchConditional %_not4 %L192 %L201
%L192 = OpLabel
%97 = OpLoad %uint %76 Aligned 4
OpBranch %L201
%L201 = OpLabel
%value_phi7 = OpPhi %uint %97 %L192 %value_phi1_lcssa %L176
%_not5 = OpIEqual %bool %36 %ulong_0
OpBranchConditional %_not5 %L203 %L206
%L203 = OpLabel
OpStore %47 %value_phi7 Aligned 1
OpBranch %L206
%L206 = OpLabel
OpReturn
OpFunctionEnd
I'll attach the compiled kernel to this issue.
For your convenience, I also have a C-based loader of this code:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <level_zero/ze_api.h>
void load_spirv(const char* filename, uint8_t** spirv, size_t* spirv_len);
int main(int argc, char* argv[]) {
if (argc != 2) {
fprintf(stderr, "Usage: %s <path to SPIR-V file>\n", argv[0]);
exit(EXIT_FAILURE);
}
ze_result_t result;
ze_context_handle_t context;
ze_driver_handle_t driver;
ze_device_handle_t device;
ze_module_handle_t module;
ze_kernel_handle_t kernel;
uint8_t* spirv;
size_t spirv_len;
load_spirv(argv[1], &spirv, &spirv_len);
const uint32_t items = 800;
const uint32_t groups = 100;
const uint32_t elements = items * groups;
// Initialize oneAPI Level Zero
result = zeInit(0);
assert(result == ZE_RESULT_SUCCESS);
// Initialize the driver
uint32_t driver_count = 0;
result = zeDriverGet(&driver_count, NULL);
assert(result == ZE_RESULT_SUCCESS && driver_count > 0);
result = zeDriverGet(&driver_count, &driver);
assert(result == ZE_RESULT_SUCCESS);
// Create a context
ze_context_desc_t context_desc = {
.stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC,
.pNext = NULL,
.flags = 0
};
result = zeContextCreate(driver, &context_desc, &context);
assert(result == ZE_RESULT_SUCCESS);
// Get a device handle
uint32_t device_count = 0;
result = zeDeviceGet(driver, &device_count, NULL);
assert(result == ZE_RESULT_SUCCESS);
result = zeDeviceGet(driver, &device_count, &device);
assert(result == ZE_RESULT_SUCCESS);
// Create a module from SPIR-V
ze_module_desc_t module_desc = {
.stype = ZE_STRUCTURE_TYPE_MODULE_DESC,
.pNext = NULL,
.format = ZE_MODULE_FORMAT_IL_SPIRV,
.inputSize = spirv_len,
.pInputModule = spirv,
.pBuildFlags = NULL
};
result = zeModuleCreate(context, device, &module_desc, &module, NULL);
assert(result == ZE_RESULT_SUCCESS);
// Get a kernel handle
ze_kernel_desc_t kernel_desc = {
.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC,
.pNext = NULL,
.flags = 0,
.pKernelName = "_Z10partialsum5Int32PS_PS_"
};
result = zeKernelCreate(module, &kernel_desc, &kernel);
assert(result == ZE_RESULT_SUCCESS);
// Create a command queue
ze_command_queue_desc_t queue_desc = {
.stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
.pNext = NULL,
.ordinal = 0,
.mode = ZE_COMMAND_QUEUE_MODE_DEFAULT,
.priority = ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
.flags = 0
};
ze_command_queue_handle_t queue;
result = zeCommandQueueCreate(context, device, &queue_desc, &queue);
assert(result == ZE_RESULT_SUCCESS);
// Create a command list
ze_command_list_desc_t cmd_list_desc = {
.stype = ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
.pNext = NULL,
.commandQueueGroupOrdinal = 0,
.flags = 0
};
ze_command_list_handle_t cmd_list;
result = zeCommandListCreate(context, device, &cmd_list_desc, &cmd_list);
assert(result == ZE_RESULT_SUCCESS);
for (int iter = 0; iter < 10; ++iter) {
// Allocate device memory
ze_device_mem_alloc_desc_t device_mem_alloc_desc = {
.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
.pNext = NULL,
.flags = 0,
.ordinal = 0
};
int* input;
result = zeMemAllocDevice(context, &device_mem_alloc_desc, elements * sizeof(int), 1, device, (void**)&input);
assert(result == ZE_RESULT_SUCCESS);
int* reduced;
result = zeMemAllocDevice(context, &device_mem_alloc_desc, groups * sizeof(int), 1, device, (void**)&reduced);
assert(result == ZE_RESULT_SUCCESS);
// Initialize input and reduced arrays
int input_value = 1;
result = zeCommandListAppendMemoryFill(cmd_list, input, &input_value, sizeof(int), elements * sizeof(int), NULL, 0, NULL);
assert(result == ZE_RESULT_SUCCESS);
int reduced_value = 0;
result = zeCommandListAppendMemoryFill(cmd_list, reduced, &reduced_value, sizeof(int), groups * sizeof(int), NULL, 0, NULL);
assert(result == ZE_RESULT_SUCCESS);
// Execute and wait for pending commands
// XXX: why is this needed?
result = zeCommandListClose(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueSynchronize(queue, UINT32_MAX);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandListReset(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
// Set kernel arguments
result = zeKernelSetArgumentValue(kernel, 0, sizeof(uint32_t), &elements);
assert(result == ZE_RESULT_SUCCESS);
result = zeKernelSetArgumentValue(kernel, 1, sizeof(int*), &reduced);
assert(result == ZE_RESULT_SUCCESS);
result = zeKernelSetArgumentValue(kernel, 2, sizeof(int*), &input);
assert(result == ZE_RESULT_SUCCESS);
result = zeKernelSetGroupSize(kernel, items, 1, 1);
assert(result == ZE_RESULT_SUCCESS);
// Launch the kernel
ze_group_count_t group_count = {groups, 1, 1};
result = zeCommandListAppendLaunchKernel(cmd_list, kernel, &group_count, NULL, 0, NULL);
assert(result == ZE_RESULT_SUCCESS);
// Execute and wait for pending commands
// XXX: why is this needed?
result = zeCommandListClose(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueSynchronize(queue, UINT32_MAX);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandListReset(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
// Read the reduced array back to the host
int* reduced_host = (int*)calloc(groups, sizeof(int));
result = zeCommandListAppendMemoryCopy(cmd_list, reduced_host, reduced, groups * sizeof(int), NULL, 0, NULL);
assert(result == ZE_RESULT_SUCCESS);
// Execute and wait for pending commands
result = zeCommandListClose(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueSynchronize(queue, UINT32_MAX);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandListReset(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
// Print the reduced items
int sum = 0;
int valid = 1;
printf("Int32[");
for (uint32_t i = 0; i < groups; ++i) {
printf("%d, ", reduced_host[i]);
if (reduced_host[i] != items) {
valid = 0;
}
}
printf("]\n");
// Free memory and exit loop if any element of the reduced array is not equal to items
free(reduced_host);
result = zeMemFree(context, input);
assert(result == ZE_RESULT_SUCCESS);
result = zeMemFree(context, reduced);
assert(result == ZE_RESULT_SUCCESS);
if (!valid) {
break;
}
}
// Clean up
result = zeKernelDestroy(kernel);
assert(result == ZE_RESULT_SUCCESS);
result = zeModuleDestroy(module);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandListDestroy(cmd_list);
assert(result == ZE_RESULT_SUCCESS);
result = zeCommandQueueDestroy(queue);
assert(result == ZE_RESULT_SUCCESS);
result = zeContextDestroy(context);
assert(result == ZE_RESULT_SUCCESS);
free(spirv);
return 0;
}
void load_spirv(const char* filename, uint8_t** spirv, size_t* spirv_len) {
FILE* file = fopen(filename, "rb");
if (!file) {
fprintf(stderr, "Error opening SPIR-V file: %s\n", filename);
exit(EXIT_FAILURE);
}
// Get the file size
fseek(file, 0, SEEK_END);
*spirv_len = ftell(file);
fseek(file, 0, SEEK_SET);
// Allocate memory and read the SPIR-V binary
*spirv = (uint8_t*)malloc(*spirv_len);
if (!*spirv) {
fprintf(stderr, "Error allocating memory for SPIR-V binary\n");
exit(EXIT_FAILURE);
}
if (fread(*spirv, 1, *spirv_len, file) != *spirv_len) {
fprintf(stderr, "Error reading SPIR-V binary\n");
exit(EXIT_FAILURE);
}
fclose(file);
}
Running this code produces similar results:
❯ gcc -g -o partialsum partialsum.c -lze_loader && ./partialsum partialsum.spv
Int32[800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 784, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 784, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, ]
Also note that I'm doing some seemingly redundant sequences of closing/executing the command list and synchronizing the command queue in between memory initialization, kernel launch, and memory readback; I'm not sure why those are needed, as I expected operations on command lists to be ordered? I'd appreciate some clarification here 🙂
All of the above has been tested on Linux, kernel 6.2, with an A770 and using both the drivers packages by Arch Linux (compute-runtime 22.43.24595.30 and graphics compiler 1.0.12812.26) and our own builds (compute-runtime 22.53.25593 and graphics compiler 1.0.13230).