Skip to content

[CUDA][HIP] Fix kernel arguments being overwritten when added out of order #2559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
ThreadsPerBlock, BlocksPerGrid));

// Set node param structure with the kernel related data
auto &ArgIndices = hKernel->getArgIndices();
auto &ArgPointers = hKernel->getArgPointers();
CUDA_KERNEL_NODE_PARAMS NodeParams = {};
NodeParams.func = CuFunc;
NodeParams.gridDimX = BlocksPerGrid[0];
Expand All @@ -533,7 +533,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
NodeParams.blockDimY = ThreadsPerBlock[1];
NodeParams.blockDimZ = ThreadsPerBlock[2];
NodeParams.sharedMemBytes = LocalSize;
NodeParams.kernelParams = const_cast<void **>(ArgIndices.data());
NodeParams.kernelParams = const_cast<void **>(ArgPointers.data());

// Create and add an new kernel node to the Cuda graph
UR_CHECK_ERROR(cuGraphAddKernelNode(&GraphNode, hCommandBuffer->CudaGraph,
Expand Down Expand Up @@ -1398,7 +1398,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
Params.blockDimZ = ThreadsPerBlock[2];
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
Params.kernelParams =
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
const_cast<void **>(KernelCommandHandle->Kernel->getArgPointers().data());

CUgraphNode Node = KernelCommandHandle->Node;
CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec;
Expand Down
8 changes: 4 additions & 4 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,11 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
UR_CHECK_ERROR(RetImplEvent->start());
}

auto &ArgIndices = hKernel->getArgIndices();
auto &ArgPointers = hKernel->getArgPointers();
UR_CHECK_ERROR(cuLaunchKernel(
CuFunc, BlocksPerGrid[0], BlocksPerGrid[1], BlocksPerGrid[2],
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], LocalSize,
CuStream, const_cast<void **>(ArgIndices.data()), nullptr));
CuStream, const_cast<void **>(ArgPointers.data()), nullptr));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
Expand Down Expand Up @@ -680,7 +680,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
UR_CHECK_ERROR(RetImplEvent->start());
}

auto &ArgIndices = hKernel->getArgIndices();
auto &ArgPointers = hKernel->getArgPointers();

CUlaunchConfig launch_config;
launch_config.gridDimX = BlocksPerGrid[0];
Expand All @@ -696,7 +696,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
launch_config.numAttrs = launch_attribute.size();

UR_CHECK_ERROR(cuLaunchKernelEx(&launch_config, CuFunc,
const_cast<void **>(ArgIndices.data()),
const_cast<void **>(ArgPointers.data()),
nullptr));

