Skip to content

[NFC][SYCL][Graph] Switch misc shared_ptr<node_impl> to raw ptr/ref #19487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl/source/detail/graph/dynamic_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ void dynamic_local_accessor_base::updateLocalAccessor(
->updateLocalAccessor(NewAllocationSize);
}

void dynamic_parameter_impl::registerNode(node_impl &NodeImpl, int ArgIndex) {
MNodes.emplace_back(NodeImpl.weak_from_this(), ArgIndex);
}

void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
size_t Size) {
// Number of bytes is taken from member of raw_kernel_arg object rather
Expand Down
4 changes: 1 addition & 3 deletions sycl/source/detail/graph/dynamic_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ class dynamic_parameter_impl {
/// @param NodeImpl The node to be registered
/// @param ArgIndex The arg index for the kernel arg associated with this
/// dynamic_parameter in NodeImpl
void registerNode(std::shared_ptr<node_impl> NodeImpl, int ArgIndex) {
MNodes.emplace_back(NodeImpl, ArgIndex);
}
void registerNode(node_impl &NodeImpl, int ArgIndex);

/// Struct detailing an instance of the usage of the dynamic parameter in a
/// dynamic CG.
Expand Down
151 changes: 69 additions & 82 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,11 @@ void propagatePartitionUp(node_impl &Node, int PartitionNum) {
/// @param PartitionNum Number to propagate.
/// @param HostTaskList List of host tasks that have already been processed and
/// are encountered as successors to the node Node.
void propagatePartitionDown(
node_impl &Node, int PartitionNum,
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
void propagatePartitionDown(node_impl &Node, int PartitionNum,
std::list<node_impl *> &HostTaskList) {
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
if (Node.MPartitionNum != -1) {
HostTaskList.push_front(Node.shared_from_this());
HostTaskList.push_front(&Node);
}
return;
}
Expand Down Expand Up @@ -181,11 +180,11 @@ void partition::updateSchedule() {

void exec_graph_impl::makePartitions() {
int CurrentPartition = -1;
std::list<std::shared_ptr<node_impl>> HostTaskList;
std::list<node_impl *> HostTaskList;
// find all the host-tasks in the graph
for (auto &Node : MNodeStorage) {
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
HostTaskList.push_back(Node);
for (node_impl &Node : nodes()) {
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
HostTaskList.push_back(&Node);
}
}

Expand Down Expand Up @@ -215,29 +214,29 @@ void exec_graph_impl::makePartitions() {
// group that includes the predecessor of `B` can be merged with the group of
// the predecessors of the node `A`.
while (HostTaskList.size() > 0) {
auto Node = HostTaskList.front();
node_impl &Node = *HostTaskList.front();
HostTaskList.pop_front();
CurrentPartition++;
for (node_impl &Predecessor : Node->predecessors()) {
for (node_impl &Predecessor : Node.predecessors()) {
propagatePartitionUp(Predecessor, CurrentPartition);
}
CurrentPartition++;
Node->MPartitionNum = CurrentPartition;
Node.MPartitionNum = CurrentPartition;
CurrentPartition++;
auto TmpSize = HostTaskList.size();
for (node_impl &Successor : Node->successors()) {
for (node_impl &Successor : Node.successors()) {
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
}
if (HostTaskList.size() > TmpSize) {
// At least one HostTask has been re-numbered so group merge opportunities
for (const auto &HT : HostTaskList) {
for (node_impl *HT : HostTaskList) {
auto HTPartitionNum = HT->MPartitionNum;
if (HTPartitionNum != -1) {
// can merge predecessors of node `Node` with predecessors of node
// `HT` (HTPartitionNum-1) since HT must be reprocessed
for (const auto &NodeImpl : MNodeStorage) {
if (NodeImpl->MPartitionNum == Node->MPartitionNum - 1) {
NodeImpl->MPartitionNum = HTPartitionNum - 1;
for (node_impl &NodeImpl : nodes()) {
if (NodeImpl.MPartitionNum == Node.MPartitionNum - 1) {
NodeImpl.MPartitionNum = HTPartitionNum - 1;
}
}
} else {
Expand All @@ -251,12 +250,12 @@ void exec_graph_impl::makePartitions() {
int PartitionFinalNum = 0;
for (int i = -1; i <= CurrentPartition; i++) {
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
for (auto &Node : MNodeStorage) {
if (Node->MPartitionNum == i) {
MPartitionNodes[Node.get()] = PartitionFinalNum;
if (isPartitionRoot(*Node)) {
Partition->MRoots.insert(Node.get());
if (Node->MCGType == CGType::CodeplayHostTask) {
for (node_impl &Node : nodes()) {
if (Node.MPartitionNum == i) {
MPartitionNodes[&Node] = PartitionFinalNum;
if (isPartitionRoot(Node)) {
Partition->MRoots.insert(&Node);
if (Node.MCGType == CGType::CodeplayHostTask) {
Partition->MIsHostTask = true;
}
}
Expand Down Expand Up @@ -295,8 +294,8 @@ void exec_graph_impl::makePartitions() {
}

// Reset node groups (if node have to be re-processed - e.g. subgraph)
for (auto &Node : MNodeStorage) {
Node->MPartitionNum = -1;
for (node_impl &Node : nodes()) {
Node.MPartitionNum = -1;
}
}

Expand Down Expand Up @@ -376,19 +375,19 @@ std::set<node_impl *> graph_impl::getCGEdges(
// A unique set of dependencies obtained by checking requirements and events
for (auto &Req : Requirements) {
// Look through the graph for nodes which share this requirement
for (auto &Node : MNodeStorage) {
if (Node->hasRequirementDependency(Req)) {
for (node_impl &Node : nodes()) {
if (Node.hasRequirementDependency(Req)) {
bool ShouldAddDep = true;
// If any of this node's successors have this requirement then we skip
// adding the current node as a dependency.
for (node_impl &Succ : Node->successors()) {
for (node_impl &Succ : Node.successors()) {
if (Succ.hasRequirementDependency(Req)) {
ShouldAddDep = false;
break;
}
}
if (ShouldAddDep) {
UniqueDeps.insert(Node.get());
UniqueDeps.insert(&Node);
}
}
}
Expand Down Expand Up @@ -487,7 +486,7 @@ node_impl &graph_impl::add(std::function<void(handler &)> CGF,
}

for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
DynamicParam->registerNode(NodeImpl, ArgIndex);
}

return NodeImpl;
Expand Down Expand Up @@ -611,10 +610,9 @@ void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
MInorderQueueMap[Queue.weak_from_this()] = &Node;
}

void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
std::shared_ptr<node_impl> Dest) {
void graph_impl::makeEdge(node_impl &Src, node_impl &Dest) {
throwIfGraphRecordingQueue("make_edge()");
if (Src == Dest) {
if (&Src == &Dest) {
throw sycl::exception(
make_error_code(sycl::errc::invalid),
"make_edge() cannot be called when Src and Dest are the same.");
Expand All @@ -624,8 +622,8 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
bool DestFound = false;
for (const auto &Node : MNodeStorage) {

SrcFound |= Node == Src;
DestFound |= Node == Dest;
SrcFound |= Node.get() == &Src;
DestFound |= Node.get() == &Dest;

if (SrcFound && DestFound) {
break;
Expand All @@ -641,49 +639,49 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
"Dest must be a node inside the graph.");
}

bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
bool DestWasGraphRoot = Dest.MPredecessors.size() == 0;

// We need to add the edges first before checking for cycles
Src->registerSuccessor(*Dest);
Src.registerSuccessor(Dest);

bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
bool DestLostRootStatus = DestWasGraphRoot && Dest.MPredecessors.size() == 1;
if (DestLostRootStatus) {
// Dest is no longer a Root node, so we need to remove it from MRoots.
MRoots.erase(Dest.get());
MRoots.erase(&Dest);
}

// We can skip cycle checks if either Dest has no successors (cycle not
// possible) or cycle checks have been disabled with the no_cycle_check
// property;
if (Dest->MSuccessors.empty() || !MSkipCycleChecks) {
if (Dest.MSuccessors.empty() || !MSkipCycleChecks) {
bool CycleFound = checkForCycles();

if (CycleFound) {
// Remove the added successor and predecessor.
Src->MSuccessors.pop_back();
Dest->MPredecessors.pop_back();
Src.MSuccessors.pop_back();
Dest.MPredecessors.pop_back();
if (DestLostRootStatus) {
// Add Dest back into MRoots.
MRoots.insert(Dest.get());
MRoots.insert(&Dest);
}

throw sycl::exception(make_error_code(sycl::errc::invalid),
"Command graphs cannot contain cycles.");
}
}
removeRoot(*Dest); // remove receiver from root node list
removeRoot(Dest); // remove receiver from root node list
}

std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
std::weak_ptr<sycl::detail::queue_impl> RecordedQueue) {
std::vector<sycl::detail::EventImplPtr> Events;

auto RecordedQueueSP = RecordedQueue.lock();
for (auto &Node : MNodeStorage) {
if (Node->MSuccessors.empty()) {
auto EventForNode = getEventForNode(*Node);
for (node_impl &Node : nodes()) {
if (Node.MSuccessors.empty()) {
auto EventForNode = getEventForNode(Node);
if (EventForNode->getSubmittedQueue() == RecordedQueueSP) {
Events.push_back(getEventForNode(*Node));
Events.push_back(getEventForNode(Node));
}
}
}
Expand Down Expand Up @@ -1433,15 +1431,14 @@ void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i].get()));
}

update(GraphImpl->MNodeStorage);
update(GraphImpl->nodes());
}

void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
void exec_graph_impl::update(node_impl &Node) {
this->update(std::vector<node_impl *>{&Node});
}

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
void exec_graph_impl::update(nodes_range Nodes) {
if (!MIsUpdatable) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"update() cannot be called on a executable graph "
Expand Down Expand Up @@ -1502,7 +1499,7 @@ void exec_graph_impl::update(
}

bool exec_graph_impl::needsScheduledUpdate(
const std::vector<std::shared_ptr<node_impl>> &Nodes,
nodes_range Nodes,
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
// If there are any accessor requirements, we have to update through the
// scheduler to ensure that any allocations have taken place before trying to
Expand All @@ -1511,30 +1508,30 @@ bool exec_graph_impl::needsScheduledUpdate(
// At worst we may have as many requirements as there are for the entire graph
// for updating.
UpdateRequirements.reserve(MRequirements.size());
for (auto &Node : Nodes) {
for (node_impl &Node : Nodes) {
// Check if node(s) derived from this modifiable node exists in this graph
if (MIDCache.count(Node->getID()) == 0) {
if (MIDCache.count(Node.getID()) == 0) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Node passed to update() is not part of the graph.");
}

if (!Node->isUpdatable()) {
if (!Node.isUpdatable()) {
std::string ErrorString = "node_type::";
ErrorString += nodeTypeToString(Node->MNodeType);
ErrorString += nodeTypeToString(Node.MNodeType);
ErrorString +=
" nodes are not supported for update. Only kernel, host_task, "
"barrier and empty nodes are supported.";
throw sycl::exception(errc::invalid, ErrorString);
}

if (const auto &CG = Node->MCommandGroup;
if (const auto &CG = Node.MCommandGroup;
CG && CG->getRequirements().size() != 0) {
NeedScheduledUpdate = true;

UpdateRequirements.insert(UpdateRequirements.end(),
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());
Node.MCommandGroup->getRequirements().begin(),
Node.MCommandGroup->getRequirements().end());
}
}

Expand Down Expand Up @@ -1740,18 +1737,17 @@ exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
return PartitionedNodes;
}

void exec_graph_impl::updateHostTasksImpl(
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
for (auto &Node : Nodes) {
if (Node->MNodeType != node_type::host_task) {
void exec_graph_impl::updateHostTasksImpl(nodes_range Nodes) const {
for (node_impl &Node : Nodes) {
if (Node.MNodeType != node_type::host_task) {
continue;
}
// Query the ID cache to find the equivalent exec node for the node passed
// to this function.
auto ExecNode = MIDCache.find(Node->MID);
auto ExecNode = MIDCache.find(Node.MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

ExecNode->second->updateFromOtherNode(*Node);
ExecNode->second->updateFromOtherNode(Node);
}
}

Expand Down Expand Up @@ -1852,21 +1848,18 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
void modifiable_command_graph::addGraphLeafDependencies(node Node) {
// Find all exit nodes in the current graph and add them to the dependency
// vector
std::shared_ptr<detail::node_impl> DstImpl =
sycl::detail::getSyclObjImpl(Node);
detail::node_impl &DstImpl = *sycl::detail::getSyclObjImpl(Node);
graph_impl::WriteLock Lock(impl->MMutex);
for (auto &NodeImpl : impl->MNodeStorage) {
if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl != DstImpl)) {
impl->makeEdge(NodeImpl, DstImpl);
if ((NodeImpl->MSuccessors.size() == 0) && (NodeImpl.get() != &DstImpl)) {
impl->makeEdge(*NodeImpl, DstImpl);
}
}
}

void modifiable_command_graph::make_edge(node &Src, node &Dest) {
std::shared_ptr<detail::node_impl> SenderImpl =
sycl::detail::getSyclObjImpl(Src);
std::shared_ptr<detail::node_impl> ReceiverImpl =
sycl::detail::getSyclObjImpl(Dest);
detail::node_impl &SenderImpl = *sycl::detail::getSyclObjImpl(Src);
detail::node_impl &ReceiverImpl = *sycl::detail::getSyclObjImpl(Dest);

graph_impl::WriteLock Lock(impl->MMutex);
impl->makeEdge(SenderImpl, ReceiverImpl);
Expand Down Expand Up @@ -2030,17 +2023,11 @@ void executable_command_graph::update(
}

void executable_command_graph::update(const node &Node) {
impl->update(sycl::detail::getSyclObjImpl(Node));
impl->update(*sycl::detail::getSyclObjImpl(Node));
}

void executable_command_graph::update(const std::vector<node> &Nodes) {
std::vector<std::shared_ptr<node_impl>> NodeImpls{};
NodeImpls.reserve(Nodes.size());
for (auto &Node : Nodes) {
NodeImpls.push_back(sycl::detail::getSyclObjImpl(Node));
}

impl->update(NodeImpls);
impl->update(Nodes);
}

size_t executable_command_graph::get_required_mem_size() const {
Expand Down
Loading