diff --git a/sycl/source/detail/graph/dynamic_impl.cpp b/sycl/source/detail/graph/dynamic_impl.cpp index d48ee31814461..3944dc6a2cfc3 100644 --- a/sycl/source/detail/graph/dynamic_impl.cpp +++ b/sycl/source/detail/graph/dynamic_impl.cpp @@ -90,6 +90,10 @@ void dynamic_local_accessor_base::updateLocalAccessor( ->updateLocalAccessor(NewAllocationSize); } +void dynamic_parameter_impl::registerNode(node_impl &NodeImpl, int ArgIndex) { + MNodes.emplace_back(NodeImpl.weak_from_this(), ArgIndex); +} + void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue, size_t Size) { // Number of bytes is taken from member of raw_kernel_arg object rather diff --git a/sycl/source/detail/graph/dynamic_impl.hpp b/sycl/source/detail/graph/dynamic_impl.hpp index 420dd8e112aa6..bd4cc7e13e54a 100644 --- a/sycl/source/detail/graph/dynamic_impl.hpp +++ b/sycl/source/detail/graph/dynamic_impl.hpp @@ -115,9 +115,7 @@ class dynamic_parameter_impl { /// @param NodeImpl The node to be registered /// @param ArgIndex The arg index for the kernel arg associated with this /// dynamic_parameter in NodeImpl - void registerNode(std::shared_ptr NodeImpl, int ArgIndex) { - MNodes.emplace_back(NodeImpl, ArgIndex); - } + void registerNode(node_impl &NodeImpl, int ArgIndex); /// Struct detailing an instance of the usage of the dynamic parameter in a /// dynamic CG. diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 2e58661b4c40b..0026cadadb4c7 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -142,12 +142,11 @@ void propagatePartitionUp(node_impl &Node, int PartitionNum) { /// @param PartitionNum Number to propagate. /// @param HostTaskList List of host tasks that have already been processed and /// are encountered as successors to the node Node. -void propagatePartitionDown( - node_impl &Node, int PartitionNum, - std::list> &HostTaskList) { +void propagatePartitionDown(node_impl &Node, int PartitionNum, + std::list &HostTaskList) { if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) { if (Node.MPartitionNum != -1) { - HostTaskList.push_front(Node.shared_from_this()); + HostTaskList.push_front(&Node); } return; } @@ -181,11 +180,11 @@ void partition::updateSchedule() { void exec_graph_impl::makePartitions() { int CurrentPartition = -1; - std::list> HostTaskList; + std::list HostTaskList; // find all the host-tasks in the graph - for (auto &Node : MNodeStorage) { - if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) { - HostTaskList.push_back(Node); + for (node_impl &Node : nodes()) { + if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) { + HostTaskList.push_back(&Node); } } @@ -215,29 +214,29 @@ void exec_graph_impl::makePartitions() { // group that includes the predecessor of `B` can be merged with the group of // the predecessors of the node `A`. while (HostTaskList.size() > 0) { - auto Node = HostTaskList.front(); + node_impl &Node = *HostTaskList.front(); HostTaskList.pop_front(); CurrentPartition++; - for (node_impl &Predecessor : Node->predecessors()) { + for (node_impl &Predecessor : Node.predecessors()) { propagatePartitionUp(Predecessor, CurrentPartition); } CurrentPartition++; - Node->MPartitionNum = CurrentPartition; + Node.MPartitionNum = CurrentPartition; CurrentPartition++; auto TmpSize = HostTaskList.size(); - for (node_impl &Successor : Node->successors()) { + for (node_impl &Successor : Node.successors()) { propagatePartitionDown(Successor, CurrentPartition, HostTaskList); } if (HostTaskList.size() > TmpSize) { // At least one HostTask has been re-numbered so group merge opportunities - for (const auto &HT : HostTaskList) { + for (node_impl *HT : HostTaskList) { auto HTPartitionNum = HT->MPartitionNum; if (HTPartitionNum != -1) { // can merge predecessors of node `Node` with predecessors of node // `HT` (HTPartitionNum-1) since HT must be reprocessed - for (const auto &NodeImpl : MNodeStorage) { - if (NodeImpl->MPartitionNum == Node->MPartitionNum - 1) { - NodeImpl->MPartitionNum = HTPartitionNum - 1; + for (node_impl &NodeImpl : nodes()) { + if (NodeImpl.MPartitionNum == Node.MPartitionNum - 1) { + NodeImpl.MPartitionNum = HTPartitionNum - 1; } } } else { @@ -251,12 +250,12 @@ void exec_graph_impl::makePartitions() { int PartitionFinalNum = 0; for (int i = -1; i <= CurrentPartition; i++) { const std::shared_ptr &Partition = std::make_shared(); - for (auto &Node : MNodeStorage) { - if (Node->MPartitionNum == i) { - MPartitionNodes[Node.get()] = PartitionFinalNum; - if (isPartitionRoot(*Node)) { - Partition->MRoots.insert(Node.get()); - if (Node->MCGType == CGType::CodeplayHostTask) { + for (node_impl &Node : nodes()) { + if (Node.MPartitionNum == i) { + MPartitionNodes[&Node] = PartitionFinalNum; + if (isPartitionRoot(Node)) { + Partition->MRoots.insert(&Node); + if (Node.MCGType == CGType::CodeplayHostTask) { Partition->MIsHostTask = true; } } @@ -295,8 +294,8 @@ void exec_graph_impl::makePartitions() { } // Reset node groups (if node have to be re-processed - e.g. subgraph) - for (auto &Node : MNodeStorage) { - Node->MPartitionNum = -1; + for (node_impl &Node : nodes()) { + Node.MPartitionNum = -1; } } @@ -376,19 +375,19 @@ std::set graph_impl::getCGEdges( // A unique set of dependencies obtained by checking requirements and events for (auto &Req : Requirements) { // Look through the graph for nodes which share this requirement - for (auto &Node : MNodeStorage) { - if (Node->hasRequirementDependency(Req)) { + for (node_impl &Node : nodes()) { + if (Node.hasRequirementDependency(Req)) { bool ShouldAddDep = true; // If any of this node's successors have this requirement then we skip // adding the current node as a dependency. - for (node_impl &Succ : Node->successors()) { + for (node_impl &Succ : Node.successors()) { if (Succ.hasRequirementDependency(Req)) { ShouldAddDep = false; break; } } if (ShouldAddDep) { - UniqueDeps.insert(Node.get()); + UniqueDeps.insert(&Node); } } } @@ -487,7 +486,7 @@ node_impl &graph_impl::add(std::function CGF, } for (auto &[DynamicParam, ArgIndex] : DynamicParams) { - DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex); + DynamicParam->registerNode(NodeImpl, ArgIndex); } return NodeImpl; @@ -611,10 +610,9 @@ void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue, MInorderQueueMap[Queue.weak_from_this()] = &Node; } -void graph_impl::makeEdge(std::shared_ptr Src, - std::shared_ptr Dest) { +void graph_impl::makeEdge(node_impl &Src, node_impl &Dest) { throwIfGraphRecordingQueue("make_edge()"); - if (Src == Dest) { + if (&Src == &Dest) { throw sycl::exception( make_error_code(sycl::errc::invalid), "make_edge() cannot be called when Src and Dest are the same."); @@ -624,8 +622,8 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestFound = false; for (const auto &Node : MNodeStorage) { - SrcFound |= Node == Src; - DestFound |= Node == Dest; + SrcFound |= Node.get() == &Src; + DestFound |= Node.get() == &Dest; if (SrcFound && DestFound) { break; @@ -641,37 +639,37 @@ void graph_impl::makeEdge(std::shared_ptr Src, "Dest must be a node inside the graph."); } - bool DestWasGraphRoot = Dest->MPredecessors.size() == 0; + bool DestWasGraphRoot = Dest.MPredecessors.size() == 0; // We need to add the edges first before checking for cycles - Src->registerSuccessor(*Dest); + Src.registerSuccessor(Dest); - bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; + bool DestLostRootStatus = DestWasGraphRoot && Dest.MPredecessors.size() == 1; if (DestLostRootStatus) { // Dest is no longer a Root node, so we need to remove it from MRoots. - MRoots.erase(Dest.get()); + MRoots.erase(&Dest); } // We can skip cycle checks if either Dest has no successors (cycle not // possible) or cycle checks have been disabled with the no_cycle_check // property; - if (Dest->MSuccessors.empty() || !MSkipCycleChecks) { + if (Dest.MSuccessors.empty() || !MSkipCycleChecks) { bool CycleFound = checkForCycles(); if (CycleFound) { // Remove the added successor and predecessor. - Src->MSuccessors.pop_back(); - Dest->MPredecessors.pop_back(); + Src.MSuccessors.pop_back(); + Dest.MPredecessors.pop_back(); if (DestLostRootStatus) { // Add Dest back into MRoots. - MRoots.insert(Dest.get()); + MRoots.insert(&Dest); } throw sycl::exception(make_error_code(sycl::errc::invalid), "Command graphs cannot contain cycles."); } } - removeRoot(*Dest); // remove receiver from root node list + removeRoot(Dest); // remove receiver from root node list } std::vector graph_impl::getExitNodesEvents( @@ -679,11 +677,11 @@ std::vector graph_impl::getExitNodesEvents( std::vector Events; auto RecordedQueueSP = RecordedQueue.lock(); - for (auto &Node : MNodeStorage) { - if (Node->MSuccessors.empty()) { - auto EventForNode = getEventForNode(*Node); + for (node_impl &Node : nodes()) { + if (Node.MSuccessors.empty()) { + auto EventForNode = getEventForNode(Node); if (EventForNode->getSubmittedQueue() == RecordedQueueSP) { - Events.push_back(getEventForNode(*Node)); + Events.push_back(getEventForNode(Node)); } } } @@ -1433,15 +1431,14 @@ void exec_graph_impl::update(std::shared_ptr GraphImpl) { std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i].get())); } - update(GraphImpl->MNodeStorage); + update(GraphImpl->nodes()); } -void exec_graph_impl::update(std::shared_ptr Node) { - this->update(std::vector>{Node}); +void exec_graph_impl::update(node_impl &Node) { + this->update(std::vector{&Node}); } -void exec_graph_impl::update( - const std::vector> &Nodes) { +void exec_graph_impl::update(nodes_range Nodes) { if (!MIsUpdatable) { throw sycl::exception(sycl::make_error_code(errc::invalid), "update() cannot be called on a executable graph " @@ -1502,7 +1499,7 @@ void exec_graph_impl::update( } bool exec_graph_impl::needsScheduledUpdate( - const std::vector> &Nodes, + nodes_range Nodes, std::vector &UpdateRequirements) { // If there are any accessor requirements, we have to update through the // scheduler to ensure that any allocations have taken place before trying to @@ -1511,30 +1508,30 @@ bool exec_graph_impl::needsScheduledUpdate( // At worst we may have as many requirements as there are for the entire graph // for updating. UpdateRequirements.reserve(MRequirements.size()); - for (auto &Node : Nodes) { + for (node_impl &Node : Nodes) { // Check if node(s) derived from this modifiable node exists in this graph - if (MIDCache.count(Node->getID()) == 0) { + if (MIDCache.count(Node.getID()) == 0) { throw sycl::exception( sycl::make_error_code(errc::invalid), "Node passed to update() is not part of the graph."); } - if (!Node->isUpdatable()) { + if (!Node.isUpdatable()) { std::string ErrorString = "node_type::"; - ErrorString += nodeTypeToString(Node->MNodeType); + ErrorString += nodeTypeToString(Node.MNodeType); ErrorString += " nodes are not supported for update. Only kernel, host_task, " "barrier and empty nodes are supported."; throw sycl::exception(errc::invalid, ErrorString); } - if (const auto &CG = Node->MCommandGroup; + if (const auto &CG = Node.MCommandGroup; CG && CG->getRequirements().size() != 0) { NeedScheduledUpdate = true; UpdateRequirements.insert(UpdateRequirements.end(), - Node->MCommandGroup->getRequirements().begin(), - Node->MCommandGroup->getRequirements().end()); + Node.MCommandGroup->getRequirements().begin(), + Node.MCommandGroup->getRequirements().end()); } } @@ -1740,18 +1737,17 @@ exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const { return PartitionedNodes; } -void exec_graph_impl::updateHostTasksImpl( - const std::vector> &Nodes) const { - for (auto &Node : Nodes) { - if (Node->MNodeType != node_type::host_task) { +void exec_graph_impl::updateHostTasksImpl(nodes_range Nodes) const { + for (node_impl &Node : Nodes) { + if (Node.MNodeType != node_type::host_task) { continue; } // Query the ID cache to find the equivalent exec node for the node passed // to this function. - auto ExecNode = MIDCache.find(Node->MID); + auto ExecNode = MIDCache.find(Node.MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - ExecNode->second->updateFromOtherNode(*Node); + ExecNode->second->updateFromOtherNode(Node); } } @@ -1852,21 +1848,18 @@ node modifiable_command_graph::addImpl(std::function CGF, void modifiable_command_graph::addGraphLeafDependencies(node Node) { // Find all exit nodes in the current graph and add them to the dependency // vector - std::shared_ptr DstImpl = - sycl::detail::getSyclObjImpl(Node); + detail::node_impl &DstImpl = *sycl::detail::getSyclObjImpl(Node); graph_impl::WriteLock Lock(impl->MMutex); for (auto &NodeImpl : impl->MNodeStorage) { - if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl != DstImpl)) { - impl->makeEdge(NodeImpl, DstImpl); + if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl.get() != &DstImpl)) { + impl->makeEdge(*NodeImpl, DstImpl); } } } void modifiable_command_graph::make_edge(node &Src, node &Dest) { - std::shared_ptr SenderImpl = - sycl::detail::getSyclObjImpl(Src); - std::shared_ptr ReceiverImpl = - sycl::detail::getSyclObjImpl(Dest); + detail::node_impl &SenderImpl = *sycl::detail::getSyclObjImpl(Src); + detail::node_impl &ReceiverImpl = *sycl::detail::getSyclObjImpl(Dest); graph_impl::WriteLock Lock(impl->MMutex); impl->makeEdge(SenderImpl, ReceiverImpl); @@ -2030,17 +2023,11 @@ void executable_command_graph::update( } void executable_command_graph::update(const node &Node) { - impl->update(sycl::detail::getSyclObjImpl(Node)); + impl->update(*sycl::detail::getSyclObjImpl(Node)); } void executable_command_graph::update(const std::vector &Nodes) { - std::vector> NodeImpls{}; - NodeImpls.reserve(Nodes.size()); - for (auto &Node : Nodes) { - NodeImpls.push_back(sycl::detail::getSyclObjImpl(Node)); - } - - impl->update(NodeImpls); + impl->update(Nodes); } size_t executable_command_graph::get_required_mem_size() const { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 9b97861736797..eedfcf0506bf3 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -330,8 +330,7 @@ class graph_impl : public std::enable_shared_from_this { /// this edge. /// @param Src The source of the new edge. /// @param Dest The destination of the new edge. - void makeEdge(std::shared_ptr Src, - std::shared_ptr Dest); + void makeEdge(node_impl &Src, node_impl &Dest); /// Throws an invalid exception if this function is called /// while a queue is recording commands to the graph. @@ -692,8 +691,8 @@ class exec_graph_impl { } void update(std::shared_ptr GraphImpl); - void update(std::shared_ptr Node); - void update(const std::vector> &Nodes); + void update(node_impl &Node); + void update(nodes_range Nodes); /// Calls UR entry-point to update nodes in command-buffer. /// @param CommandBuffer The UR command-buffer to update commands in. @@ -706,8 +705,7 @@ class exec_graph_impl { /// Update host-task nodes /// @param Nodes List of nodes to update, any node that is not a host-task /// will be ignored. - void updateHostTasksImpl( - const std::vector> &Nodes) const; + void updateHostTasksImpl(nodes_range Nodes) const; /// Splits a list of nodes into separate lists of nodes for each /// command-buffer partition. @@ -834,14 +832,14 @@ class exec_graph_impl { std::fstream Stream(FilePath, std::ios::out); Stream << "digraph dot {" << std::endl; - std::vector> Roots; - for (auto &Node : MNodeStorage) { - if (Node->MPredecessors.size() == 0) { - Roots.push_back(Node); + std::vector Roots; + for (node_impl &Node : nodes()) { + if (Node.MPredecessors.size() == 0) { + Roots.push_back(&Node); } } - for (std::shared_ptr Node : Roots) + for (node_impl *Node : Roots) Node->printDotRecursive(Stream, VisitedNodes, Verbose); Stream << "}" << std::endl; @@ -854,7 +852,7 @@ class exec_graph_impl { /// @param[out] UpdateRequirements Accessor requirements found in /p Nodes. /// return True if update should be done through the scheduler. bool needsScheduledUpdate( - const std::vector> &Nodes, + nodes_range Nodes, std::vector &UpdateRequirements); /// Sets the UR struct values required to update a graph node. diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 3a2a8c8b46381..0423c0b29f52a 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -1274,9 +1274,8 @@ Command *Scheduler::GraphBuilder::connectDepEvent( Command *Scheduler::GraphBuilder::addCommandGraphUpdate( ext::oneapi::experimental::detail::exec_graph_impl *Graph, - std::vector> - Nodes, - queue_impl *Queue, std::vector Requirements, + ext::oneapi::experimental::detail::nodes_range Nodes, queue_impl *Queue, + std::vector Requirements, std::vector &Events, std::vector &ToEnqueue) { auto NewCmd = diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index 2f5d4c5ca5e6a..365159c4b8cb7 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -643,9 +643,8 @@ ur_kernel_handle_t Scheduler::completeSpecConstMaterialization( EventImplPtr Scheduler::addCommandGraphUpdate( ext::oneapi::experimental::detail::exec_graph_impl *Graph, - std::vector> - Nodes, - queue_impl *Queue, std::vector Requirements, + ext::oneapi::experimental::detail::nodes_range Nodes, queue_impl *Queue, + std::vector Requirements, std::vector &Events) { std::vector AuxiliaryCmds; EventImplPtr NewCmdEvent = nullptr; diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index ee1eeabd6a99b..0a1d33731f064 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -178,6 +178,7 @@ inline namespace _V1 { namespace ext::oneapi::experimental::detail { class exec_graph_impl; class node_impl; +class nodes_range; } // namespace ext::oneapi::experimental::detail namespace detail { class queue_impl; @@ -477,9 +478,8 @@ class Scheduler { /// \param Events List of events that this update operation depends on EventImplPtr addCommandGraphUpdate( ext::oneapi::experimental::detail::exec_graph_impl *Graph, - std::vector> - Nodes, - queue_impl *Queue, std::vector Requirements, + ext::oneapi::experimental::detail::nodes_range Nodes, queue_impl *Queue, + std::vector Requirements, std::vector &Events); static bool CheckEventReadiness(context_impl &Context, @@ -654,10 +654,8 @@ class Scheduler { /// \param ToEnqueue List of commands which need to be enqueued. Command *addCommandGraphUpdate( ext::oneapi::experimental::detail::exec_graph_impl *Graph, - std::vector< - std::shared_ptr> - Nodes, - queue_impl *Queue, std::vector Requirements, + ext::oneapi::experimental::detail::nodes_range Nodes, queue_impl *Queue, + std::vector Requirements, std::vector &Events, std::vector &ToEnqueue); diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index dc5d2f9df6758..3ee06bdd9ca07 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -868,8 +868,7 @@ event handler::finalize() { if (auto GraphImpl = Queue->getCommandGraph(); GraphImpl) { auto EventImpl = detail::event_impl::create_completed_host_event(); EventImpl->setSubmittedQueue(Queue->weak_from_this()); - std::shared_ptr NodeImpl = - nullptr; + ext::oneapi::experimental::detail::node_impl *NodeImpl = nullptr; // GraphImpl is read and written in this scope so we lock this graph // with full priviledges. @@ -891,8 +890,7 @@ event handler::finalize() { GraphImpl->getLastInorderNode(Queue)) { Deps.push_back(DependentNode); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) - .shared_from_this(); + NodeImpl = &GraphImpl->add(NodeType, std::move(CommandGroup), Deps); // If we are recording an in-order queue remember the new node, so it // can be used as a dependency for any more nodes recorded from this @@ -907,8 +905,7 @@ event handler::finalize() { if (LastBarrierRecordedFromQueue) { Deps.push_back(LastBarrierRecordedFromQueue); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) - .shared_from_this(); + NodeImpl = &GraphImpl->add(NodeType, std::move(CommandGroup), Deps); if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) { GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl);