Skip to content

A770: Possible miscompilation of mapreduce-like code #636

Open
@maleadt

Description

@maleadt

Working on Julia support for oneAPI, I've isolated a test failure that only occurs on my A770 to the following Julia code:

using oneAPI

# complete reduction values by a group, using local memory for communication
@inline function partialsum_group(val::T, neutral) where {T}
    items = get_local_size(0)
    item = get_local_id(0)

    # local mem for a complete reduction
    shared = oneLocalArray(T, (1024,))
    @inbounds shared[item] = val

    # perform a reduction
    d = 1
    while d < items
        barrier()
        index = 2 * d * (item-1) + 1
        @inbounds if index <= items
            other_val = if index + d <= items
                shared[index+d]
            else
                neutral
            end
            shared[index] = shared[index] + other_val
        end
        d *= 2
    end

    # load the final value on the first item
    if item == 1
        val = @inbounds shared[item]
    end

    return val
end

# partial reduction of the input vector, using a grid-stride loop
function partialsum(elements, reduced, input)
    localIdx = get_local_id(0)
    localDim = get_local_size(0)
    groupIdx = get_group_id(0)
    groupDim = get_num_groups(0)

    @inbounds begin
        # load the neutral value
        #
        # for this MWE, the value is always 0, but hard-coding it makes the bug occur
        # less often.
        neutral = reduced[groupIdx]
        val = neutral + neutral

        # reduce serially across chunks of the input vector.
        #
        # for this MWE, we only execute the loop body once (as we allocate exactly
        # items * groups elements), but removing the loop makes the bug happen less often.
        ireduce = localIdx + (groupIdx - 1) * localDim
        while ireduce <= elements
            val = val + input[ireduce]
            ireduce += localDim * groupDim
        end

        # reduce all values within the group
        val = partialsum_group(val, neutral)

        # write back to memory
        if localIdx == 1
            reduced[groupIdx] = val
        end
    end

    return
end

items = 800
groups = 100

elements = Int32(items * groups)
input = oneAPI.ones(Int32, elements)
reduced = oneAPI.zeros(Int32, groups)

@oneapi items=items groups=groups partialsum(elements, reduced, input)
reduced

This function computes a partial summation, e.g., starting with an array of 800*100=80000 32-bit integers set to 1, it should reduce to an array of 100 32-bit integers set to 800. However, often I'm getting that some groups haven't fully reduced:

julia> main()
Int32[800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 776, 752, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 768, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800]

Notice the 776 and 752 in there. I haven't been able to reproduce this on other hardware, because lowering the groupsize to e.g. the 512 that my NUC with integrated Xe GPU supports results in the bug not happening anymore. It's of course possible that there's something wrong with my implementation, so feel free to point out any issues.


The above code compiles down to the following LLVM IR:

; ModuleID = 'text'
source_filename = "text"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

@local_memory = internal unnamed_addr addrspace(3) global [1024 x i32] zeroinitializer, align 4

declare i64 @_Z12get_local_idj(i32)

declare i64 @_Z14get_local_sizej(i32)

declare i64 @_Z12get_group_idj(i32)

declare i64 @_Z14get_num_groupsj(i32)

declare void @_Z7barrierj(i32)

define spir_kernel void @_Z10partialsum5Int32PS_PS_(i32 signext %0, i64 zeroext %1, i64 zeroext %2) local_unnamed_addr #0 {
conversion:
  %3 = call i64 @_Z12get_local_idj(i32 0)
  %4 = add i64 %3, 1
  %5 = call i64 @_Z14get_local_sizej(i32 0)
  %6 = call i64 @_Z12get_group_idj(i32 0)
  %7 = call i64 @_Z14get_num_groupsj(i32 0)
  %8 = inttoptr i64 %1 to i32*
  %9 = getelementptr inbounds i32, i32* %8, i64 %6
  %10 = load i32, i32* %9, align 1
  %11 = shl i32 %10, 1
  %12 = mul i64 %6, %5
  %13 = add i64 %4, %12
  %14 = sext i32 %0 to i64
  %.not8 = icmp sgt i64 %13, %14
  br i1 %.not8, label %L42, label %L37.lr.ph

