Skip to content

Commit 2566769

Browse files
committed
clean: Remove unused code and update docs
1 parent 3e8dcc6 commit 2566769

3 files changed

Lines changed: 42 additions & 21 deletions

File tree

docs/explanation/infrastructure.md

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,38 @@ The fixed point rewriter transforms a pass by iteratively applying the pass till
204204
///
205205
///
206206

207-
## Dataflow
208-
`dataflow` defines a reusable forward dataflow analysis framework.
209-
210207
## Lattice
211-
`lattice` defines ordering and merge semantics used by analysis states. Use `LatticeBase` to model the state domain, then use `ForwardDataflowAnalysis` to run transfer and merge over a graph. This keeps transfer logic in the analysis class and ordering logic in the lattice class.
208+
209+
The `Lattice` class in [`lattice.py`](../../src/oqd_compiler_infrastructure/lattice.py) defines a generic lattice interface with all the methods it requires. The following methods are defined:
210+
- top(): Returns the top element of the lattice.
211+
- bottom(): Returns the bottom element of the lattice.
212+
- leq(): Returns True if `t1 <= t2` in the lattice.
213+
- join(): Returns the least upper bound of `t1` and `t2`.
214+
- meet(): Returns the greatest lower bound of `t1` and `t2`.
215+
216+
These methods allow analysis to be done on a concrete instance of the lattice.
217+
218+
The `LatticeBase` class defines a simple concrete implementation of a `Lattice`. It stores a dictionary that maps each node of the lattice to its immediate parent(s). It defines `LatticeTop` as the top element, and `LatticeBottom` as the bottom element of the lattice. This class defines the following helper methods:
219+
- is_class_node(t): Returns True if `t` is a valid lattice node.
220+
- add_node(t, parent): Adds a node to the lattice, by tracking the parent(s) of the node.
221+
- atomic_ancestors(t): Returns the atomic ancestors of a given node.
222+
223+
These helper methods are used in the concrete implementation of the lattice operation methods: `leq`, `join`, and `meet`.
224+
225+
You can define your own lattice using the `LatticeBase` class.
226+
227+
## Dataflow
228+
229+
The `GraphProtocol` class in [`dataflow.py`](../../src/oqd_compiler_infrastructure/dataflow.py) defines a generic Graph Protocol interface that provides the nodes in the graph, the predecessors of a given node, and the successors of a given node. This protocol can be applied on any graph object for analysis: control flow graphs, dependency graphs, IR graphs, etc.
230+
231+
The `DataflowAnalysis` class requires a Lattice to implement the analysis on. This class provides a dataflow analysis framework that can be used to implement a specific dataflow analysis The following methods are defined:
232+
- transfer(node, state_in): Returns the state of a given node after transfer.
233+
- bottom(): Returns the default starting state for all nodes.
234+
- merge(states): Joins incoming states using the lattice's join operation.
235+
- states_equal(t1, t2): Returns True if two states are equal in the lattice.
236+
237+
The `ForwardDataflowAnalysis` class implements the forward dataflow analysis using the worklist algorithm with the `analyze` method. The output of the analysis is an instance of the `DataflowResult` class which contains the `in_states`, `out_states`, and the `iterations`.
238+
239+
The `MapForwardDataflowAnalysis` is a helper instance of `ForwardDataflowAnalysis` for states that use a `dict[str, LatticeType]` type for analysis.
240+
241+

src/oqd_compiler_infrastructure/dataflow.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections import deque
1818
from dataclasses import dataclass
1919
from typing import Generic, Iterable, TypeVar, Protocol
20-
2120
from oqd_compiler_infrastructure.lattice import Lattice, LatticeType
2221

2322
NodeType = TypeVar("NodeType")
@@ -66,10 +65,6 @@ def bottom(self) -> LatticeType:
6665
"""Returns the default starting state for all nodes."""
6766
return self.lattice.bottom()
6867

69-
def boundary_state(self, node: NodeType) -> LatticeType:
70-
"""Returns the extra state injected at a given node."""
71-
return self.bottom()
72-
7368
def merge(self, states: Iterable[LatticeType]) -> LatticeType:
7469
"""Joins incoming states using the lattice's join operation."""
7570
states_list = list(states)
@@ -88,7 +83,6 @@ def states_equal(self, t1: LatticeType, t2: LatticeType) -> bool:
8883
class ForwardDataflowAnalysis(DataflowAnalysis[NodeType, LatticeType], Generic[NodeType, LatticeType]):
8984
"""
9085
Forward dataflow analysis framework.
91-
This class implements the fixed point loop.
9286
"""
9387
def __init__(self, lattice: Lattice[LatticeType]):
9488
super().__init__(lattice)
@@ -117,7 +111,7 @@ def analyze(self, graph: GraphProtocol[NodeType]) -> DataflowResult[NodeType, La
117111
iterations += 1
118112

119113
pred_states = [out_states[pred] for pred in graph.predecessors(node)]
120-
merged_input = self.merge([self.boundary_state(node), *pred_states])
114+
merged_input = self.merge(pred_states)
121115

122116
if not self.states_equal(in_states[node], merged_input):
123117
in_states[node] = merged_input
@@ -137,23 +131,20 @@ def analyze(self, graph: GraphProtocol[NodeType]) -> DataflowResult[NodeType, La
137131

138132
class MapForwardDataflowAnalysis(ForwardDataflowAnalysis[NodeType, dict[str, LatticeType]], Generic[NodeType, LatticeType]):
139133
"""
140-
Helper instance of ForwardDataflowAnalysis for states that need a dict Type
134+
Helper instance of ForwardDataflowAnalysis for states that need a dict Type.
141135
"""
142136
def __init__(self, lattice: Lattice[LatticeType]):
143137
super().__init__(lattice)
144138

145139
def bottom(self) -> dict[str, LatticeType]:
146140
return {}
147141

148-
def boundary_state(self, node: NodeType) -> dict[str, LatticeType]:
149-
return {}
150-
151142
def merge(self, states: Iterable[dict[str, LatticeType]]) -> dict[str, LatticeType]:
152143
states_list = list(states)
153144
if not states_list:
154145
return {}
155146

156-
bottom = self.lattice.bottom()
147+
bottom = self.bottom()
157148
all_keys = set().union(*(state.keys() for state in states_list))
158149

159150
merged = {}
@@ -166,7 +157,7 @@ def merge(self, states: Iterable[dict[str, LatticeType]]) -> dict[str, LatticeTy
166157
return merged
167158

168159
def states_equal(self, t1: dict[str, LatticeType], t2: dict[str, LatticeType]) -> bool:
169-
bottom = self.lattice.bottom()
160+
bottom = self.bottom()
170161
all_keys = set(t1.keys()).union(t2.keys())
171162

172163
for key in all_keys:

tests/test_dataflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def test_reachability(self):
7878
)
7979
result = Reachability().analyze(graph)
8080

81-
assert result.in_states["entry"] == {"ENTRY"}
82-
assert result.out_states["entry"] == {"ENTRY", "entry"}
83-
assert result.out_states["mid"] == {"ENTRY", "entry", "mid"}
84-
assert result.out_states["exit"] == {"ENTRY", "entry", "mid", "exit"}
81+
assert result.in_states["entry"] == set()
82+
assert result.out_states["entry"] == {"entry"}
83+
assert result.out_states["mid"] == {"entry", "mid"}
84+
assert result.out_states["exit"] == {"entry", "mid", "exit"}
8585
assert result.iterations >= 3
8686

8787

0 commit comments

Comments
 (0)