Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,303 @@ TEST_P(
lzt::destroy_function(mulKernel);
}

TEST_F(
zeMutableCommandListTests,
GivenMutationOfMultipleKernelCapabilitiesAndEventsWhenCommandListIsClosedThenEverythingIsUpdatedCorrectly) {
if (!CheckExtensionSupport(ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1) ||
!signalEventSupport || !waitEventsSupport || !kernelInstructionSupport ||
!globalOffsetSupport || !groupCountSupport || !groupSizeSupport ||
!kernelArgumentsSupport) {
GTEST_SKIP() << "Not all required extensions are supported";
}

const int32_t buffer_size = 16384;
const int32_t init_buffer_val = 100;
const int32_t add_val = 20;
const int32_t mul_val = 30;
const int32_t sub_val = 40;
const int32_t div_val = 4;
const int32_t global_offset_x = 5;
const int32_t mutated_global_offset_x = 5;
const int32_t mutated_mul_val = 60;
const int32_t mutated_sub_val = 80;
const int32_t part_of_buffer_to_fill_1 = 2;
const int32_t part_of_buffer_to_fill_2 = 8;

lzt::zeEventPool event_pool;
const uint32_t events_number = 4;
std::vector<ze_event_handle_t> events(events_number, nullptr);
event_pool.InitEventPool(context, events_number,
ZE_EVENT_POOL_FLAG_HOST_VISIBLE);
event_pool.create_events(events, events_number);

int32_t *in_out_buffer_1 = reinterpret_cast<int32_t *>(
lzt::allocate_host_memory(buffer_size * sizeof(int32_t)));
int32_t *in_out_buffer_2 = reinterpret_cast<int32_t *>(
lzt::allocate_host_memory(buffer_size * sizeof(int32_t)));
for (size_t i = 0; i < buffer_size; i++) {
in_out_buffer_1[i] = init_buffer_val;
in_out_buffer_2[i] = init_buffer_val;
}

uint32_t group_size_x = 0;
uint32_t group_size_y = 0;
uint32_t group_size_z = 0;

ze_kernel_handle_t add_kernel = lzt::create_function(module, "addValue");
ze_kernel_handle_t mul_kernel = lzt::create_function(module, "mulValue");
ze_kernel_handle_t sub_kernel = lzt::create_function(module, "subValue");
ze_kernel_handle_t div_kernel = lzt::create_function(module, "divValue");

uint64_t kernel_command_id_1 = 0;
uint64_t kernel_command_id_2 = 0;
uint64_t kernel_command_id_3 = 0;
std::vector<ze_kernel_handle_t> kernels{add_kernel, mul_kernel, sub_kernel,
div_kernel};
commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS |
ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT |
ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE |
ZE_MUTABLE_COMMAND_EXP_FLAG_SIGNAL_EVENT |
ZE_MUTABLE_COMMAND_EXP_FLAG_WAIT_EVENTS |
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION |
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;

lzt::suggest_group_size(add_kernel, buffer_size, 1, 1, group_size_x,
group_size_y, group_size_z);
const uint32_t mutated_group_size_x = group_size_x / 2;

ze_group_count_t group_count{
buffer_size / group_size_x / part_of_buffer_to_fill_1, 1, 1};
ze_group_count_t mutated_group_count{
buffer_size / mutated_group_size_x / part_of_buffer_to_fill_2, 1, 1};

// 1 add_kernel
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeKernelSetGlobalOffsetExp(add_kernel, global_offset_x, 0, 0));
lzt::set_group_size(add_kernel, group_size_x, group_size_y, group_size_z);
lzt::set_argument_value(add_kernel, 0, sizeof(void *), &in_out_buffer_1);
lzt::set_argument_value(add_kernel, 1, sizeof(add_val), &add_val);
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListGetNextCommandIdWithKernelsExp(
mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(),
&kernel_command_id_1));
lzt::append_launch_function(mutableCmdList, add_kernel, &group_count,
events[0], 0, nullptr);
// 2 mul_kernel
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeKernelSetGlobalOffsetExp(mul_kernel, global_offset_x, 0, 0));
lzt::set_group_size(mul_kernel, group_size_x, group_size_y, group_size_z);
lzt::set_argument_value(mul_kernel, 0, sizeof(void *), &in_out_buffer_1);
lzt::set_argument_value(mul_kernel, 1, sizeof(mul_val), &mul_val);

EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListGetNextCommandIdWithKernelsExp(
mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(),
&kernel_command_id_2));
lzt::append_launch_function(mutableCmdList, mul_kernel, &group_count,
events[1], 1, &events[0]);
// 3 sub_kernel
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeKernelSetGlobalOffsetExp(sub_kernel, global_offset_x, 0, 0));
lzt::set_group_size(sub_kernel, group_size_x, group_size_y, group_size_z);
lzt::set_argument_value(sub_kernel, 0, sizeof(void *), &in_out_buffer_1);
lzt::set_argument_value(sub_kernel, 1, sizeof(sub_val), &sub_val);

EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListGetNextCommandIdWithKernelsExp(
mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(),
&kernel_command_id_3));
lzt::append_launch_function(mutableCmdList, sub_kernel, &group_count, nullptr,
2, &events[0]);

lzt::close_command_list(mutableCmdList);
lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr);
lzt::synchronize(queue, std::numeric_limits<uint64_t>::max());

EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0]));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[1]));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[0]));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[1]));
const uint32_t first_result =
((init_buffer_val + add_val) * mul_val) - sub_val;
for (size_t i = 0; i < global_offset_x; i++) {
EXPECT_EQ(in_out_buffer_1[i], init_buffer_val);
}
for (size_t i = global_offset_x;
i < buffer_size / part_of_buffer_to_fill_1 + global_offset_x; i++) {
EXPECT_EQ(in_out_buffer_1[i], first_result);
}
for (size_t i = buffer_size / part_of_buffer_to_fill_1 + global_offset_x;
i < buffer_size; i++) {
EXPECT_EQ(in_out_buffer_1[i], init_buffer_val);
}

// Update events
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListUpdateMutableCommandSignalEventExp(
mutableCmdList, kernel_command_id_1, events[2]));
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListUpdateMutableCommandSignalEventExp(
mutableCmdList, kernel_command_id_2, events[3]));
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListUpdateMutableCommandWaitEventsExp(
mutableCmdList, kernel_command_id_2, 1, &events[2]));
EXPECT_EQ(ZE_RESULT_SUCCESS,
zeCommandListUpdateMutableCommandWaitEventsExp(
mutableCmdList, kernel_command_id_3, 2, &events[2]));

// Change kernels sequence from add, mul, sub to mul, sub, div
std::vector<uint64_t> commandIds{kernel_command_id_1, kernel_command_id_2,
kernel_command_id_3};
std::vector<ze_kernel_handle_t> newSequenceOfKernels{mul_kernel, sub_kernel,
div_kernel};

EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandKernelsExp(
mutableCmdList, 3, commandIds.data(),
newSequenceOfKernels.data()));

// Mutate invalidated data for kernel 1
ze_mutable_global_offset_exp_desc_t mutate_global_offset = {
ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC};
mutate_global_offset.commandId = kernel_command_id_1;
mutate_global_offset.offsetX = mutated_global_offset_x;
mutate_global_offset.offsetY = 0;
mutate_global_offset.offsetZ = 0;
mutate_global_offset.pNext = nullptr;
ze_mutable_group_count_exp_desc_t mutate_group_count{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC};
mutate_group_count.commandId = kernel_command_id_1;
mutate_group_count.pGroupCount = &mutated_group_count;
mutate_group_count.pNext = &mutate_global_offset;
ze_mutable_group_size_exp_desc_t mutate_group_size{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC};
mutate_group_size.commandId = kernel_command_id_1;
mutate_group_size.groupSizeX = mutated_group_size_x;
mutate_group_size.groupSizeY = group_size_y;
mutate_group_size.groupSizeZ = group_size_z;
mutate_group_size.pNext = &mutate_group_count;
ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_buffer_kernel_arg.commandId = kernel_command_id_1;
mutate_buffer_kernel_arg.argIndex = 0;
mutate_buffer_kernel_arg.argSize = sizeof(void *);
mutate_buffer_kernel_arg.pArgValue = &in_out_buffer_2;
mutate_buffer_kernel_arg.pNext = &mutate_group_size;
ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_scalar_kernel_arg.commandId = kernel_command_id_1;
mutate_scalar_kernel_arg.argIndex = 1;
mutate_scalar_kernel_arg.argSize = sizeof(mutated_mul_val);
mutate_scalar_kernel_arg.pArgValue = &mutated_mul_val;
mutate_scalar_kernel_arg.pNext = &mutate_buffer_kernel_arg;