L37.lr.ph:                                        ; preds = %conversion
  %15 = inttoptr i64 %2 to i32*
  %16 = mul i64 %7, %5
  br label %L37

L37:                                              ; preds = %L37, %L37.lr.ph
  %value_phi110 = phi i32 [ %11, %L37.lr.ph ], [ %20, %L37 ]
  %value_phi9 = phi i64 [ %13, %L37.lr.ph ], [ %21, %L37 ]
  %17 = add i64 %value_phi9, -1
  %18 = getelementptr inbounds i32, i32* %15, i64 %17
  %19 = load i32, i32* %18, align 1
  %20 = add i32 %19, %value_phi110
  %21 = add i64 %value_phi9, %16
  %.not = icmp sgt i64 %21, %14
  br i1 %.not, label %L42, label %L37

L42:                                              ; preds = %L37, %conversion
  %value_phi1.lcssa = phi i32 [ %11, %conversion ], [ %20, %L37 ]
  %22 = call i64 @_Z14get_local_sizej(i32 0)
  %23 = call i64 @_Z12get_local_idj(i32 0)
  %24 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %23
  store i32 %value_phi1.lcssa, i32 addrspace(3)* %24, align 4
  %.not16 = icmp sgt i64 %22, 1
  br i1 %.not16, label %L95, label %L176

L95:                                              ; preds = %L174, %L42
  %value_phi27 = phi i64 [ %25, %L174 ], [ 1, %L42 ]
  call void @_Z7barrierj(i32 0)
  %25 = shl i64 %value_phi27, 1
  %26 = mul i64 %25, %23
  %27 = or i64 %26, 1
  %.not2 = icmp sgt i64 %27, %22
  br i1 %.not2, label %L174, label %L104

L104:                                             ; preds = %L95
  %28 = add i64 %27, %value_phi27
  %.not3 = icmp sgt i64 %28, %22
  br i1 %.not3, label %L145, label %L107

L107:                                             ; preds = %L104
  %29 = add i64 %28, -1
  %30 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %29
  %31 = load i32, i32 addrspace(3)* %30, align 4
  br label %L145

L145:                                             ; preds = %L107, %L104
  %value_phi3 = phi i32 [ %31, %L107 ], [ %10, %L104 ]
  %32 = getelementptr inbounds [1024 x i32], [1024 x i32] addrspace(3)* @local_memory, i64 0, i64 %26
  %33 = load i32, i32 addrspace(3)* %32, align 4
  %34 = add i32 %33, %value_phi3
  store i32 %34, i32 addrspace(3)* %32, align 4
  br label %L174

L174:                                             ; preds = %L145, %L95
  %.not1 = icmp slt i64 %25, %22
  br i1 %.not1, label %L95, label %L176

L176:                                             ; preds = %L174, %L42
  %.not4 = icmp eq i64 %23, 0
  br i1 %.not4, label %L192, label %L201

L192:                                             ; preds = %L176
  %35 = load i32, i32 addrspace(3)* %24, align 4
  br label %L201

L201:                                             ; preds = %L192, %L176
  %value_phi7 = phi i32 [ %35, %L192 ], [ %value_phi1.lcssa, %L176 ]
  %.not5 = icmp eq i64 %3, 0
  br i1 %.not5, label %L203, label %L206

L203:                                             ; preds = %L201
  store i32 %value_phi7, i32* %9, align 1
  br label %L206

L206:                                             ; preds = %L203, %L201
  ret void
}

attributes #0 = { "probe-stack"="inline-asm" }

!llvm.module.flags = !{!0, !1}
!opencl.ocl.version = !{!2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2}
!opencl.spirv.version = !{!3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3, !3}
!julia.kernel = !{!4}

