Skip to content

Commit 90b2086

Browse files
Pranav-Agarwal0612brenns10
authored andcommitted
Fixed Cycle detection algorithm from recursive to iterative
1 parent aef550e commit 90b2086

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

drgn_tools/deadlock.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66
from typing import Optional
77
from typing import Set
8+
from typing import Tuple
89
from typing import Union
910

1011
from drgn import Object
@@ -23,14 +24,16 @@ def from_object(cls, object: Object):
2324
if type_.kind == TypeKind.POINTER:
2425
object = object[0]
2526
type_ = object.type_
26-
2727
if type_.kind != TypeKind.STRUCT or object.address_ is None:
2828
raise ValueError(
2929
"A reference object of type struct is expected"
3030
)
3131

32-
object_node = cls(name=type_.typename(), address=object.address_)
33-
object_node.object = object
32+
object_node = cls.get_node(
33+
name=type_.typename(), address=object.address_
34+
)
35+
if not object_node.object:
36+
object_node.object = object
3437
return object_node
3538

3639
def __init__(
@@ -50,50 +53,63 @@ def __hash__(self):
5053
# Will be useful in case of Objects which don't exist in memory (example : CPU) - can set address=0 and name='CPU10'
5154
return hash((self.name, self.address))
5255

56+
@classmethod
57+
def get_node(cls, name: str, address: int):
58+
hash_value = hash((name, address))
59+
if hash_value in DependencyGraph.node_map:
60+
return DependencyGraph.node_map[hash_value]
61+
return cls(name, address)
62+
63+
node_map: Dict[int, Node] = dict()
64+
5365
def __init__(self):
54-
self.node_map: Dict[int, self.Node] = dict()
66+
pass
5567

5668
def add_edge(self, src: Node, dst: Node) -> None:
5769
if hash(src) not in self.node_map:
5870
self.node_map[hash(src)] = src
5971
else:
6072
src = self.node_map[hash(src)]
61-
6273
if hash(dst) not in self.node_map:
6374
self.node_map[hash(dst)] = dst
6475
else:
6576
dst = self.node_map[hash(dst)]
6677

6778
if dst not in src.blocked_nodes:
6879
src.blocked_nodes.append(dst)
69-
7080
if src not in dst.depends_on:
7181
dst.depends_on.append(src)
7282

7383
def detect_cycle(self) -> Optional[List[List[Node]]]:
7484
visited: Set[DependencyGraph.Node] = set()
75-
path: List[DependencyGraph.Node] = []
7685
cycles: List[List[DependencyGraph.Node]] = []
77-
78-
def dfs(node: DependencyGraph.Node) -> None:
79-
# If the node is currently being visited (part of the current DFS path), we've found a cycle
80-
if node in path:
81-
cycle_start = path.index(node)
82-
cycles.append(path[cycle_start:] + [node])
83-
return
84-
85-
# If it's fully visited, no need to process this node again
86-
if node in visited:
87-
return
88-
89-
visited.add(node)
90-
path.append(node)
91-
92-
# Recursively visit all neighboring nodes
93-
for neighbor in node.depends_on:
94-
dfs(neighbor)
95-
96-
path.pop()
86+
parent: Dict[int, Optional[DependencyGraph.Node]] = dict()
87+
88+
def dfs(start_node: DependencyGraph.Node) -> None:
89+
stack: List[
90+
Tuple[DependencyGraph.Node, Optional[DependencyGraph.Node]]
91+
] = [(start_node, None)]
92+
while stack:
93+
node, parent_node = stack.pop()
94+
if node in visited:
95+
continue
96+
visited.add(node)
97+
parent[hash(node)] = parent_node
98+
99+
for neighbour in node.depends_on:
100+
if neighbour not in visited:
101+
stack.append((neighbour, node))
102+
else:
103+
cycle: List[DependencyGraph.Node] = []
104+
cycle.append(neighbour)
105+
temp: Optional[DependencyGraph.Node] = node
106+
while temp and temp != neighbour:
107+
cycle.append(temp)
108+
temp = parent[hash(temp)]
109+
cycle.append(neighbour)
110+
cycle.reverse()
111+
cycles.append(cycle)
112+
return
97113

98114
# Run DFS for all nodes in the graph
99115
for node in self.node_map.values():
@@ -112,12 +128,9 @@ def add_args(self, parser: argparse.ArgumentParser) -> None:
112128
def run(self, prog: Program, args: argparse.Namespace) -> None:
113129
graph: DependencyGraph = DependencyGraph()
114130
get_mutex_lock_info(prog, stack=False, graph=graph)
115-
116131
cycles = graph.detect_cycle()
117-
118132
if not cycles:
119133
print("No cycle found")
120134
return
121-
122135
for cycle in cycles:
123136
print(cycle)

0 commit comments

Comments
 (0)