Skip to content

Commit bf6b6f9

Browse files
committed
[CUDA][HIP] Fix kernel arguments being overriden when added out of order
In the Cuda and Hip adapter, when kernel arguments are added out of order (e.g. argument at index 1 is added before argument at index 0), the existing arguments are currently being overwritten. This happens because some of the argument sizes might not be known when adding them out of order and the code relies on those sizes to choose where to store the argument. This commit avoids this issue by storing the arguments in the same order that they are added and accessing them using pointer offsets.
1 parent ef70004 commit bf6b6f9

File tree

4 files changed

+377
-38
lines changed

4 files changed

+377
-38
lines changed

source/adapters/cuda/kernel.hpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ struct ur_kernel_handle_t_ {
6868
args_size_t ParamSizes;
6969
/// Byte offset into /p Storage allocation for each parameter.
7070
args_index_t Indices;
71+
/// Position in the Storage array where the next argument should added.
72+
size_t InsertPos = 0;
7173
/// Aligned size in bytes for each local memory parameter after padding has
7274
/// been added. Zero if the argument at the index isn't a local memory
7375
/// argument.
@@ -101,6 +103,7 @@ struct ur_kernel_handle_t_ {
101103
/// Implicit offset argument is kept at the back of the indices collection.
102104
void addArg(size_t Index, size_t Size, const void *Arg,
103105
size_t LocalSize = 0) {
106+
// Expand storage to accommodate this Index if needed.
104107
if (Index + 2 > Indices.size()) {
105108
// Move implicit offset argument index with the end
106109
Indices.resize(Index + 2, Indices.back());
@@ -109,14 +112,21 @@ struct ur_kernel_handle_t_ {
109112
AlignedLocalMemSize.resize(Index + 1);
110113
OriginalLocalMemSize.resize(Index + 1);
111114
}
112-
ParamSizes[Index] = Size;
113-
// calculate the insertion point on the array
114-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
115-
std::begin(ParamSizes) + Index, 0);
116-
// Update the stored value for the argument
117-
std::memcpy(&Storage[InsertPos], Arg, Size);
118-
Indices[Index] = &Storage[InsertPos];
119-
AlignedLocalMemSize[Index] = LocalSize;
115+
116+
// Copy new argument to storage if it hasn't been added before.
117+
if (ParamSizes[Index] == 0) {
118+
ParamSizes[Index] = Size;
119+
std::memcpy(&Storage[InsertPos], Arg, Size);
120+
Indices[Index] = &Storage[InsertPos];
121+
AlignedLocalMemSize[Index] = LocalSize;
122+
InsertPos += Size;
123+
}
124+
// Otherwise, update the existing argument.
125+
else {
126+
std::memcpy(Indices[Index], Arg, Size);
127+
AlignedLocalMemSize[Index] = LocalSize;
128+
assert(Size == ParamSizes[Index]);
129+
}
120130
}
121131

122132
/// Returns the padded size and offset of a local memory argument.
@@ -177,10 +187,7 @@ struct ur_kernel_handle_t_ {
177187
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178188

179189
// Store new offset into local data
180-
const size_t InsertPos =
181-
std::accumulate(std::begin(ParamSizes),
182-
std::begin(ParamSizes) + SuccIndex, size_t{0});
183-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
190+
std::memcpy(Indices[SuccIndex], &SuccAlignedLocalOffset,
184191
sizeof(size_t));
185192
}
186193
}

source/adapters/hip/kernel.hpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ struct ur_kernel_handle_t_ {
6363
args_size_t ParamSizes;
6464
/// Byte offset into /p Storage allocation for each parameter.
6565
args_index_t Indices;
66+
/// Position in the Storage array where the next argument should added.
67+
size_t InsertPos = 0;
6668
/// Aligned size in bytes for each local memory parameter after padding has
6769
/// been added. Zero if the argument at the index isn't a local memory
6870
/// argument.
@@ -95,22 +97,30 @@ struct ur_kernel_handle_t_ {
9597
/// Implicit offset argument is kept at the back of the indices collection.
9698
void addArg(size_t Index, size_t Size, const void *Arg,
9799
size_t LocalSize = 0) {
100+
// Expand storage to accommodate this Index if needed.
98101
if (Index + 2 > Indices.size()) {
99-
// Move implicit offset argument Index with the end
102+
// Move implicit offset argument index with the end
100103
Indices.resize(Index + 2, Indices.back());
101104
// Ensure enough space for the new argument
102105
ParamSizes.resize(Index + 1);
103106
AlignedLocalMemSize.resize(Index + 1);
104107
OriginalLocalMemSize.resize(Index + 1);
105108
}
106-
ParamSizes[Index] = Size;
107-
// calculate the insertion point on the array
108-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
109-
std::begin(ParamSizes) + Index, 0);
110-
// Update the stored value for the argument
111-
std::memcpy(&Storage[InsertPos], Arg, Size);
112-
Indices[Index] = &Storage[InsertPos];
113-
AlignedLocalMemSize[Index] = LocalSize;
109+
110+
// Copy new argument to storage if it hasn't been added before.
111+
if (ParamSizes[Index] == 0) {
112+
ParamSizes[Index] = Size;
113+
std::memcpy(&Storage[InsertPos], Arg, Size);
114+
Indices[Index] = &Storage[InsertPos];
115+
AlignedLocalMemSize[Index] = LocalSize;
116+
InsertPos += Size;
117+
}
118+
// Otherwise, update the existing argument.
119+
else {
120+
std::memcpy(Indices[Index], Arg, Size);
121+
AlignedLocalMemSize[Index] = LocalSize;
122+
assert(Size == ParamSizes[Index]);
123+
}
114124
}
115125

116126
/// Returns the padded size and offset of a local memory argument.
@@ -151,20 +161,11 @@ struct ur_kernel_handle_t_ {
151161
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
152162
}
153163

154-
void addLocalArg(size_t Index, size_t Size) {
155-
// Get the aligned argument size and offset into local data
156-
auto [AlignedLocalSize, AlignedLocalOffset] =
157-
calcAlignedLocalArgument(Index, Size);
158-
159-
// Store argument details
160-
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
161-
AlignedLocalSize);
162-
163-
// For every existing local argument which follows at later argument
164-
// indices, update the offset and pointer into the kernel local memory.
165-
// Required as padding will need to be recalculated.
164+
// Iterate over all existing local argument which follows StartIndex
165+
// index, update the offset and pointer into the kernel local memory.
166+
void updateLocalArgOffset(size_t StartIndex) {
166167
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
167-
for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) {
168+
for (auto SuccIndex = StartIndex; SuccIndex < NumArgs; SuccIndex++) {
168169
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
169170
if (OriginalLocalSize == 0) {
170171
// Skip if successor argument isn't a local memory arg
@@ -179,14 +180,26 @@ struct ur_kernel_handle_t_ {
179180
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
180181

181182
// Store new offset into local data
182-
const size_t InsertPos =
183-
std::accumulate(std::begin(ParamSizes),
184-
std::begin(ParamSizes) + SuccIndex, size_t{0});
185-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
183+
std::memcpy(Indices[SuccIndex], &SuccAlignedLocalOffset,
186184
sizeof(size_t));
187185
}
188186
}
189187

188+
void addLocalArg(size_t Index, size_t Size) {
189+
// Get the aligned argument size and offset into local data
190+
auto [AlignedLocalSize, AlignedLocalOffset] =
191+
calcAlignedLocalArgument(Index, Size);
192+
193+
// Store argument details
194+
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
195+
AlignedLocalSize);
196+
197+
// For every existing local argument which follows at later argument
198+
// indices, update the offset and pointer into the kernel local memory.
199+
// Required as padding will need to be recalculated.
200+
updateLocalArgOffset(Index + 1);
201+
}
202+
190203
void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
191204
assert(hMem && "Invalid mem handle");
192205
// To avoid redundancy we are not storing mem obj with index i at index

test/conformance/exp_command_buffer/update/local_memory_update.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,3 +1105,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
11051105
uint32_t *new_Y = (uint32_t *)shared_ptrs[4];
11061106
Validate(new_output, new_X, new_Y, new_A, global_size, local_size);
11071107
}
1108+
1109+
struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1110+
virtual void SetUp() override {
1111+
program_name = "saxpy_usm_local_mem";
1112+
UUR_RETURN_ON_FATAL_FAILURE(
1113+
urUpdatableCommandBufferExpExecutionTest::SetUp());
1114+
1115+
if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1116+
GTEST_SKIP()
1117+
<< "Local memory argument update not supported on Level Zero.";
1118+
}
1119+
1120+
// HIP has extra args for local memory so we define an offset for arg
1121+
// indices here for updating
1122+
hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0;
1123+
ur_device_usm_access_capability_flags_t shared_usm_flags;
1124+
ASSERT_SUCCESS(
1125+
uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags));
1126+
if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1127+
GTEST_SKIP() << "Shared USM is not supported.";
1128+
}
1129+
1130+
const size_t allocation_size =
1131+
sizeof(uint32_t) * global_size * local_size;
1132+
for (auto &shared_ptr : shared_ptrs) {
1133+
ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr,
1134+
allocation_size, &shared_ptr));
1135+
ASSERT_NE(shared_ptr, nullptr);
1136+
1137+
std::vector<uint8_t> pattern(allocation_size);
1138+
uur::generateMemFillPattern(pattern);
1139+
std::memcpy(shared_ptr, pattern.data(), allocation_size);
1140+
}
1141+
1142+
std::array<size_t, 12> index_order{};
1143+
if (backend != UR_PLATFORM_BACKEND_HIP) {
1144+
index_order = {3, 2, 4, 5, 1, 0};
1145+
} else {
1146+
index_order = {9, 8, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3};
1147+
}
1148+
size_t current_index = 0;
1149+
1150+
// Index 3 is A
1151+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, index_order[current_index++],
1152+
sizeof(A), nullptr, &A));
1153+
// Index 2 is output
1154+
ASSERT_SUCCESS(urKernelSetArgPointer(
1155+
kernel, index_order[current_index++], nullptr, shared_ptrs[0]));
1156+
1157+
// Index 4 is X
1158+
ASSERT_SUCCESS(urKernelSetArgPointer(
1159+
kernel, index_order[current_index++], nullptr, shared_ptrs[1]));
1160+
// Index 5 is Y
1161+
ASSERT_SUCCESS(urKernelSetArgPointer(
1162+
kernel, index_order[current_index++], nullptr, shared_ptrs[2]));
1163+
1164+
// Index 1 is local_mem_b arg
1165+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, index_order[current_index++],
1166+
local_mem_b_size, nullptr));
1167+
if (backend == UR_PLATFORM_BACKEND_HIP) {
1168+
ASSERT_SUCCESS(urKernelSetArgValue(
1169+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1170+
nullptr, &hip_local_offset));
1171+
ASSERT_SUCCESS(urKernelSetArgValue(
1172+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1173+
nullptr, &hip_local_offset));
1174+
ASSERT_SUCCESS(urKernelSetArgValue(
1175+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1176+
nullptr, &hip_local_offset));
1177+
}
1178+
1179+
// Index 0 is local_mem_a arg
1180+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, index_order[current_index++],
1181+
local_mem_a_size, nullptr));
1182+
1183+
// Hip has extra args for local mem at index 1-3
1184+
if (backend == UR_PLATFORM_BACKEND_HIP) {
1185+
ASSERT_SUCCESS(urKernelSetArgValue(
1186+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1187+
nullptr, &hip_local_offset));
1188+
ASSERT_SUCCESS(urKernelSetArgValue(
1189+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1190+
nullptr, &hip_local_offset));
1191+
ASSERT_SUCCESS(urKernelSetArgValue(
1192+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1193+
nullptr, &hip_local_offset));
1194+
}
1195+
}
1196+
};
1197+
1198+
struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1199+
void SetUp() override {
1200+
UUR_RETURN_ON_FATAL_FAILURE(
1201+
LocalMemoryUpdateTestBaseOutOfOrder::SetUp());
1202+
1203+
// Append kernel command to command-buffer and close command-buffer
1204+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
1205+
updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1206+
&global_size, &local_size, 0, nullptr, 0, nullptr, 0, nullptr,
1207+
nullptr, nullptr, &command_handle));
1208+
ASSERT_NE(command_handle, nullptr);
1209+
1210+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
1211+
}
1212+
1213+
void TearDown() override {
1214+
if (command_handle) {
1215+
EXPECT_SUCCESS(urCommandBufferReleaseCommandExp(command_handle));
1216+
}
1217+
1218+
UUR_RETURN_ON_FATAL_FAILURE(
1219+
LocalMemoryUpdateTestBaseOutOfOrder::TearDown());
1220+
}
1221+
1222+
ur_exp_command_buffer_command_handle_t command_handle = nullptr;
1223+
};
1224+
1225+
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryUpdateTestOutOfOrder);
1226+
1227+
// Test updating A,X,Y parameters to new values and local memory to larger
1228+
// values when the kernel arguments were added out of order.
1229+
TEST_P(LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1230+
// Run command-buffer prior to update and verify output
1231+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
1232+
nullptr, nullptr));
1233+
ASSERT_SUCCESS(urQueueFinish(queue));
1234+
1235+
uint32_t *output = (uint32_t *)shared_ptrs[0];
1236+
uint32_t *X = (uint32_t *)shared_ptrs[1];
1237+
uint32_t *Y = (uint32_t *)shared_ptrs[2];
1238+
Validate(output, X, Y, A, global_size, local_size);
1239+
1240+
// Update inputs
1241+
std::array<ur_exp_command_buffer_update_pointer_arg_desc_t, 2>
1242+
new_input_descs;
1243+
std::array<ur_exp_command_buffer_update_value_arg_desc_t, 3>
1244+
new_value_descs;
1245+
1246+
size_t new_local_size = local_size * 4;
1247+
size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t);
1248+
1249+
// New local_mem_a at index 0
1250+
new_value_descs[0] = {
1251+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1252+
nullptr, // pNext
1253+
0, // argIndex
1254+
new_local_mem_a_size, // argSize
1255+
nullptr, // pProperties
1256+
nullptr, // hArgValue
1257+
};
1258+
1259+
// New local_mem_b at index 1
1260+
new_value_descs[1] = {
1261+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262+
nullptr, // pNext
1263+
1 + hip_arg_offset, // argIndex
1264+
local_mem_b_size, // argSize
1265+
nullptr, // pProperties
1266+
nullptr, // hArgValue
1267+
};
1268+
1269+
// New A at index 3
1270+
uint32_t new_A = 33;
1271+
new_value_descs[2] = {
1272+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1273+
nullptr, // pNext
1274+
3 + (2 * hip_arg_offset), // argIndex
1275+
sizeof(new_A), // argSize
1276+
nullptr, // pProperties
1277+
&new_A, // hArgValue
1278+
};
1279+
1280+
// New X at index 4
1281+
new_input_descs[0] = {
1282+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1283+
nullptr, // pNext
1284+
4 + (2 * hip_arg_offset), // argIndex
1285+
nullptr, // pProperties
1286+
&shared_ptrs[3], // pArgValue
1287+
};
1288+
1289+
// New Y at index 5
1290+
new_input_descs[1] = {
1291+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1292+
nullptr, // pNext
1293+
5 + (2 * hip_arg_offset), // argIndex
1294+
nullptr, // pProperties
1295+
&shared_ptrs[4], // pArgValue
1296+
};
1297+
1298+
// Update kernel inputs
1299+
ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1300+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1301+
nullptr, // pNext
1302+
kernel, // hNewKernel
1303+
0, // numNewMemObjArgs
1304+
new_input_descs.size(), // numNewPointerArgs
1305+
new_value_descs.size(), // numNewValueArgs
1306+
n_dimensions, // newWorkDim
1307+
nullptr, // pNewMemObjArgList
1308+
new_input_descs.data(), // pNewPointerArgList
1309+
new_value_descs.data(), // pNewValueArgList
1310+
nullptr, // pNewGlobalWorkOffset
1311+
nullptr, // pNewGlobalWorkSize
1312+
nullptr, // pNewLocalWorkSize
1313+
};
1314+
1315+
// Update kernel and enqueue command-buffer again
1316+
ASSERT_SUCCESS(
1317+
urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc));
1318+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
1319+
nullptr, nullptr));
1320+
ASSERT_SUCCESS(urQueueFinish(queue));
1321+
1322+
// Verify that update occurred correctly
1323+
uint32_t *new_output = (uint32_t *)shared_ptrs[0];
1324+
uint32_t *new_X = (uint32_t *)shared_ptrs[3];
1325+
uint32_t *new_Y = (uint32_t *)shared_ptrs[4];
1326+
Validate(new_output, new_X, new_Y, new_A, global_size, local_size);
1327+
}

0 commit comments

Comments
 (0)