Skip to content

Commit 3d6204f

Browse files
committed
[SYCL][Graph] Initial implementation of graph-owned device allocations
- Allocations managed via new graph_mem_pool - Device allocations use virtual memory - Intercept async_alloc calls when adding nodes to graph - New tests for functionality - Design doc information about implementation
1 parent 599a9e3 commit 3d6204f

22 files changed

+875
-126
lines changed

sycl/doc/design/CommandGraph.md

+22
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,28 @@ from the same dynamic command-group object. This allows the SYCL runtime to
305305
access the list of alternative kernels when calling the UR API to append a
306306
kernel command to a command-buffer.
307307

308+
## Graph-Owned Memory Allocations
309+
### Device Allocations
310+
311+
Device allocations for graphs are implemented using virtual memory. Allocation
312+
commands performing a virtual reservation for the provided size, and physical
313+
memory is created and mapped only during graph finalization. This allows valid
314+
device addresses to be returned immediately when building the graph without the
315+
penalty of doing any memory allocations during graph building, which could have
316+
a negative impact on features such as whole-graph update through increased
317+
overhead.
318+
319+
### Behaviour of async_free
320+
321+
`async_free` nodes are treated as hints rather than an actual memory free
322+
operation. This is because deallocating during graph execution is both
323+
undesirable for performance and not feasible with the current
324+
implementation/backends. Instead a free node represents a promise from the user
325+
that the memory is no longer in use. This enables optimizations such as
326+
potentially reusing that memory for subsequent allocation nodes in the graph.
327+
This allows us to reduce the total amount of concurrent memory required by a
328+
single graph.
329+
308330
## Optimizations
309331
### Interactions with Profiling
310332

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ enum class node_type {
121121
memadvise = 7,
122122
ext_oneapi_barrier = 8,
123123
host_task = 9,
124-
native_command = 10
124+
native_command = 10,
125+
async_malloc = 11,
126+
async_free = 12
125127
};
126128

127129
/// Class representing a node in the graph, returned by command_graph::add().
@@ -429,6 +431,10 @@ class __SYCL_EXPORT executable_command_graph
429431
/// @param Nodes The nodes to use for updating the graph.
430432
void update(const std::vector<node> &Nodes);
431433

434+
/// Return the total amount of memory required by this graph for graph-owned
435+
/// memory allocations.
436+
size_t get_required_mem_size() const;
437+
432438
/// Common Reference Semantics
433439
friend bool operator==(const executable_command_graph &LHS,
434440
const executable_command_graph &RHS) {

sycl/source/detail/async_alloc.cpp

+40-23
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,28 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
4444
sycl::make_error_code(sycl::errc::feature_not_supported),
4545
"Only device backed asynchronous allocations are supported!");
4646

47-
h.throwIfGraphAssociated<
48-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
49-
sycl_ext_oneapi_async_alloc>();
50-
5147
auto &Adapter = h.getContextImplPtr()->getAdapter();
52-
auto &Q = h.MQueue->getHandleRef();
5348

5449
// Get events to wait on.
5550
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
5651
uint32_t numEvents = h.impl->CGData.MEvents.size();
5752

5853
void *alloc = nullptr;
59-
ur_event_handle_t Event;
60-
Adapter->call<sycl::errc::runtime,
61-
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
62-
Q, (ur_usm_pool_handle_t)0, size, nullptr, numEvents, depEvents.data(),
63-
&alloc, &Event);
54+
55+
ur_event_handle_t Event = nullptr;
56+
// If a graph is present do the allocation from the graph memory pool instead.
57+
if (auto Graph = h.getCommandGraph(); Graph) {
58+
// size may be modified to reflect the aligned size based on device
59+
// granularity.
60+
alloc = Graph->getMemPool().malloc(size, kind);
61+
62+
} else {
63+
auto &Q = h.MQueue->getHandleRef();
64+
Adapter->call<sycl::errc::runtime,
65+
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
66+
Q, (ur_usm_pool_handle_t)0, size, nullptr, numEvents, depEvents.data(),
67+
&alloc, &Event);
68+
}
6469