!0 = !{i32 2, !"Dwarf Version", i32 4}
!1 = !{i32 2, !"Debug Info Version", i32 3}
!2 = !{i32 2, i32 0}
!3 = !{i32 1, i32 5}
!4 = !{void (i32, i64, i64)* @_Z10partialsum5Int32PS_PS_}

... which is then compiled to the following SPIR-V:

; SPIR-V
; Version: 1.0
; Generator: Khronos LLVM/SPIR-V Translator; 14
; Bound: 100
; Schema: 0
               OpCapability Addresses
               OpCapability Linkage
               OpCapability Kernel
               OpCapability Int64
          %1 = OpExtInstImport "OpenCL.std"
               OpMemoryModel Physical64 OpenCL
               OpEntryPoint Kernel %17 "_Z10partialsum5Int32PS_PS_" %__spirv_BuiltInLocalInvocationId %__spirv_BuiltInWorkgroupSize %__spirv_BuiltInWorkgroupId %__spirv_BuiltInNumWorkgroups
               OpSource OpenCL_C 200000
               OpName %local_memory "local_memory"
               OpName %__spirv_BuiltInLocalInvocationId "__spirv_BuiltInLocalInvocationId"
               OpName %__spirv_BuiltInWorkgroupSize "__spirv_BuiltInWorkgroupSize"
               OpName %__spirv_BuiltInWorkgroupId "__spirv_BuiltInWorkgroupId"
               OpName %__spirv_BuiltInNumWorkgroups "__spirv_BuiltInNumWorkgroups"
               OpName %conversion "conversion"
               OpName %L37_lr_ph "L37.lr.ph"
               OpName %L37 "L37"
               OpName %L42 "L42"
               OpName %L95 "L95"
               OpName %L104 "L104"
               OpName %L107 "L107"
               OpName %L145 "L145"
               OpName %L174 "L174"
               OpName %L176 "L176"
               OpName %L192 "L192"
               OpName %L201 "L201"
               OpName %L203 "L203"
               OpName %L206 "L206"
               OpName %_not8 ".not8"
               OpName %value_phi110 "value_phi110"
               OpName %value_phi9 "value_phi9"
               OpName %_not ".not"
               OpName %value_phi1_lcssa "value_phi1.lcssa"
               OpName %_not16 ".not16"
               OpName %value_phi27 "value_phi27"
               OpName %_not2 ".not2"
               OpName %_not3 ".not3"
               OpName %value_phi3 "value_phi3"
               OpName %_not1 ".not1"
               OpName %_not4 ".not4"
               OpName %value_phi7 "value_phi7"
               OpName %_not5 ".not5"
               OpDecorate %__spirv_BuiltInNumWorkgroups BuiltIn NumWorkgroups
               OpDecorate %__spirv_BuiltInWorkgroupSize BuiltIn WorkgroupSize
               OpDecorate %__spirv_BuiltInWorkgroupId BuiltIn WorkgroupId
               OpDecorate %__spirv_BuiltInLocalInvocationId BuiltIn LocalInvocationId
               OpDecorate %__spirv_BuiltInLocalInvocationId Constant
               OpDecorate %__spirv_BuiltInWorkgroupSize Constant
               OpDecorate %__spirv_BuiltInWorkgroupId Constant
               OpDecorate %__spirv_BuiltInNumWorkgroups Constant
               OpDecorate %19 FuncParamAttr Zext
               OpDecorate %20 FuncParamAttr Zext
               OpDecorate %18 FuncParamAttr Sext
               OpDecorate %__spirv_BuiltInWorkgroupId LinkageAttributes "__spirv_BuiltInWorkgroupId" Import
               OpDecorate %__spirv_BuiltInNumWorkgroups LinkageAttributes "__spirv_BuiltInNumWorkgroups" Import
               OpDecorate %__spirv_BuiltInWorkgroupSize LinkageAttributes "__spirv_BuiltInWorkgroupSize" Import
               OpDecorate %__spirv_BuiltInLocalInvocationId LinkageAttributes "__spirv_BuiltInLocalInvocationId" Import
               OpDecorate %local_memory Alignment 4
      %ulong = OpTypeInt 64 0
       %uint = OpTypeInt 32 0
 %ulong_1024 = OpConstant %ulong 1024
    %ulong_1 = OpConstant %ulong 1
     %uint_1 = OpConstant %uint 1
