Skip to content

Commit f0df304

Browse files
committed
[SYCL][Graph] Initial implementation of graph-owned device allocations
- Allocations managed via new graph_mem_pool - Device allocations use virtual memory - Intercept async_alloc calls when adding nodes to graph - New tests for functionality
1 parent c2b5a90 commit f0df304

20 files changed

+835
-126
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ enum class node_type {
121121
memadvise = 7,
122122
ext_oneapi_barrier = 8,
123123
host_task = 9,
124-
native_command = 10
124+
native_command = 10,
125+
async_malloc = 11,
126+
async_free = 12
125127
};
126128

127129
/// Class representing a node in the graph, returned by command_graph::add().
@@ -429,6 +431,10 @@ class __SYCL_EXPORT executable_command_graph
429431
/// @param Nodes The nodes to use for updating the graph.
430432
void update(const std::vector<node> &Nodes);
431433

434+
/// Return the total amount of memory required by this graph for graph-owned
435+
/// memory allocations.
436+
size_t get_required_mem_size() const;
437+
432438
/// Common Reference Semantics
433439
friend bool operator==(const executable_command_graph &LHS,
434440
const executable_command_graph &RHS) {

sycl/source/detail/async_alloc.cpp

+40-23
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,28 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
4444
sycl::make_error_code(sycl::errc::feature_not_supported),
4545
"Only device backed asynchronous allocations are supported!");
4646

47-
h.throwIfGraphAssociated<
48-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
49-
sycl_ext_oneapi_async_alloc>();
50-
5147
auto &Adapter = h.getContextImplPtr()->getAdapter();
52-
auto &Q = h.MQueue->getHandleRef();
5348

5449
// Get events to wait on.
5550
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
5651
uint32_t numEvents = h.impl->CGData.MEvents.size();
5752

5853
void *alloc = nullptr;
59-
ur_event_handle_t Event;
60-
Adapter->call<sycl::errc::runtime,
61-
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
62-
Q, (ur_usm_pool_handle_t)0, size, nullptr, numEvents, depEvents.data(),
63-
&alloc, &Event);
54+
55+
ur_event_handle_t Event = nullptr;
56+
// If a graph is present do the allocation from the graph memory pool instead.
57+
if (auto Graph = h.getCommandGraph(); Graph) {
58+
// size may be modified to reflect the aligned size based on device
59+
// granularity.
60+
alloc = Graph->getMemPool().malloc(size, kind);
61+
62+
} else {
63+
auto &Q = h.MQueue->getHandleRef();
64+
Adapter->call<sycl::errc::runtime,
65+
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
66+
Q, (ur_usm_pool_handle_t)0, size, nullptr, numEvents, depEvents.data(),
67+
&alloc, &Event);
68+
}
6469

6570
// Async malloc must return a void* immediately.
6671
// Set up CommandGroup which is a no-op and pass the
@@ -90,25 +95,31 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
9095
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
9196
const memory_pool &pool) {
9297

93-
h.throwIfGraphAssociated<
94-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
95-
sycl_ext_oneapi_async_alloc>();
96-
9798
auto &Adapter = h.getContextImplPtr()->getAdapter();
98-
auto &Q = h.MQueue->getHandleRef();
9999
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);
100100

101101
// Get events to wait on.
102102
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
103103
uint32_t numEvents = h.impl->CGData.MEvents.size();
104104

105105
void *alloc = nullptr;
106-
ur_event_handle_t Event;
107-
Adapter->call<sycl::errc::runtime,
108-
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
109-
Q, memPoolImpl.get()->get_handle(), size, nullptr, numEvents,
110-
depEvents.data(), &alloc, &Event);
111106

107+
ur_event_handle_t Event = nullptr;
108+
// If a graph is present do the allocation from the graph memory pool instead.
109+
if (auto Graph = h.getCommandGraph(); Graph) {
110+
// size may be modified to reflect the aligned size based on device
111+
// granularity.
112+
// Memory pool is passed as the graph may use some properties of it.
113+
alloc = Graph->getMemPool().malloc(size, pool.get_alloc_kind(),
114+
sycl::detail::getSyclObjImpl(pool));
115+
116+
} else {
117+
auto &Q = h.MQueue->getHandleRef();
118+
Adapter->call<sycl::errc::runtime,
119+
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
120+
Q, memPoolImpl.get()->get_handle(), size, nullptr, numEvents,
121+
depEvents.data(), &alloc, &Event);
122+
}
112123
// Async malloc must return a void* immediately.
113124
// Set up CommandGroup which is a no-op and pass the event from the alloc.
114125
h.impl->MAsyncAllocEvent = Event;
@@ -135,9 +146,15 @@ async_malloc_from_pool(const sycl::queue &q, size_t size,
135146
}
136147

137148
__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr) {
138-
h.throwIfGraphAssociated<
139-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
140-
sycl_ext_oneapi_async_alloc>();
149+
if (auto Graph = h.getCommandGraph(); Graph) {
150+
// Check if the pointer to be freed has an associated allocation node, and
151+
// error if not
152+
if (!Graph->getMemPool().hasAllocation(ptr)) {
153+
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
154+
"Cannot add a free node to a graph for which "
155+
"there is no associated allocation node!");
156+
}
157+
}
141158

