Skip to content

[WIP][UR][SYCL] Implement USM prefetch from device to host in SYCL runtime and UR #19437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ enum class CGType : unsigned int {
EnqueueNativeCommand = 27,
AsyncAlloc = 28,
AsyncFree = 29,
PrefetchUSMExpD2H = 30,
};

template <typename, typename T> struct check_fn_signature {
Expand Down
13 changes: 10 additions & 3 deletions sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <sycl/detail/common.hpp>
#include <sycl/event.hpp>
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp>
#include <sycl/ext/oneapi/experimental/graph.hpp>
#include <sycl/ext/oneapi/properties/properties.hpp>
#include <sycl/handler.hpp>
Expand Down Expand Up @@ -369,15 +370,21 @@ void fill(sycl::queue Q, T *Ptr, const T &Pattern, size_t Count,
CodeLoc);
}

inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes) {
CGH.prefetch(Ptr, NumBytes);
inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes,
prefetch_type Type = prefetch_type::device) {
if (Type == prefetch_type::device) {
CGH.prefetch(Ptr, NumBytes);
} else {
CGH.ext_oneapi_prefetch_d2h(Ptr, NumBytes);
}
}

inline void prefetch(queue Q, void *Ptr, size_t NumBytes,
prefetch_type Type = prefetch_type::device,
const sycl::detail::code_location &CodeLoc =
sycl::detail::code_location::current()) {
submit(
std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes); },
std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes, Type); },
CodeLoc);
}

Expand Down
33 changes: 33 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/enqueue_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//==--------------- enqueue_types.hpp ---- SYCL enqueue types --------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <string>

