@@ -1105,3 +1105,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
1105
1105
uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1106
1106
Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1107
1107
}
1108
+
1109
+ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1110
+ virtual void SetUp () override {
1111
+ program_name = " saxpy_usm_local_mem" ;
1112
+ UUR_RETURN_ON_FATAL_FAILURE (
1113
+ urUpdatableCommandBufferExpExecutionTest::SetUp ());
1114
+
1115
+ if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1116
+ GTEST_SKIP ()
1117
+ << " Local memory argument update not supported on Level Zero." ;
1118
+ }
1119
+
1120
+ // HIP has extra args for local memory so we define an offset for arg
1121
+ // indices here for updating
1122
+ hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0 ;
1123
+ ur_device_usm_access_capability_flags_t shared_usm_flags;
1124
+ ASSERT_SUCCESS (
1125
+ uur::GetDeviceUSMSingleSharedSupport (device, shared_usm_flags));
1126
+ if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1127
+ GTEST_SKIP () << " Shared USM is not supported." ;
1128
+ }
1129
+
1130
+ const size_t allocation_size =
1131
+ sizeof (uint32_t ) * global_size * local_size;
1132
+ for (auto &shared_ptr : shared_ptrs) {
1133
+ ASSERT_SUCCESS (urUSMSharedAlloc (context, device, nullptr , nullptr ,
1134
+ allocation_size, &shared_ptr));
1135
+ ASSERT_NE (shared_ptr, nullptr );
1136
+
1137
+ std::vector<uint8_t > pattern (allocation_size);
1138
+ uur::generateMemFillPattern (pattern);
1139
+ std::memcpy (shared_ptr, pattern.data (), allocation_size);
1140
+ }
1141
+
1142
+ std::array<size_t , 12 > index_order{};
1143
+ if (backend != UR_PLATFORM_BACKEND_HIP) {
1144
+ index_order = {3 , 2 , 4 , 5 , 1 , 0 };
1145
+ } else {
1146
+ index_order = {9 , 8 , 10 , 11 , 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 };
1147
+ }
1148
+ size_t current_index = 0 ;
1149
+
1150
+ // Index 3 is A
1151
+ ASSERT_SUCCESS (urKernelSetArgValue (kernel, index_order[current_index++],
1152
+ sizeof (A), nullptr , &A));
1153
+ // Index 2 is output
1154
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1155
+ kernel, index_order[current_index++], nullptr , shared_ptrs[0 ]));
1156
+
1157
+ // Index 4 is X
1158
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1159
+ kernel, index_order[current_index++], nullptr , shared_ptrs[1 ]));
1160
+ // Index 5 is Y
1161
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1162
+ kernel, index_order[current_index++], nullptr , shared_ptrs[2 ]));
1163
+
1164
+ // Index 1 is local_mem_b arg
1165
+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1166
+ local_mem_b_size, nullptr ));
1167
+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1168
+ ASSERT_SUCCESS (urKernelSetArgValue (
1169
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1170
+ nullptr , &hip_local_offset));
1171
+ ASSERT_SUCCESS (urKernelSetArgValue (
1172
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1173
+ nullptr , &hip_local_offset));
1174
+ ASSERT_SUCCESS (urKernelSetArgValue (
1175
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1176
+ nullptr , &hip_local_offset));
1177
+ }
1178
+
1179
+ // Index 0 is local_mem_a arg
1180
+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1181
+ local_mem_a_size, nullptr ));
1182
+
1183
+ // Hip has extra args for local mem at index 1-3
1184
+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1185
+ ASSERT_SUCCESS (urKernelSetArgValue (
1186
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1187
+ nullptr , &hip_local_offset));
1188
+ ASSERT_SUCCESS (urKernelSetArgValue (
1189
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1190
+ nullptr , &hip_local_offset));
1191
+ ASSERT_SUCCESS (urKernelSetArgValue (
1192
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1193
+ nullptr , &hip_local_offset));
1194
+ }
1195
+ }
1196
+ };
1197
+
1198
+ struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1199
+ void SetUp () override {
1200
+ UUR_RETURN_ON_FATAL_FAILURE (
1201
+ LocalMemoryUpdateTestBaseOutOfOrder::SetUp ());
1202
+
1203
+ // Append kernel command to command-buffer and close command-buffer
1204
+ ASSERT_SUCCESS (urCommandBufferAppendKernelLaunchExp (
1205
+ updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1206
+ &global_size, &local_size, 0 , nullptr , 0 , nullptr , 0 , nullptr ,
1207
+ nullptr , nullptr , &command_handle));
1208
+ ASSERT_NE (command_handle, nullptr );
1209
+
1210
+ ASSERT_SUCCESS (urCommandBufferFinalizeExp (updatable_cmd_buf_handle));
1211
+ }
1212
+
1213
+ void TearDown () override {
1214
+ if (command_handle) {
1215
+ EXPECT_SUCCESS (urCommandBufferReleaseCommandExp (command_handle));
1216
+ }
1217
+
1218
+ UUR_RETURN_ON_FATAL_FAILURE (
1219
+ LocalMemoryUpdateTestBaseOutOfOrder::TearDown ());
1220
+ }
1221
+
1222
+ ur_exp_command_buffer_command_handle_t command_handle = nullptr ;
1223
+ };
1224
+
1225
+ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P (LocalMemoryUpdateTestOutOfOrder);
1226
+
1227
+ // Test updating A,X,Y parameters to new values and local memory to larger
1228
+ // values when the kernel arguments were added out of order.
1229
+ TEST_P (LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1230
+ // Run command-buffer prior to update and verify output
1231
+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1232
+ nullptr , nullptr ));
1233
+ ASSERT_SUCCESS (urQueueFinish (queue));
1234
+
1235
+ uint32_t *output = (uint32_t *)shared_ptrs[0 ];
1236
+ uint32_t *X = (uint32_t *)shared_ptrs[1 ];
1237
+ uint32_t *Y = (uint32_t *)shared_ptrs[2 ];
1238
+ Validate (output, X, Y, A, global_size, local_size);
1239
+
1240
+ // Update inputs
1241
+ std::array<ur_exp_command_buffer_update_pointer_arg_desc_t , 2 >
1242
+ new_input_descs;
1243
+ std::array<ur_exp_command_buffer_update_value_arg_desc_t , 3 >
1244
+ new_value_descs;
1245
+
1246
+ size_t new_local_size = local_size * 4 ;
1247
+ size_t new_local_mem_a_size = new_local_size * sizeof (uint32_t );
1248
+
1249
+ // New local_mem_a at index 0
1250
+ new_value_descs[0 ] = {
1251
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1252
+ nullptr , // pNext
1253
+ 0 , // argIndex
1254
+ new_local_mem_a_size, // argSize
1255
+ nullptr , // pProperties
1256
+ nullptr , // hArgValue
1257
+ };
1258
+
1259
+ // New local_mem_b at index 1
1260
+ new_value_descs[1 ] = {
1261
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262
+ nullptr , // pNext
1263
+ 1 + hip_arg_offset, // argIndex
1264
+ local_mem_b_size, // argSize
1265
+ nullptr , // pProperties
1266
+ nullptr , // hArgValue
1267
+ };
1268
+
1269
+ // New A at index 3
1270
+ uint32_t new_A = 33 ;
1271
+ new_value_descs[2 ] = {
1272
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1273
+ nullptr , // pNext
1274
+ 3 + (2 * hip_arg_offset), // argIndex
1275
+ sizeof (new_A), // argSize
1276
+ nullptr , // pProperties
1277
+ &new_A, // hArgValue
1278
+ };
1279
+
1280
+ // New X at index 4
1281
+ new_input_descs[0 ] = {
1282
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1283
+ nullptr , // pNext
1284
+ 4 + (2 * hip_arg_offset), // argIndex
1285
+ nullptr , // pProperties
1286
+ &shared_ptrs[3 ], // pArgValue
1287
+ };
1288
+
1289
+ // New Y at index 5
1290
+ new_input_descs[1 ] = {
1291
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1292
+ nullptr , // pNext
1293
+ 5 + (2 * hip_arg_offset), // argIndex
1294
+ nullptr , // pProperties
1295
+ &shared_ptrs[4 ], // pArgValue
1296
+ };
1297
+
1298
+ // Update kernel inputs
1299
+ ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1300
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1301
+ nullptr , // pNext
1302
+ kernel, // hNewKernel
1303
+ 0 , // numNewMemObjArgs
1304
+ new_input_descs.size (), // numNewPointerArgs
1305
+ new_value_descs.size (), // numNewValueArgs
1306
+ n_dimensions, // newWorkDim
1307
+ nullptr , // pNewMemObjArgList
1308
+ new_input_descs.data (), // pNewPointerArgList
1309
+ new_value_descs.data (), // pNewValueArgList
1310
+ nullptr , // pNewGlobalWorkOffset
1311
+ nullptr , // pNewGlobalWorkSize
1312
+ nullptr , // pNewLocalWorkSize
1313
+ };
1314
+
1315
+ // Update kernel and enqueue command-buffer again
1316
+ ASSERT_SUCCESS (
1317
+ urCommandBufferUpdateKernelLaunchExp (command_handle, &update_desc));
1318
+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1319
+ nullptr , nullptr ));
1320
+ ASSERT_SUCCESS (urQueueFinish (queue));
1321
+
1322
+ // Verify that update occurred correctly
1323
+ uint32_t *new_output = (uint32_t *)shared_ptrs[0 ];
1324
+ uint32_t *new_X = (uint32_t *)shared_ptrs[3 ];
1325
+ uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1326
+ Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1327
+ }
0 commit comments