diff --git a/conformance_tests/core/test_mutable_cmdlist/src/test_mutable_cmdlist.cpp b/conformance_tests/core/test_mutable_cmdlist/src/test_mutable_cmdlist.cpp index 89d3dc405..57043e21b 100644 --- a/conformance_tests/core/test_mutable_cmdlist/src/test_mutable_cmdlist.cpp +++ b/conformance_tests/core/test_mutable_cmdlist/src/test_mutable_cmdlist.cpp @@ -1293,6 +1293,303 @@ TEST_P( lzt::destroy_function(mulKernel); } +TEST_F( + zeMutableCommandListTests, + GivenMutationOfMultipleKernelCapabilitiesAndEventsWhenCommandListIsClosedThenEverythingIsUpdatedCorrectly) { + if (!CheckExtensionSupport(ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1) || + !signalEventSupport || !waitEventsSupport || !kernelInstructionSupport || + !globalOffsetSupport || !groupCountSupport || !groupSizeSupport || + !kernelArgumentsSupport) { + GTEST_SKIP() << "Not all required extensions are supported"; + } + + const int32_t buffer_size = 16384; + const int32_t init_buffer_val = 100; + const int32_t add_val = 20; + const int32_t mul_val = 30; + const int32_t sub_val = 40; + const int32_t div_val = 4; + const int32_t global_offset_x = 5; + const int32_t mutated_global_offset_x = 5; + const int32_t mutated_mul_val = 60; + const int32_t mutated_sub_val = 80; + const int32_t part_of_buffer_to_fill_1 = 2; + const int32_t part_of_buffer_to_fill_2 = 8; + + lzt::zeEventPool event_pool; + const uint32_t events_number = 4; + std::vector events(events_number, nullptr); + event_pool.InitEventPool(context, events_number, + ZE_EVENT_POOL_FLAG_HOST_VISIBLE); + event_pool.create_events(events, events_number); + + int32_t *in_out_buffer_1 = reinterpret_cast( + lzt::allocate_host_memory(buffer_size * sizeof(int32_t))); + int32_t *in_out_buffer_2 = reinterpret_cast( + lzt::allocate_host_memory(buffer_size * sizeof(int32_t))); + for (size_t i = 0; i < buffer_size; i++) { + in_out_buffer_1[i] = init_buffer_val; + in_out_buffer_2[i] = init_buffer_val; + } + + uint32_t group_size_x = 0; + uint32_t group_size_y = 0; + uint32_t group_size_z = 0; + + ze_kernel_handle_t add_kernel = lzt::create_function(module, "addValue"); + ze_kernel_handle_t mul_kernel = lzt::create_function(module, "mulValue"); + ze_kernel_handle_t sub_kernel = lzt::create_function(module, "subValue"); + ze_kernel_handle_t div_kernel = lzt::create_function(module, "divValue"); + + uint64_t kernel_command_id_1 = 0; + uint64_t kernel_command_id_2 = 0; + uint64_t kernel_command_id_3 = 0; + std::vector kernels{add_kernel, mul_kernel, sub_kernel, + div_kernel}; + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_SIGNAL_EVENT | + ZE_MUTABLE_COMMAND_EXP_FLAG_WAIT_EVENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION | + ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET; + + lzt::suggest_group_size(add_kernel, buffer_size, 1, 1, group_size_x, + group_size_y, group_size_z); + const uint32_t mutated_group_size_x = group_size_x / 2; + + ze_group_count_t group_count{ + buffer_size / group_size_x / part_of_buffer_to_fill_1, 1, 1}; + ze_group_count_t mutated_group_count{ + buffer_size / mutated_group_size_x / part_of_buffer_to_fill_2, 1, 1}; + + // 1 add_kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeKernelSetGlobalOffsetExp(add_kernel, global_offset_x, 0, 0)); + lzt::set_group_size(add_kernel, group_size_x, group_size_y, group_size_z); + lzt::set_argument_value(add_kernel, 0, sizeof(void *), &in_out_buffer_1); + lzt::set_argument_value(add_kernel, 1, sizeof(add_val), &add_val); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernel_command_id_1)); + lzt::append_launch_function(mutableCmdList, add_kernel, &group_count, + events[0], 0, nullptr); + // 2 mul_kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeKernelSetGlobalOffsetExp(mul_kernel, global_offset_x, 0, 0)); + lzt::set_group_size(mul_kernel, group_size_x, group_size_y, group_size_z); + lzt::set_argument_value(mul_kernel, 0, sizeof(void *), &in_out_buffer_1); + lzt::set_argument_value(mul_kernel, 1, sizeof(mul_val), &mul_val); + + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernel_command_id_2)); + lzt::append_launch_function(mutableCmdList, mul_kernel, &group_count, + events[1], 1, &events[0]); + // 3 sub_kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeKernelSetGlobalOffsetExp(sub_kernel, global_offset_x, 0, 0)); + lzt::set_group_size(sub_kernel, group_size_x, group_size_y, group_size_z); + lzt::set_argument_value(sub_kernel, 0, sizeof(void *), &in_out_buffer_1); + lzt::set_argument_value(sub_kernel, 1, sizeof(sub_val), &sub_val); + + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernel_command_id_3)); + lzt::append_launch_function(mutableCmdList, sub_kernel, &group_count, nullptr, + 2, &events[0]); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[1])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[0])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[1])); + const uint32_t first_result = + ((init_buffer_val + add_val) * mul_val) - sub_val; + for (size_t i = 0; i < global_offset_x; i++) { + EXPECT_EQ(in_out_buffer_1[i], init_buffer_val); + } + for (size_t i = global_offset_x; + i < buffer_size / part_of_buffer_to_fill_1 + global_offset_x; i++) { + EXPECT_EQ(in_out_buffer_1[i], first_result); + } + for (size_t i = buffer_size / part_of_buffer_to_fill_1 + global_offset_x; + i < buffer_size; i++) { + EXPECT_EQ(in_out_buffer_1[i], init_buffer_val); + } + + // Update events + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandSignalEventExp( + mutableCmdList, kernel_command_id_1, events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandSignalEventExp( + mutableCmdList, kernel_command_id_2, events[3])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandWaitEventsExp( + mutableCmdList, kernel_command_id_2, 1, &events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandWaitEventsExp( + mutableCmdList, kernel_command_id_3, 2, &events[2])); + + // Change kernels sequence from add, mul, sub to mul, sub, div + std::vector commandIds{kernel_command_id_1, kernel_command_id_2, + kernel_command_id_3}; + std::vector newSequenceOfKernels{mul_kernel, sub_kernel, + div_kernel}; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandKernelsExp( + mutableCmdList, 3, commandIds.data(), + newSequenceOfKernels.data())); + + // Mutate invalidated data for kernel 1 + ze_mutable_global_offset_exp_desc_t mutate_global_offset = { + ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC}; + mutate_global_offset.commandId = kernel_command_id_1; + mutate_global_offset.offsetX = mutated_global_offset_x; + mutate_global_offset.offsetY = 0; + mutate_global_offset.offsetZ = 0; + mutate_global_offset.pNext = nullptr; + ze_mutable_group_count_exp_desc_t mutate_group_count{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutate_group_count.commandId = kernel_command_id_1; + mutate_group_count.pGroupCount = &mutated_group_count; + mutate_group_count.pNext = &mutate_global_offset; + ze_mutable_group_size_exp_desc_t mutate_group_size{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutate_group_size.commandId = kernel_command_id_1; + mutate_group_size.groupSizeX = mutated_group_size_x; + mutate_group_size.groupSizeY = group_size_y; + mutate_group_size.groupSizeZ = group_size_z; + mutate_group_size.pNext = &mutate_group_count; + ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_buffer_kernel_arg.commandId = kernel_command_id_1; + mutate_buffer_kernel_arg.argIndex = 0; + mutate_buffer_kernel_arg.argSize = sizeof(void *); + mutate_buffer_kernel_arg.pArgValue = &in_out_buffer_2; + mutate_buffer_kernel_arg.pNext = &mutate_group_size; + ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_scalar_kernel_arg.commandId = kernel_command_id_1; + mutate_scalar_kernel_arg.argIndex = 1; + mutate_scalar_kernel_arg.argSize = sizeof(mutated_mul_val); + mutate_scalar_kernel_arg.pArgValue = &mutated_mul_val; + mutate_scalar_kernel_arg.pNext = &mutate_buffer_kernel_arg; + + // Mutate invalidated data for kernel 2 + ze_mutable_global_offset_exp_desc_t mutate_global_offset_2 = { + ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC}; + mutate_global_offset_2.commandId = kernel_command_id_2; + mutate_global_offset_2.offsetX = mutated_global_offset_x; + mutate_global_offset_2.offsetY = 0; + mutate_global_offset_2.offsetZ = 0; + mutate_global_offset_2.pNext = &mutate_scalar_kernel_arg; + ze_mutable_group_count_exp_desc_t mutate_group_count_2{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutate_group_count_2.commandId = kernel_command_id_2; + mutate_group_count_2.pGroupCount = &mutated_group_count; + mutate_group_count_2.pNext = &mutate_global_offset_2; + ze_mutable_group_size_exp_desc_t mutate_group_size_2{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutate_group_size_2.commandId = kernel_command_id_2; + mutate_group_size_2.groupSizeX = mutated_group_size_x; + mutate_group_size_2.groupSizeY = group_size_y; + mutate_group_size_2.groupSizeZ = group_size_z; + mutate_group_size_2.pNext = &mutate_group_count_2; + ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg_2{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_buffer_kernel_arg_2.commandId = kernel_command_id_2; + mutate_buffer_kernel_arg_2.argIndex = 0; + mutate_buffer_kernel_arg_2.argSize = sizeof(void *); + mutate_buffer_kernel_arg_2.pArgValue = &in_out_buffer_2; + mutate_buffer_kernel_arg_2.pNext = &mutate_group_size_2; + ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg_2{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_scalar_kernel_arg_2.commandId = kernel_command_id_2; + mutate_scalar_kernel_arg_2.argIndex = 1; + mutate_scalar_kernel_arg_2.argSize = sizeof(mutated_sub_val); + mutate_scalar_kernel_arg_2.pArgValue = &mutated_sub_val; + mutate_scalar_kernel_arg_2.pNext = &mutate_buffer_kernel_arg_2; + + // Mutate invalidated data for kernel 3 + ze_mutable_global_offset_exp_desc_t mutate_global_offset_3 = { + ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC}; + mutate_global_offset_3.commandId = kernel_command_id_3; + mutate_global_offset_3.offsetX = mutated_global_offset_x; + mutate_global_offset_3.offsetY = 0; + mutate_global_offset_3.offsetZ = 0; + mutate_global_offset_3.pNext = &mutate_scalar_kernel_arg_2; + ze_mutable_group_count_exp_desc_t mutate_group_count_3{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutate_group_count_3.commandId = kernel_command_id_3; + mutate_group_count_3.pGroupCount = &mutated_group_count; + mutate_group_count_3.pNext = &mutate_global_offset_3; + ze_mutable_group_size_exp_desc_t mutate_group_size_3{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutate_group_size_3.commandId = kernel_command_id_3; + mutate_group_size_3.groupSizeX = mutated_group_size_x; + mutate_group_size_3.groupSizeY = group_size_y; + mutate_group_size_3.groupSizeZ = group_size_z; + mutate_group_size_3.pNext = &mutate_group_count_3; + ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg_3{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_buffer_kernel_arg_3.commandId = kernel_command_id_3; + mutate_buffer_kernel_arg_3.argIndex = 0; + mutate_buffer_kernel_arg_3.argSize = sizeof(void *); + mutate_buffer_kernel_arg_3.pArgValue = &in_out_buffer_2; + mutate_buffer_kernel_arg_3.pNext = &mutate_group_size_3; + ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg_3{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutate_scalar_kernel_arg_3.commandId = kernel_command_id_3; + mutate_scalar_kernel_arg_3.argIndex = 1; + mutate_scalar_kernel_arg_3.argSize = sizeof(div_val); + mutate_scalar_kernel_arg_3.pArgValue = &div_val; + mutate_scalar_kernel_arg_3.pNext = &mutate_buffer_kernel_arg_3; + + mutableCmdDesc.pNext = &mutate_scalar_kernel_arg_3; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp( + mutableCmdList, &mutableCmdDesc)); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[0])); + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[1])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[3])); + const uint32_t second_result = + ((init_buffer_val * mutated_mul_val) - mutated_sub_val) / div_val; + for (size_t i = 0; i < mutated_global_offset_x; i++) { + EXPECT_EQ(in_out_buffer_2[i], init_buffer_val); + } + for (size_t i = mutated_global_offset_x; + i < buffer_size / part_of_buffer_to_fill_2 + mutated_global_offset_x; + i++) { + EXPECT_EQ(in_out_buffer_2[i], second_result); + } + for (size_t i = + buffer_size / part_of_buffer_to_fill_2 + mutated_global_offset_x; + i < buffer_size; i++) { + EXPECT_EQ(in_out_buffer_2[i], init_buffer_val); + } + + event_pool.destroy_events(events); + lzt::free_memory(in_out_buffer_1); + lzt::free_memory(in_out_buffer_2); + lzt::destroy_function(add_kernel); + lzt::destroy_function(mul_kernel); + lzt::destroy_function(sub_kernel); + lzt::destroy_function(div_kernel); +} + INSTANTIATE_TEST_SUITE_P( zeMutableCommandListTests, zeMutableCommandListTestsEvents, testing::Values(ZE_EVENT_POOL_FLAG_HOST_VISIBLE,