namespace sycl {
inline namespace _V1 {
namespace ext::oneapi::experimental {

/// @brief Indicates the destination device for USM data to be prefetched to.
enum class prefetch_type { device, host };

inline std::string prefetchTypeToString(prefetch_type value) {
switch (value) {
case sycl::ext::oneapi::experimental::prefetch_type::device:
return "prefetch_type::device";
case sycl::ext::oneapi::experimental::prefetch_type::host:
return "prefetch_type::host";
default:
return "prefetch_type::unknown";
}
}

} // namespace ext::oneapi::experimental
} // namespace _V1
} // namespace sycl
12 changes: 12 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ namespace ext ::oneapi ::experimental {
template <typename, typename> class work_group_memory;
template <typename, typename> class dynamic_work_group_memory;
struct image_descriptor;
enum class prefetch_type;
void prefetch(handler &CGH, void *Ptr, size_t NumBytes, prefetch_type Type);

__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr);
__SYCL_EXPORT void *async_malloc(sycl::handler &h, sycl::usm::alloc kind,
size_t size);
Expand Down Expand Up @@ -3687,6 +3690,15 @@ class __SYCL_EXPORT handler {
void ext_oneapi_memset2d_impl(void *Dest, size_t DestPitch, int Value,
size_t Width, size_t Height);

// Implementation of prefetch from device back to host
void ext_oneapi_prefetch_d2h(const void *Ptr, size_t Count);

// The enqueue_functions module's prefetch function is friended in order for
// it to be able to call private handler function ext_oneapi_prefetch_d2h.
friend void sycl::ext::oneapi::experimental::prefetch(
handler &CGH, void *Ptr, size_t NumBytes,
sycl::ext::oneapi::experimental::prefetch_type Type);

// Implementation of memcpy to device_global.
void memcpyToDeviceGlobal(const void *DeviceGlobalPtr, const void *Src,
bool IsDeviceImageScoped, size_t NumBytes,
Expand Down
14 changes: 14 additions & 0 deletions sycl/source/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,20 @@ class CGPrefetchUSM : public CG {
size_t getLength() { return MLength; }
};

/// "Prefetch USM" command group class.
class CGPrefetchUSMExpD2H : public CG {
void *MDst;
size_t MLength;

public:
CGPrefetchUSMExpD2H(void *DstPtr, size_t Length, CG::StorageInitHelper CGData,
detail::code_location loc = {})
: CG(CGType::PrefetchUSMExpD2H, std::move(CGData), std::move(loc)),
MDst(DstPtr), MLength(Length) {}
void *getDst() { return MDst; }
size_t getLength() { return MLength; }
};

/// "Advise USM" command group class.
class CGAdviseUSM : public CG {
void *MDst;
Expand Down
12 changes: 12 additions & 0 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
case sycl::detail::CGType::FillUSM:
return node_type::memfill;
case sycl::detail::CGType::PrefetchUSM:
case sycl::detail::CGType::PrefetchUSMExpD2H:
return node_type::prefetch;
case sycl::detail::CGType::AdviseUSM:
return node_type::memadvise;
Expand Down Expand Up @@ -247,6 +248,8 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
return createCGCopy<sycl::detail::CGFillUSM>();
case sycl::detail::CGType::PrefetchUSM:
return createCGCopy<sycl::detail::CGPrefetchUSM>();
case sycl::detail::CGType::PrefetchUSMExpD2H:
return createCGCopy<sycl::detail::CGPrefetchUSMExpD2H>();
case sycl::detail::CGType::AdviseUSM:
return createCGCopy<sycl::detail::CGAdviseUSM>();
case sycl::detail::CGType::Copy2DUSM:
Expand Down Expand Up @@ -658,6 +661,15 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
<< " Length: " << Prefetch->getLength() << "\\n";
}
break;
case sycl::detail::CGType::PrefetchUSMExpD2H:
Stream << "CGPrefetchUSMExpD2H (Experimental, Device to host) \\n";
if (Verbose) {
sycl::detail::CGPrefetchUSMExpD2H *Prefetch =
static_cast<sycl::detail::CGPrefetchUSMExpD2H *>(MCommandGroup.get());
Stream << "Dst: " << Prefetch->getDst()
<< " Length: " << Prefetch->getLength() << "\\n";
}
break;
case sycl::detail::CGType::AdviseUSM:
Stream << "CGAdviseUSM \\n";
if (Verbose) {
Expand Down
20 changes: 15 additions & 5 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,10 +924,15 @@ void MemoryManager::fill_usm(void *Mem, queue_impl &Queue, size_t Length,

void MemoryManager::prefetch_usm(void *Mem, queue_impl &Queue, size_t Length,
std::vector<ur_event_handle_t> DepEvents,
ur_event_handle_t *OutEvent) {
ur_event_handle_t *OutEvent,
sycl::ext::oneapi::experimental::prefetch_type Dest) {
adapter_impl &Adapter = Queue.getAdapter();
ur_usm_migration_flags_t MigrationFlag =
(Dest == sycl::ext::oneapi::experimental::prefetch_type::device)
? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE
: UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
Adapter.call<UrApiKind::urEnqueueUSMPrefetch>(Queue.getHandleRef(), Mem,
Length, 0u, DepEvents.size(),
Length, MigrationFlag, DepEvents.size(),
DepEvents.data(), OutEvent);
}

Expand Down Expand Up @@ -1539,11 +1544,16 @@ void MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer(
sycl::detail::context_impl *Context,
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
ur_exp_command_buffer_sync_point_t *OutSyncPoint) {
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
sycl::ext::oneapi::experimental::prefetch_type Dest) {
adapter_impl &Adapter = Context->getAdapter();
ur_usm_migration_flags_t MigrationFlag =
(Dest == sycl::ext::oneapi::experimental::prefetch_type::device)
? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE
: UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
Adapter.call<UrApiKind::urCommandBufferAppendUSMPrefetchExp>(
CommandBuffer, Mem, Length, ur_usm_migration_flags_t(0), Deps.size(),
Deps.data(), 0u, nullptr, OutSyncPoint, nullptr, nullptr);
CommandBuffer, Mem, Length, MigrationFlag, Deps.size(),
Deps.data(), 0, nullptr, OutSyncPoint, nullptr, nullptr);
}

void MemoryManager::ext_oneapi_advise_usm_cmd_buffer(
Expand Down
14 changes: 10 additions & 4 deletions sycl/source/detail/memory_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <detail/sycl_mem_obj_i.hpp>
#include <sycl/access/access.hpp>
#include <sycl/detail/export.hpp>
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp> // for prefetch_type
#include <sycl/id.hpp>
#include <sycl/property_list.hpp>
#include <sycl/range.hpp>
Expand Down Expand Up @@ -146,9 +147,12 @@ class MemoryManager {
std::vector<ur_event_handle_t> DepEvents,
ur_event_handle_t *OutEvent);

static void prefetch_usm(void *Ptr, queue_impl &Queue, size_t Len,
std::vector<ur_event_handle_t> DepEvents,
ur_event_handle_t *OutEvent);
static void prefetch_usm(
void *Ptr, queue_impl &Queue, size_t Len,
std::vector<ur_event_handle_t> DepEvents,
ur_event_handle_t *OutEvent,
sycl::ext::oneapi::experimental::prefetch_type Dest =
sycl::ext::oneapi::experimental::prefetch_type::device);

static void advise_usm(const void *Ptr, queue_impl &Queue, size_t Len,
ur_usm_advice_flags_t Advice,
Expand Down Expand Up @@ -245,7 +249,9 @@ class MemoryManager {
sycl::detail::context_impl *Context,
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
ur_exp_command_buffer_sync_point_t *OutSyncPoint);
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
sycl::ext::oneapi::experimental::prefetch_type Dest =
sycl::ext::oneapi::experimental::prefetch_type::device);

static void ext_oneapi_advise_usm_cmd_buffer(
sycl::detail::context_impl *Context,
Expand Down
34 changes: 32 additions & 2 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,9 @@ static std::string_view cgTypeToString(detail::CGType Type) {
case detail::CGType::PrefetchUSM:
return "prefetch usm";
break;
case detail::CGType::PrefetchUSMExpD2H:
return "prefetch usm (experimental, device to host)";
break;
case detail::CGType::CodeplayHostTask:
return "host task";
break;
Expand Down Expand Up @@ -2989,7 +2992,21 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
if (auto Result = callMemOpHelper(
MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer,
&MQueue->getContextImpl(), MCommandBuffer, Prefetch->getDst(),
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint);
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint,
sycl::ext::oneapi::experimental::prefetch_type::device);
Result != UR_RESULT_SUCCESS)
return Result;

MEvent->setSyncPoint(OutSyncPoint);
return UR_RESULT_SUCCESS;
}
case CGType::PrefetchUSMExpD2H: {
CGPrefetchUSMExpD2H *Prefetch = (CGPrefetchUSMExpD2H *)MCommandGroup.get();
if (auto Result = callMemOpHelper(
MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer,
&MQueue->getContextImpl(), MCommandBuffer, Prefetch->getDst(),
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint,
sycl::ext::oneapi::experimental::prefetch_type::host);
Result != UR_RESULT_SUCCESS)
return Result;

Expand Down Expand Up @@ -3300,7 +3317,20 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
if (auto Result = callMemOpHelper(
MemoryManager::prefetch_usm, Prefetch->getDst(), *MQueue,
Prefetch->getLength(), std::move(RawEvents), Event);
Prefetch->getLength(), std::move(RawEvents), Event,
sycl::ext::oneapi::experimental::prefetch_type::device);
Result != UR_RESULT_SUCCESS)
return Result;

SetEventHandleOrDiscard();
return UR_RESULT_SUCCESS;
}
case CGType::PrefetchUSMExpD2H: {
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
if (auto Result = callMemOpHelper(
MemoryManager::prefetch_usm, Prefetch->getDst(), *MQueue,
Prefetch->getLength(), std::move(RawEvents), Event,
sycl::ext::oneapi::experimental::prefetch_type::host);
Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
11 changes: 11 additions & 0 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ event handler::finalize() {
CommandGroup.reset(new detail::CGPrefetchUSM(
MDstPtr, MLength, std::move(impl->CGData), MCodeLoc));
break;
case detail::CGType::PrefetchUSMExpD2H:
CommandGroup.reset(new detail::CGPrefetchUSMExpD2H(
MDstPtr, MLength, std::move(impl->CGData), MCodeLoc));
break;
case detail::CGType::AdviseUSM:
CommandGroup.reset(new detail::CGAdviseUSM(MDstPtr, MLength, impl->MAdvice,
std::move(impl->CGData),
Expand Down Expand Up @@ -1479,6 +1483,13 @@ void handler::prefetch(const void *Ptr, size_t Count) {
setType(detail::CGType::PrefetchUSM);
}

void handler::ext_oneapi_prefetch_d2h(const void *Ptr, size_t Count) {
throwIfActionIsCreated();
MDstPtr = const_cast<void *>(Ptr);
MLength = Count;
setType(detail::CGType::PrefetchUSMExpD2H);
}

void handler::mem_advise(const void *Ptr, size_t Count, int Advice) {
throwIfActionIsCreated();
MDstPtr = const_cast<void *>(Ptr);
Expand Down
Loading
Loading