142159
h.impl->MFreePtr = ptr;
143160
h.setType(detail::CGType::AsyncFree);

sycl/source/detail/graph_impl.cpp

+35-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ inline const char *nodeTypeToString(node_type NodeType) {
5757
return "host_task";
5858
case node_type::native_command:
5959
return "native_command";
60+
case node_type::async_malloc:
61+
return "async_malloc";
62+
case node_type::async_free:
63+
return "async_free";
6064
}
6165
assert(false && "Unhandled node type");
6266
return {};
@@ -340,7 +344,7 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
340344
const sycl::device &SyclDevice,
341345
const sycl::property_list &PropList)
342346
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
343-
MEventsMap(), MInorderQueueMap(),
347+
MEventsMap(), MInorderQueueMap(), MGraphMemPool(SyclContext, SyclDevice),
344348
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
345349
checkGraphPropertiesAndThrow(PropList);
346350
if (PropList.has_property<property::graph::no_cycle_check>()) {
@@ -748,12 +752,12 @@ void graph_impl::beginRecording(
748752
}
749753
}
750754

751-
// Check if nodes are empty and if so loop back through predecessors until we
752-
// find the real dependency.
755+
// Check if nodes do not require enqueueing and if so loop back through
756+
// predecessors until we find the real dependency.
753757
void exec_graph_impl::findRealDeps(
754758
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
755759
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
756-
if (CurrentNode->isEmpty()) {
760+
if (!CurrentNode->requiresEnqueue()) {
757761
for (auto &N : CurrentNode->MPredecessors) {
758762
auto NodeImpl = N.lock();
759763
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
@@ -873,9 +877,9 @@ void exec_graph_impl::createCommandBuffers(
873877
Partition->MCommandBuffers[Device] = OutCommandBuffer;
874878

875879
for (const auto &Node : Partition->MSchedule) {
876-
// Empty nodes are not processed as other nodes, but only their
880+
// Some nodes are not scheduled like other nodes, and only their
877881
// dependencies are propagated in findRealDeps
878-
if (Node->isEmpty())
882+
if (!Node->requiresEnqueue())
879883
continue;
880884

881885
sycl::detail::CGType type = Node->MCGType;
@@ -941,6 +945,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
941945

942946
exec_graph_impl::~exec_graph_impl() {
943947
try {
948+
MGraphImpl->markExecGraphDestroyed();
949+
944950
const sycl::detail::AdapterPtr &Adapter =
945951
sycl::detail::getSyclObjImpl(MContext)->getAdapter();
946952
MSchedule.clear();
@@ -950,6 +956,9 @@ exec_graph_impl::~exec_graph_impl() {
950956
Event->wait(Event);
951957
}
952958

959+
// Clean up any graph-owned allocations that were allocated
960+
MGraphImpl->getMemPool().deallocateAndUnmapAll();
961+
953962
for (const auto &Partition : MPartitions) {
954963
Partition->MSchedule.clear();
955964
for (const auto &Iter : Partition->MCommandBuffers) {
@@ -1867,6 +1876,14 @@ modifiable_command_graph::finalize(const sycl::property_list &PropList) const {
18671876
// Graph is read and written in this scope so we lock
18681877
// this graph with full priviledges.
18691878
graph_impl::WriteLock Lock(impl->MMutex);
1879+
// If the graph uses graph-owned allocations and an executable graph already
1880+
// exists we must throw an error.
1881+
if (impl->getMemPool().hasAllocations() && impl->getExecGraphCount() > 0) {
1882+
throw sycl::exception(sycl::make_error_code(errc::invalid),
1883+
"Graphs containing allocations can only have a "
1884+
"single executable graph alive at any onc time.");
1885+
}
1886+
18701887
return command_graph<graph_state::executable>{
18711888
this->impl, this->impl->getContext(), PropList};
18721889
}
@@ -1994,11 +2011,16 @@ executable_command_graph::executable_command_graph(
19942011
const property_list &PropList)
19952012
: impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph, PropList)) {
19962013
finalizeImpl(); // Create backend representation for executable graph
2014+
// Mark that we have created an executable graph from the modifiable graph.
2015+
Graph->markExecGraphCreated();
19972016
}
19982017

19992018
void executable_command_graph::finalizeImpl() {
20002019
impl->makePartitions();
20012020

2021+
// Handle any work required for graph-owned memory allocations
2022+
impl->finalizeMemoryAllocations();
2023+
20022024
auto Device = impl->getGraphImpl()->getDevice();
20032025
for (auto Partition : impl->getPartitions()) {
20042026
if (!Partition->isHostTask()) {
@@ -2026,6 +2048,13 @@ void executable_command_graph::update(const std::vector<node> &Nodes) {
20262048
impl->update(NodeImpls);
20272049
}
20282050

2051+
size_t executable_command_graph::get_required_mem_size() const {
2052+
// Since each graph has a unique mem pool, return the current memory usage for
2053+
// now. This call my change if we move to being able to share memory between
2054+
// unique graphs.
2055+
return impl->getGraphImpl()->getMemPool().getMemUseCurrent();
2056+
}
2057+
20292058
dynamic_parameter_base::dynamic_parameter_base(
20302059
command_graph<graph_state::modifiable> Graph)
20312060
: impl(std::make_shared<dynamic_parameter_impl>(

sycl/source/detail/graph_impl.hpp

+54
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <detail/accessor_impl.hpp>
1818
#include <detail/cg.hpp>
1919
#include <detail/event_impl.hpp>
20+
#include <detail/graph_memory_pool.hpp>
2021
#include <detail/host_task.hpp>
2122
#include <detail/kernel_impl.hpp>
2223
#include <detail/sycl_mem_obj_t.hpp>
@@ -73,6 +74,11 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
7374
return node_type::subgraph;
7475
case sycl::detail::CGType::EnqueueNativeCommand:
7576
return node_type::native_command;
77+
case sycl::detail::CGType::AsyncAlloc:
78+
return node_type::async_malloc;
79+
case sycl::detail::CGType::AsyncFree:
80+
return node_type::async_free;
81+
7682
default:
7783
assert(false && "Invalid Graph Node Type");
7884
return node_type::empty;
@@ -473,6 +479,15 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
473479
}
474480
}
475481

482+
/// Returns true if this node should be enqueued to the backend, if not only
483+
/// its dependencies are considered.
484+
bool requiresEnqueue() const {
485+
return MNodeType != node_type::empty &&
486+
MNodeType != node_type::ext_oneapi_barrier &&
487+
MNodeType != node_type::async_malloc &&
488+
MNodeType != node_type::async_free;
489+
}
490+
476491
private:
477492
void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
478493
const std::vector<std::vector<char>> &OldArgStorage,
@@ -919,6 +934,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
919934
/// @return Context associated with graph.
920935
sycl::context getContext() const { return MContext; }
921936

937+
/// Query for the context impl tied to this graph.
938+
/// @return shared_ptr ref for the context impl associated with graph.
939+
const std::shared_ptr<sycl::detail::context_impl> &getContextImplPtr() const {
940+
return sycl::detail::getSyclObjImpl(MContext);
941+
}
942+
922943
/// Query for the device_impl tied to this graph.
923944
/// @return device_impl shared ptr reference associated with graph.
924945
const DeviceImplPtr &getDeviceImplPtr() const {
@@ -1139,6 +1160,24 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11391160

11401161
unsigned long long getID() const { return MID; }
11411162

1163+
/// Get the memory pool used for graph-owned allocations.
1164+
graph_mem_pool &getMemPool() { return MGraphMemPool; }
1165+
1166+
/// Mark that an executable graph was created from this modifiable graph, used
1167+
/// for tracking live graphs for graph-owned allocations.
1168+
void markExecGraphCreated() { MExecGraphCount++; }
1169+
1170+
/// Mark that an executable graph created from this modifiable graph was
1171+
/// destroyed, used for tracking live graphs for graph-owned allocations.
1172+
void markExecGraphDestroyed() {
1173+
assert(MExecGraphCount != 0);
1174+
MExecGraphCount--;
1175+
}
1176+
1177+
/// Get the number of unique executable graph instances currently alive for
1178+
/// this graph.
1179+
size_t getExecGraphCount() const { return MExecGraphCount; }
1180+
11421181
private:
11431182
/// Check the graph for cycles by performing a depth-first search of the
11441183
/// graph. If a node is visited more than once in a given path through the
@@ -1206,10 +1245,17 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
12061245
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
12071246
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
12081247
MBarrierDependencyMap;
1248+
/// Graph memory pool for handling graph-owned memory allocations for this
1249+
/// graph.
1250+
graph_mem_pool MGraphMemPool;
12091251

12101252
unsigned long long MID;
12111253
// Used for std::hash in order to create a unique hash for the instance.
12121254
inline static std::atomic<unsigned long long> NextAvailableID = 0;
1255+
1256+
// The number of live executable graphs that have been created from this
1257+
// modifiable graph
1258+
size_t MExecGraphCount = 0;
12131259
};
12141260

12151261
/// Class representing the implementation of command_graph<executable>.
@@ -1334,6 +1380,14 @@ class exec_graph_impl {
13341380

13351381
unsigned long long getID() const { return MID; }
13361382

1383+
/// Do any work required during finalization to finalize graph-owned memory
1384+
/// allocations.
1385+
void finalizeMemoryAllocations() {
1386+
// This call allocates physical memory and maps all virtual device
1387+
// allocations
1388+
MGraphImpl->getMemPool().allocateAndMapAll();
1389+
}
1390+
13371391
private:
13381392
/// Create a command-group for the node and add it to command-buffer by going
13391393
/// through the scheduler.

0 commit comments

Comments
 (0)