6570
// Async malloc must return a void* immediately.
6671
// Set up CommandGroup which is a no-op and pass the
@@ -90,25 +95,31 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
9095
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
9196
const memory_pool &pool) {
9297

93-
h.throwIfGraphAssociated<
94-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
95-
sycl_ext_oneapi_async_alloc>();
96-
9798
auto &Adapter = h.getContextImplPtr()->getAdapter();
98-
auto &Q = h.MQueue->getHandleRef();
9999
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);
100100

101101
// Get events to wait on.
102102
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
103103
uint32_t numEvents = h.impl->CGData.MEvents.size();
104104

105105
void *alloc = nullptr;
106-
ur_event_handle_t Event;
107-
Adapter->call<sycl::errc::runtime,
108-
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
109-
Q, memPoolImpl.get()->get_handle(), size, nullptr, numEvents,
110-
depEvents.data(), &alloc, &Event);
111106

107+
ur_event_handle_t Event = nullptr;
108+
// If a graph is present do the allocation from the graph memory pool instead.
109+
if (auto Graph = h.getCommandGraph(); Graph) {
110+
// size may be modified to reflect the aligned size based on device
111+
// granularity.
112+
// Memory pool is passed as the graph may use some properties of it.
113+
alloc = Graph->getMemPool().malloc(size, pool.get_alloc_kind(),
114+
sycl::detail::getSyclObjImpl(pool));
115+
116+
} else {
117+
auto &Q = h.MQueue->getHandleRef();
118+
Adapter->call<sycl::errc::runtime,
119+
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
120+
Q, memPoolImpl.get()->get_handle(), size, nullptr, numEvents,
121+
depEvents.data(), &alloc, &Event);
122+
}
112123
// Async malloc must return a void* immediately.
113124
// Set up CommandGroup which is a no-op and pass the event from the alloc.
114125
h.impl->MAsyncAllocEvent = Event;
@@ -135,9 +146,15 @@ async_malloc_from_pool(const sycl::queue &q, size_t size,
135146
}
136147

137148
__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr) {
138-
h.throwIfGraphAssociated<
139-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
140-
sycl_ext_oneapi_async_alloc>();
149+
if (auto Graph = h.getCommandGraph(); Graph) {
150+
// Check if the pointer to be freed has an associated allocation node, and
151+
// error if not
152+
if (!Graph->getMemPool().hasAllocation(ptr)) {
153+
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
154+
"Cannot add a free node to a graph for which "
155+
"there is no associated allocation node!");
156+
}
157+
}
141158

142159
h.impl->MFreePtr = ptr;
143160
h.setType(detail::CGType::AsyncFree);

sycl/source/detail/graph_impl.cpp

+35-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ inline const char *nodeTypeToString(node_type NodeType) {
5757
return "host_task";
5858
case node_type::native_command:
5959
return "native_command";
60+
case node_type::async_malloc:
61+
return "async_malloc";
62+
case node_type::async_free:
63+
return "async_free";
6064
}
6165
assert(false && "Unhandled node type");
6266
return {};
@@ -340,7 +344,7 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
340344
const sycl::device &SyclDevice,
341345
const sycl::property_list &PropList)
342346
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
343-
MEventsMap(), MInorderQueueMap(),
347+
MEventsMap(), MInorderQueueMap(), MGraphMemPool(SyclContext, SyclDevice),
344348
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
345349
checkGraphPropertiesAndThrow(PropList);
346350
if (PropList.has_property<property::graph::no_cycle_check>()) {
@@ -750,12 +754,12 @@ void graph_impl::beginRecording(
750754
}
751755
}
752756