%ulong_18446744073709551615 = OpConstant %ulong 18446744073709551615
    %ulong_0 = OpConstant %ulong 0
     %uint_2 = OpConstant %uint 2
     %uint_0 = OpConstant %uint 0
%_arr_uint_ulong_1024 = OpTypeArray %uint %ulong_1024
%_ptr_Workgroup__arr_uint_ulong_1024 = OpTypePointer Workgroup %_arr_uint_ulong_1024
    %v3ulong = OpTypeVector %ulong 3
%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
       %void = OpTypeVoid
         %16 = OpTypeFunction %void %uint %ulong %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
       %bool = OpTypeBool
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
          %6 = OpConstantNull %_arr_uint_ulong_1024
%local_memory = OpVariable %_ptr_Workgroup__arr_uint_ulong_1024 Workgroup %6
%__spirv_BuiltInLocalInvocationId = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInWorkgroupSize = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInWorkgroupId = OpVariable %_ptr_Input_v3ulong Input
%__spirv_BuiltInNumWorkgroups = OpVariable %_ptr_Input_v3ulong Input
         %17 = OpFunction %void None %16
         %18 = OpFunctionParameter %uint
         %19 = OpFunctionParameter %ulong
         %20 = OpFunctionParameter %ulong
 %conversion = OpLabel
         %35 = OpLoad %v3ulong %__spirv_BuiltInLocalInvocationId Aligned 32
         %36 = OpCompositeExtract %ulong %35 0
         %38 = OpIAdd %ulong %36 %ulong_1
         %39 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupSize Aligned 32
         %40 = OpCompositeExtract %ulong %39 0
         %41 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupId Aligned 32
         %42 = OpCompositeExtract %ulong %41 0
         %43 = OpLoad %v3ulong %__spirv_BuiltInNumWorkgroups Aligned 32
         %44 = OpCompositeExtract %ulong %43 0
         %46 = OpConvertUToPtr %_ptr_Function_uint %19
         %47 = OpInBoundsPtrAccessChain %_ptr_Function_uint %46 %42
         %48 = OpLoad %uint %47 Aligned 1
         %50 = OpShiftLeftLogical %uint %48 %uint_1
         %51 = OpIMul %ulong %42 %40
         %52 = OpIAdd %ulong %38 %51
         %53 = OpSConvert %ulong %18
      %_not8 = OpSGreaterThan %bool %52 %53
               OpBranchConditional %_not8 %L42 %L37_lr_ph
  %L37_lr_ph = OpLabel
         %56 = OpConvertUToPtr %_ptr_Function_uint %20
         %57 = OpIMul %ulong %44 %40
               OpBranch %L37
        %L37 = OpLabel
%value_phi110 = OpPhi %uint %50 %L37_lr_ph %58 %L37
 %value_phi9 = OpPhi %ulong %52 %L37_lr_ph %60 %L37
         %63 = OpIAdd %ulong %value_phi9 %ulong_18446744073709551615
         %64 = OpInBoundsPtrAccessChain %_ptr_Function_uint %56 %63
         %65 = OpLoad %uint %64 Aligned 1
         %58 = OpIAdd %uint %65 %value_phi110
         %60 = OpIAdd %ulong %value_phi9 %57
       %_not = OpSGreaterThan %bool %60 %53
               OpBranchConditional %_not %L42 %L37
        %L42 = OpLabel