if (phEvent) {
Expand Down
62 changes: 36 additions & 26 deletions source/adapters/cuda/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ struct ur_kernel_handle_t_ {
args_t Storage;
/// Aligned size of each parameter, including padding.
args_size_t ParamSizes;
/// Byte offset into /p Storage allocation for each parameter.
args_index_t Indices;
/// Byte offset into /p Storage allocation for each argument.
args_index_t ArgPointers;
/// Position in the Storage array where the next argument should added.
size_t InsertPos = 0;
/// Aligned size in bytes for each local memory parameter after padding has
/// been added. Zero if the argument at the index isn't a local memory
/// argument.
Expand All @@ -90,33 +92,43 @@ struct ur_kernel_handle_t_ {
std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};

arguments() {
// Place the implicit offset index at the end of the indicies collection
Indices.emplace_back(&ImplicitOffsetArgs);
// Place the implicit offset index at the end of the ArgPointers
// collection.
ArgPointers.emplace_back(&ImplicitOffsetArgs);
}

/// Add an argument to the kernel.
/// If the argument existed before, it is replaced.
/// Otherwise, it is added.
/// Gaps are filled with empty arguments.
/// Implicit offset argument is kept at the back of the indices collection.
/// Implicit offset argument is kept at the back of the ArgPointers
/// collection.
void addArg(size_t Index, size_t Size, const void *Arg,
size_t LocalSize = 0) {
if (Index + 2 > Indices.size()) {
// Expand storage to accommodate this Index if needed.
if (Index + 2 > ArgPointers.size()) {
// Move implicit offset argument index with the end
Indices.resize(Index + 2, Indices.back());
ArgPointers.resize(Index + 2, ArgPointers.back());
// Ensure enough space for the new argument
ParamSizes.resize(Index + 1);
AlignedLocalMemSize.resize(Index + 1);
OriginalLocalMemSize.resize(Index + 1);
}
ParamSizes[Index] = Size;
// calculate the insertion point on the array
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
std::begin(ParamSizes) + Index, 0);
// Update the stored value for the argument
std::memcpy(&Storage[InsertPos], Arg, Size);
Indices[Index] = &Storage[InsertPos];
AlignedLocalMemSize[Index] = LocalSize;

// Copy new argument to storage if it hasn't been added before.
if (ParamSizes[Index] == 0) {
ParamSizes[Index] = Size;
std::memcpy(&Storage[InsertPos], Arg, Size);
ArgPointers[Index] = &Storage[InsertPos];
AlignedLocalMemSize[Index] = LocalSize;
InsertPos += Size;
}
// Otherwise, update the existing argument.
else {
std::memcpy(ArgPointers[Index], Arg, Size);
AlignedLocalMemSize[Index] = LocalSize;
assert(Size == ParamSizes[Index]);
}
}

/// Returns the padded size and offset of a local memory argument.
Expand All @@ -128,7 +140,7 @@ struct ur_kernel_handle_t_ {
std::pair<size_t, size_t> calcAlignedLocalArgument(size_t Index,
size_t Size) {
// Store the unpadded size of the local argument
if (Index + 2 > Indices.size()) {
if (Index + 2 > ArgPointers.size()) {
AlignedLocalMemSize.resize(Index + 1);
OriginalLocalMemSize.resize(Index + 1);
}
Expand Down Expand Up @@ -158,10 +170,11 @@ struct ur_kernel_handle_t_ {
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
}

// Iterate over all existing local argument which follows StartIndex
// Iterate over each existing local argument which follows StartIndex
// index, update the offset and pointer into the kernel local memory.
void updateLocalArgOffset(size_t StartIndex) {
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
const size_t NumArgs =
ArgPointers.size() - 1; // Accounts for implicit arg
for (auto SuccIndex = StartIndex; SuccIndex < NumArgs; SuccIndex++) {
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
if (OriginalLocalSize == 0) {
Expand All @@ -177,10 +190,7 @@ struct ur_kernel_handle_t_ {
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;

// Store new offset into local data
const size_t InsertPos =
std::accumulate(std::begin(ParamSizes),
std::begin(ParamSizes) + SuccIndex, size_t{0});
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
std::memcpy(ArgPointers[SuccIndex], &SuccAlignedLocalOffset,
sizeof(size_t));
}
}
Expand Down Expand Up @@ -228,7 +238,7 @@ struct ur_kernel_handle_t_ {
std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size);
}

const args_index_t &getIndices() const noexcept { return Indices; }
const args_index_t &getArgPointers() const noexcept { return ArgPointers; }

uint32_t getLocalSize() const {
return std::accumulate(std::begin(AlignedLocalMemSize),
Expand Down Expand Up @@ -299,7 +309,7 @@ struct ur_kernel_handle_t_ {
/// real one required by the kernel, since this cannot be queried from
/// the CUDA Driver API
uint32_t getNumArgs() const noexcept {
return static_cast<uint32_t>(Args.Indices.size() - 1);
return static_cast<uint32_t>(Args.ArgPointers.size() - 1);
}

void setKernelArg(int Index, size_t Size, const void *Arg) {
Expand All @@ -314,8 +324,8 @@ struct ur_kernel_handle_t_ {
return Args.setImplicitOffset(Size, ImplicitOffset);
}

const arguments::args_index_t &getArgIndices() const {
return Args.getIndices();
const arguments::args_index_t &getArgPointers() const {
return Args.getArgPointers();
}

void setWorkGroupMemory(size_t MemSize) { Args.setWorkGroupMemory(MemSize); }
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
pLocalWorkSize, hKernel, HIPFunc, ThreadsPerBlock, BlocksPerGrid));

// Set node param structure with the kernel related data
auto &ArgIndices = hKernel->getArgIndices();
auto &ArgPointers = hKernel->getArgPointers();
hipKernelNodeParams NodeParams;
NodeParams.func = HIPFunc;
NodeParams.gridDim.x = BlocksPerGrid[0];
Expand All @@ -388,7 +388,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
NodeParams.blockDim.y = ThreadsPerBlock[1];
NodeParams.blockDim.z = ThreadsPerBlock[2];
NodeParams.sharedMemBytes = LocalSize;
NodeParams.kernelParams = const_cast<void **>(ArgIndices.data());
NodeParams.kernelParams = const_cast<void **>(ArgPointers.data());
NodeParams.extra = nullptr;

// Create and add an new kernel node to the HIP graph
Expand Down Expand Up @@ -1098,7 +1098,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
Params.blockDim.z = ThreadsPerBlock[2];
Params.sharedMemBytes = hCommand->Kernel->getLocalSize();
Params.kernelParams =
const_cast<void **>(hCommand->Kernel->getArgIndices().data());
const_cast<void **>(hCommand->Kernel->getArgPointers().data());

hipGraphNode_t Node = hCommand->Node;
hipGraphExec_t HipGraphExec = CommandBuffer->HIPGraphExec;
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
}
}

auto ArgIndices = hKernel->getArgIndices();
auto ArgPointers = hKernel->getArgPointers();

// If migration of mem across buffer is needed, an event must be associated
// with this command, implicitly if phEvent is nullptr
Expand All @@ -322,7 +322,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
UR_CHECK_ERROR(hipModuleLaunchKernel(
HIPFunc, BlocksPerGrid[0], BlocksPerGrid[1], BlocksPerGrid[2],
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2],
hKernel->getLocalSize(), HIPStream, ArgIndices.data(), nullptr));
hKernel->getLocalSize(), HIPStream, ArgPointers.data(), nullptr));

if (phEvent) {
UR_CHECK_ERROR(RetImplEvent->record());
Expand Down
Loading
Loading