753-
// Check if nodes are empty and if so loop back through predecessors until we
754-
// find the real dependency.
757+
// Check if nodes do not require enqueueing and if so loop back through
758+
// predecessors until we find the real dependency.
755759
void exec_graph_impl::findRealDeps(
756760
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
757761
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
758-
if (CurrentNode->isEmpty()) {
762+
if (!CurrentNode->requiresEnqueue()) {
759763
for (auto &N : CurrentNode->MPredecessors) {
760764
auto NodeImpl = N.lock();
761765
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
@@ -875,9 +879,9 @@ void exec_graph_impl::createCommandBuffers(
875879
Partition->MCommandBuffers[Device] = OutCommandBuffer;
876880

877881
for (const auto &Node : Partition->MSchedule) {
878-
// Empty nodes are not processed as other nodes, but only their
882+
// Some nodes are not scheduled like other nodes, and only their
879883
// dependencies are propagated in findRealDeps
880-
if (Node->isEmpty())
884+
if (!Node->requiresEnqueue())
881885
continue;
882886

883887
sycl::detail::CGType type = Node->MCGType;
@@ -943,6 +947,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
943947

944948
exec_graph_impl::~exec_graph_impl() {
945949
try {
950+
MGraphImpl->markExecGraphDestroyed();
951+
946952
const sycl::detail::AdapterPtr &Adapter =
947953
sycl::detail::getSyclObjImpl(MContext)->getAdapter();
948954
MSchedule.clear();
@@ -952,6 +958,9 @@ exec_graph_impl::~exec_graph_impl() {
952958
Event->wait(Event);
953959
}
954960

961+
// Clean up any graph-owned allocations that were allocated
962+
MGraphImpl->getMemPool().deallocateAndUnmapAll();
963+
955964
for (const auto &Partition : MPartitions) {
956965
Partition->MSchedule.clear();
957966
for (const auto &Iter : Partition->MCommandBuffers) {
@@ -1870,6 +1879,14 @@ modifiable_command_graph::finalize(const sycl::property_list &PropList) const {
18701879
// Graph is read and written in this scope so we lock
18711880
// this graph with full priviledges.
18721881
graph_impl::WriteLock Lock(impl->MMutex);
1882+
// If the graph uses graph-owned allocations and an executable graph already
1883+
// exists we must throw an error.
1884+
if (impl->getMemPool().hasAllocations() && impl->getExecGraphCount() > 0) {
1885+
throw sycl::exception(sycl::make_error_code(errc::invalid),
1886+
"Graphs containing allocations can only have a "
1887+
"single executable graph alive at any onc time.");
1888+
}
1889+
18731890
return command_graph<graph_state::executable>{
18741891
this->impl, this->impl->getContext(), PropList};
18751892
}
@@ -1997,11 +2014,16 @@ executable_command_graph::executable_command_graph(
19972014
const property_list &PropList)
19982015
: impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph, PropList)) {
19992016
finalizeImpl(); // Create backend representation for executable graph
2017+
// Mark that we have created an executable graph from the modifiable graph.
2018+
Graph->markExecGraphCreated();
20002019
}
20012020

20022021
void executable_command_graph::finalizeImpl() {
20032022
impl->makePartitions();
20042023

2024+
// Handle any work required for graph-owned memory allocations
2025+
impl->finalizeMemoryAllocations();
2026+
20052027
auto Device = impl->getGraphImpl()->getDevice();
20062028
for (auto Partition : impl->getPartitions()) {
20072029
if (!Partition->isHostTask()) {
@@ -2029,6 +2051,13 @@ void executable_command_graph::update(const std::vector<node> &Nodes) {
20292051
impl->update(NodeImpls);
20302052
}
20312053

2054+
size_t executable_command_graph::get_required_mem_size() const {
2055+
// Since each graph has a unique mem pool, return the current memory usage for
2056+
// now. This call my change if we move to being able to share memory between
2057+
// unique graphs.
2058+
return impl->getGraphImpl()->getMemPool().getMemUseCurrent();
2059+
}
2060+
20322061
dynamic_parameter_base::dynamic_parameter_base(
20332062
command_graph<graph_state::modifiable> Graph)
20342063
: impl(std::make_shared<dynamic_parameter_impl>(

sycl/source/detail/graph_impl.hpp

+68
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <detail/accessor_impl.hpp>
1818
#include <detail/cg.hpp>
1919
#include <detail/event_impl.hpp>
20+
#include <detail/graph_memory_pool.hpp>
2021
#include <detail/host_task.hpp>
2122
#include <detail/kernel_impl.hpp>
2223
#include <detail/sycl_mem_obj_t.hpp>
@@ -73,6 +74,11 @@ inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
7374
return node_type::subgraph;
7475
case sycl::detail::CGType::EnqueueNativeCommand:
7576
return node_type::native_command;
77+
case sycl::detail::CGType::AsyncAlloc:
78+
return node_type::async_malloc;
79+
case sycl::detail::CGType::AsyncFree:
80+
return node_type::async_free;
81+
7682
default:
7783
assert(false && "Invalid Graph Node Type");
7884
return node_type::empty;
@@ -473,6 +479,21 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
473479
}
474480
}
475481

482+
/// Returns true if this node should be enqueued to the backend, if not only
483+
/// its dependencies are considered.
484+
bool requiresEnqueue() const {
485+
switch (MNodeType) {
486+
case node_type::empty:
487+
case node_type::ext_oneapi_barrier:
488+
case node_type::async_malloc:
489+
case node_type::async_free:
490+
return false;
491+
492+
default:
493+
return true;
494+
}
495+
}
496+
476497
private:
477498
void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
478499
const std::vector<std::vector<char>> &OldArgStorage,
@@ -919,6 +940,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
919940
/// @return Context associated with graph.
920941
sycl::context getContext() const { return MContext; }
921942

943+
/// Query for the context impl tied to this graph.
944+
/// @return shared_ptr ref for the context impl associated with graph.
945+
const std::shared_ptr<sycl::detail::context_impl> &getContextImplPtr() const {
946+
return sycl::detail::getSyclObjImpl(MContext);
947+
}
948+
922949
/// Query for the device_impl tied to this graph.
923950
/// @return device_impl shared ptr reference associated with graph.
924951
const DeviceImplPtr &getDeviceImplPtr() const {
@@ -1139,6 +1166,32 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11391166

11401167
unsigned long long getID() const { return MID; }
11411168

1169+
/// Get the memory pool used for graph-owned allocations.
1170+
graph_mem_pool &getMemPool() { return MGraphMemPool; }
1171+
1172+
/// Mark that an executable graph was created from this modifiable graph, used
1173+
/// for tracking live graphs for graph-owned allocations.
1174+
void markExecGraphCreated() { MExecGraphCount++; }
1175+
1176+
/// Mark that an executable graph created from this modifiable graph was
1177+
/// destroyed, used for tracking live graphs for graph-owned allocations.
1178+
void markExecGraphDestroyed() {
1179+
while (true) {
1180+
size_t CurrentVal = MExecGraphCount;
1181+
if (CurrentVal == 0) {
1182+
break;
1183+
}
1184+
if (MExecGraphCount.compare_exchange_strong(CurrentVal, CurrentVal - 1) ==
1185+
false) {
1186+
continue;
1187+
}
1188+
}
1189+
}
1190+
1191+
/// Get the number of unique executable graph instances currently alive for
1192+
/// this graph.
1193+
size_t getExecGraphCount() const { return MExecGraphCount; }
1194+
11421195
private:
11431196
/// Check the graph for cycles by performing a depth-first search of the
11441197
/// graph. If a node is visited more than once in a given path through the
@@ -1206,10 +1259,17 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
12061259
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
12071260
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
12081261
MBarrierDependencyMap;
1262+
/// Graph memory pool for handling graph-owned memory allocations for this
1263+
/// graph.
1264+
graph_mem_pool MGraphMemPool;
12091265

12101266
unsigned long long MID;
12111267
// Used for std::hash in order to create a unique hash for the instance.
12121268
inline static std::atomic<unsigned long long> NextAvailableID = 0;
1269+
1270+
// The number of live executable graphs that have been created from this
1271+
// modifiable graph
1272+
std::atomic<size_t> MExecGraphCount = 0;
12131273
};
12141274

12151275
/// Class representing the implementation of command_graph<executable>.
@@ -1334,6 +1394,14 @@ class exec_graph_impl {
13341394

13351395
unsigned long long getID() const { return MID; }
13361396

1397+
/// Do any work required during finalization to finalize graph-owned memory
1398+
/// allocations.
1399+
void finalizeMemoryAllocations() {
1400+
// This call allocates physical memory and maps all virtual device
1401+
// allocations
1402+
MGraphImpl->getMemPool().allocateAndMapAll();
1403+
}
1404+
13371405
private:
13381406
/// Create a command-group for the node and add it to command-buffer by going
13391407
/// through the scheduler.

0 commit comments

Comments
 (0)