%value_phi1_lcssa = OpPhi %uint %50 %conversion %58 %L37
         %70 = OpLoad %v3ulong %__spirv_BuiltInWorkgroupSize Aligned 32
         %71 = OpCompositeExtract %ulong %70 0
         %72 = OpLoad %v3ulong %__spirv_BuiltInLocalInvocationId Aligned 32
         %73 = OpCompositeExtract %ulong %72 0
         %76 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %73
               OpStore %76 %value_phi1_lcssa Aligned 4
     %_not16 = OpSGreaterThan %bool %71 %ulong_1
               OpBranchConditional %_not16 %L95 %L176
        %L95 = OpLabel
%value_phi27 = OpPhi %ulong %78 %L174 %ulong_1 %L42
               OpControlBarrier %uint_2 %uint_2 %uint_0
         %78 = OpShiftLeftLogical %ulong %value_phi27 %ulong_1
         %83 = OpIMul %ulong %78 %73
         %84 = OpBitwiseOr %ulong %83 %ulong_1
      %_not2 = OpSGreaterThan %bool %84 %71
               OpBranchConditional %_not2 %L174 %L104
       %L104 = OpLabel
         %86 = OpIAdd %ulong %84 %value_phi27
      %_not3 = OpSGreaterThan %bool %86 %71
               OpBranchConditional %_not3 %L145 %L107
       %L107 = OpLabel
         %88 = OpIAdd %ulong %86 %ulong_18446744073709551615
         %89 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %88
         %90 = OpLoad %uint %89 Aligned 4
               OpBranch %L145
       %L145 = OpLabel
 %value_phi3 = OpPhi %uint %90 %L107 %48 %L104
         %92 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uint %local_memory %ulong_0 %83
         %93 = OpLoad %uint %92 Aligned 4
         %94 = OpIAdd %uint %93 %value_phi3
               OpStore %92 %94 Aligned 4
               OpBranch %L174
       %L174 = OpLabel
      %_not1 = OpSLessThan %bool %78 %71
               OpBranchConditional %_not1 %L95 %L176
       %L176 = OpLabel
      %_not4 = OpIEqual %bool %73 %ulong_0
               OpBranchConditional %_not4 %L192 %L201
       %L192 = OpLabel
         %97 = OpLoad %uint %76 Aligned 4
               OpBranch %L201
       %L201 = OpLabel
 %value_phi7 = OpPhi %uint %97 %L192 %value_phi1_lcssa %L176
      %_not5 = OpIEqual %bool %36 %ulong_0
               OpBranchConditional %_not5 %L203 %L206
       %L203 = OpLabel
               OpStore %47 %value_phi7 Aligned 1
               OpBranch %L206
       %L206 = OpLabel
               OpReturn
               OpFunctionEnd

I'll attach the compiled kernel to this issue.


For your convenience, I also have a C-based loader of this code:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <level_zero/ze_api.h>

void load_spirv(const char* filename, uint8_t** spirv, size_t* spirv_len);