// Mutate invalidated data for kernel 2
ze_mutable_global_offset_exp_desc_t mutate_global_offset_2 = {
ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC};
mutate_global_offset_2.commandId = kernel_command_id_2;
mutate_global_offset_2.offsetX = mutated_global_offset_x;
mutate_global_offset_2.offsetY = 0;
mutate_global_offset_2.offsetZ = 0;
mutate_global_offset_2.pNext = &mutate_scalar_kernel_arg;
ze_mutable_group_count_exp_desc_t mutate_group_count_2{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC};
mutate_group_count_2.commandId = kernel_command_id_2;
mutate_group_count_2.pGroupCount = &mutated_group_count;
mutate_group_count_2.pNext = &mutate_global_offset_2;
ze_mutable_group_size_exp_desc_t mutate_group_size_2{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC};
mutate_group_size_2.commandId = kernel_command_id_2;
mutate_group_size_2.groupSizeX = mutated_group_size_x;
mutate_group_size_2.groupSizeY = group_size_y;
mutate_group_size_2.groupSizeZ = group_size_z;
mutate_group_size_2.pNext = &mutate_group_count_2;
ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg_2{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_buffer_kernel_arg_2.commandId = kernel_command_id_2;
mutate_buffer_kernel_arg_2.argIndex = 0;
mutate_buffer_kernel_arg_2.argSize = sizeof(void *);
mutate_buffer_kernel_arg_2.pArgValue = &in_out_buffer_2;
mutate_buffer_kernel_arg_2.pNext = &mutate_group_size_2;
ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg_2{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_scalar_kernel_arg_2.commandId = kernel_command_id_2;
mutate_scalar_kernel_arg_2.argIndex = 1;
mutate_scalar_kernel_arg_2.argSize = sizeof(mutated_sub_val);
mutate_scalar_kernel_arg_2.pArgValue = &mutated_sub_val;
mutate_scalar_kernel_arg_2.pNext = &mutate_buffer_kernel_arg_2;

// Mutate invalidated data for kernel 3
ze_mutable_global_offset_exp_desc_t mutate_global_offset_3 = {
ZE_STRUCTURE_TYPE_MUTABLE_GLOBAL_OFFSET_EXP_DESC};
mutate_global_offset_3.commandId = kernel_command_id_3;
mutate_global_offset_3.offsetX = mutated_global_offset_x;
mutate_global_offset_3.offsetY = 0;
mutate_global_offset_3.offsetZ = 0;
mutate_global_offset_3.pNext = &mutate_scalar_kernel_arg_2;
ze_mutable_group_count_exp_desc_t mutate_group_count_3{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC};
mutate_group_count_3.commandId = kernel_command_id_3;
mutate_group_count_3.pGroupCount = &mutated_group_count;
mutate_group_count_3.pNext = &mutate_global_offset_3;
ze_mutable_group_size_exp_desc_t mutate_group_size_3{
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC};
mutate_group_size_3.commandId = kernel_command_id_3;
mutate_group_size_3.groupSizeX = mutated_group_size_x;
mutate_group_size_3.groupSizeY = group_size_y;
mutate_group_size_3.groupSizeZ = group_size_z;
mutate_group_size_3.pNext = &mutate_group_count_3;
ze_mutable_kernel_argument_exp_desc_t mutate_buffer_kernel_arg_3{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_buffer_kernel_arg_3.commandId = kernel_command_id_3;
mutate_buffer_kernel_arg_3.argIndex = 0;
mutate_buffer_kernel_arg_3.argSize = sizeof(void *);
mutate_buffer_kernel_arg_3.pArgValue = &in_out_buffer_2;
mutate_buffer_kernel_arg_3.pNext = &mutate_group_size_3;
ze_mutable_kernel_argument_exp_desc_t mutate_scalar_kernel_arg_3{
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_scalar_kernel_arg_3.commandId = kernel_command_id_3;
mutate_scalar_kernel_arg_3.argIndex = 1;
mutate_scalar_kernel_arg_3.argSize = sizeof(div_val);
mutate_scalar_kernel_arg_3.pArgValue = &div_val;
mutate_scalar_kernel_arg_3.pNext = &mutate_buffer_kernel_arg_3;

mutableCmdDesc.pNext = &mutate_scalar_kernel_arg_3;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp(
mutableCmdList, &mutableCmdDesc));

lzt::close_command_list(mutableCmdList);
lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr);
lzt::synchronize(queue, std::numeric_limits<uint64_t>::max());

EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[0]));
EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[1]));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[2]));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[3]));
const uint32_t second_result =
((init_buffer_val * mutated_mul_val) - mutated_sub_val) / div_val;
for (size_t i = 0; i < mutated_global_offset_x; i++) {
EXPECT_EQ(in_out_buffer_2[i], init_buffer_val);
}
for (size_t i = mutated_global_offset_x;
i < buffer_size / part_of_buffer_to_fill_2 + mutated_global_offset_x;
i++) {
EXPECT_EQ(in_out_buffer_2[i], second_result);
}
for (size_t i =
buffer_size / part_of_buffer_to_fill_2 + mutated_global_offset_x;
i < buffer_size; i++) {
EXPECT_EQ(in_out_buffer_2[i], init_buffer_val);
}

event_pool.destroy_events(events);
lzt::free_memory(in_out_buffer_1);
lzt::free_memory(in_out_buffer_2);
lzt::destroy_function(add_kernel);
lzt::destroy_function(mul_kernel);
lzt::destroy_function(sub_kernel);
lzt::destroy_function(div_kernel);
}

INSTANTIATE_TEST_SUITE_P(
zeMutableCommandListTests, zeMutableCommandListTestsEvents,
testing::Values(ZE_EVENT_POOL_FLAG_HOST_VISIBLE,
Expand Down
Loading