int main(int argc, char* argv[]) {
    if (argc != 2) {
        fprintf(stderr, "Usage: %s <path to SPIR-V file>\n", argv[0]);
        exit(EXIT_FAILURE);
    }

    ze_result_t result;
    ze_context_handle_t context;
    ze_driver_handle_t driver;
    ze_device_handle_t device;
    ze_module_handle_t module;
    ze_kernel_handle_t kernel;
    uint8_t* spirv;
    size_t spirv_len;

    load_spirv(argv[1], &spirv, &spirv_len);

    const uint32_t items = 800;
    const uint32_t groups = 100;
    const uint32_t elements = items * groups;

    // Initialize oneAPI Level Zero
    result = zeInit(0);
    assert(result == ZE_RESULT_SUCCESS);

    // Initialize the driver
    uint32_t driver_count = 0;
    result = zeDriverGet(&driver_count, NULL);
    assert(result == ZE_RESULT_SUCCESS && driver_count > 0);
    result = zeDriverGet(&driver_count, &driver);
    assert(result == ZE_RESULT_SUCCESS);

    // Create a context
    ze_context_desc_t context_desc = {
        .stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC,
        .pNext = NULL,
        .flags = 0
    };
    result = zeContextCreate(driver, &context_desc, &context);
    assert(result == ZE_RESULT_SUCCESS);

    // Get a device handle
    uint32_t device_count = 0;
    result = zeDeviceGet(driver, &device_count, NULL);
    assert(result == ZE_RESULT_SUCCESS);
    result = zeDeviceGet(driver, &device_count, &device);
    assert(result == ZE_RESULT_SUCCESS);

    // Create a module from SPIR-V
    ze_module_desc_t module_desc = {
        .stype = ZE_STRUCTURE_TYPE_MODULE_DESC,
        .pNext = NULL,
        .format = ZE_MODULE_FORMAT_IL_SPIRV,
        .inputSize = spirv_len,
        .pInputModule = spirv,
        .pBuildFlags = NULL
    };
    result = zeModuleCreate(context, device, &module_desc, &module, NULL);
    assert(result == ZE_RESULT_SUCCESS);

    // Get a kernel handle
    ze_kernel_desc_t kernel_desc = {
        .stype = ZE_STRUCTURE_TYPE_KERNEL_DESC,
        .pNext = NULL,
        .flags = 0,
        .pKernelName = "_Z10partialsum5Int32PS_PS_"
    };
    result = zeKernelCreate(module, &kernel_desc, &kernel);
    assert(result == ZE_RESULT_SUCCESS);

    // Create a command queue
    ze_command_queue_desc_t queue_desc = {
        .stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
        .pNext = NULL,
        .ordinal = 0,
        .mode = ZE_COMMAND_QUEUE_MODE_DEFAULT,
        .priority = ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
        .flags = 0
    };
    ze_command_queue_handle_t queue;
    result = zeCommandQueueCreate(context, device, &queue_desc, &queue);
    assert(result == ZE_RESULT_SUCCESS);

    // Create a command list
    ze_command_list_desc_t cmd_list_desc = {
        .stype = ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
        .pNext = NULL,
        .commandQueueGroupOrdinal = 0,
        .flags = 0
    };
    ze_command_list_handle_t cmd_list;
    result = zeCommandListCreate(context, device, &cmd_list_desc, &cmd_list);
    assert(result == ZE_RESULT_SUCCESS);

    for (int iter = 0; iter < 10; ++iter) {
        // Allocate device memory
        ze_device_mem_alloc_desc_t device_mem_alloc_desc = {
            .stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
            .pNext = NULL,
            .flags = 0,
            .ordinal = 0
        };
        int* input;
        result = zeMemAllocDevice(context, &device_mem_alloc_desc, elements * sizeof(int), 1, device, (void**)&input);
        assert(result == ZE_RESULT_SUCCESS);
        int* reduced;
        result = zeMemAllocDevice(context, &device_mem_alloc_desc, groups * sizeof(int), 1, device, (void**)&reduced);
        assert(result == ZE_RESULT_SUCCESS);

        // Initialize input and reduced arrays
        int input_value = 1;
        result = zeCommandListAppendMemoryFill(cmd_list, input, &input_value, sizeof(int), elements * sizeof(int), NULL, 0, NULL);
        assert(result == ZE_RESULT_SUCCESS);
        int reduced_value = 0;
        result = zeCommandListAppendMemoryFill(cmd_list, reduced, &reduced_value, sizeof(int), groups * sizeof(int), NULL, 0, NULL);
        assert(result == ZE_RESULT_SUCCESS);

        // Execute and wait for pending commands
        // XXX: why is this needed?
        result = zeCommandListClose(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueSynchronize(queue, UINT32_MAX);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandListReset(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);

        // Set kernel arguments
        result = zeKernelSetArgumentValue(kernel, 0, sizeof(uint32_t), &elements);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeKernelSetArgumentValue(kernel, 1, sizeof(int*), &reduced);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeKernelSetArgumentValue(kernel, 2, sizeof(int*), &input);
        assert(result == ZE_RESULT_SUCCESS);

        result = zeKernelSetGroupSize(kernel, items, 1, 1);
        assert(result == ZE_RESULT_SUCCESS);

        // Launch the kernel
        ze_group_count_t group_count = {groups, 1, 1};
        result = zeCommandListAppendLaunchKernel(cmd_list, kernel, &group_count, NULL, 0, NULL);
        assert(result == ZE_RESULT_SUCCESS);

        // Execute and wait for pending commands
        // XXX: why is this needed?
        result = zeCommandListClose(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueSynchronize(queue, UINT32_MAX);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandListReset(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);

        // Read the reduced array back to the host
        int* reduced_host = (int*)calloc(groups, sizeof(int));
        result = zeCommandListAppendMemoryCopy(cmd_list, reduced_host, reduced, groups * sizeof(int), NULL, 0, NULL);
        assert(result == ZE_RESULT_SUCCESS);

        // Execute and wait for pending commands
        result = zeCommandListClose(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueExecuteCommandLists(queue, 1, &cmd_list, NULL);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandQueueSynchronize(queue, UINT32_MAX);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeCommandListReset(cmd_list);
        assert(result == ZE_RESULT_SUCCESS);

        // Print the reduced items
        int sum = 0;
        int valid = 1;
        printf("Int32[");
        for (uint32_t i = 0; i < groups; ++i) {
            printf("%d, ", reduced_host[i]);
            if (reduced_host[i] != items) {
                valid = 0;
            }
        }
        printf("]\n");

        // Free memory and exit loop if any element of the reduced array is not equal to items
        free(reduced_host);
        result = zeMemFree(context, input);
        assert(result == ZE_RESULT_SUCCESS);
        result = zeMemFree(context, reduced);
        assert(result == ZE_RESULT_SUCCESS);
        if (!valid) {
            break;
        }
    }

    // Clean up
    result = zeKernelDestroy(kernel);
    assert(result == ZE_RESULT_SUCCESS);
    result = zeModuleDestroy(module);
    assert(result == ZE_RESULT_SUCCESS);
    result = zeCommandListDestroy(cmd_list);
    assert(result == ZE_RESULT_SUCCESS);
    result = zeCommandQueueDestroy(queue);
    assert(result == ZE_RESULT_SUCCESS);
    result = zeContextDestroy(context);
    assert(result == ZE_RESULT_SUCCESS);
    free(spirv);

    return 0;
}

void load_spirv(const char* filename, uint8_t** spirv, size_t* spirv_len) {
    FILE* file = fopen(filename, "rb");
    if (!file) {
        fprintf(stderr, "Error opening SPIR-V file: %s\n", filename);
        exit(EXIT_FAILURE);
    }

    // Get the file size
    fseek(file, 0, SEEK_END);
    *spirv_len = ftell(file);
    fseek(file, 0, SEEK_SET);

    // Allocate memory and read the SPIR-V binary
    *spirv = (uint8_t*)malloc(*spirv_len);
    if (!*spirv) {
        fprintf(stderr, "Error allocating memory for SPIR-V binary\n");
        exit(EXIT_FAILURE);
    }

    if (fread(*spirv, 1, *spirv_len, file) != *spirv_len) {
        fprintf(stderr, "Error reading SPIR-V binary\n");
        exit(EXIT_FAILURE);
    }

    fclose(file);
}

Running this code produces similar results:

❯ gcc -g -o partialsum partialsum.c -lze_loader && ./partialsum partialsum.spv
Int32[800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 784, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 784, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, 800, ]

Also note that I'm doing some seemingly redundant sequences of closing/executing the command list and synchronizing the command queue in between memory initialization, kernel launch, and memory readback; I'm not sure why those are needed, as I expected operations on command lists to be ordered? I'd appreciate some clarification here 🙂

All of the above has been tested on Linux, kernel 6.2, with an A770 and using both the drivers packages by Arch Linux (compute-runtime 22.43.24595.30 and graphics compiler 1.0.12812.26) and our own builds (compute-runtime 22.53.25593 and graphics compiler 1.0.13230).

partialsum.spv.zip

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions