Skip to content

Commit 8e3a9d7

Browse files
committed
1 parent b22fab0 commit 8e3a9d7

File tree

1 file changed

+354
-0
lines changed

1 file changed

+354
-0
lines changed

graphs/travelling_salesman_problem.py

+354
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
import itertools
2+
from collections.abc import Generator, Hashable, Sequence
3+
from dataclasses import dataclass
4+
from typing import Generic, TypeVar
5+
6+
T = TypeVar("T", bound=int | str | Hashable)
7+
8+
9+
@dataclass(frozen=True)
10+
class TSPEdge(Generic[T]):
11+
"""
12+
Represents an edge in a graph for the Traveling Salesman Problem (TSP).
13+
14+
Attributes:
15+
vertices (frozenset[T]): A pair of vertices representing the edge.
16+
weight (float): The weight (or cost) of the edge.
17+
"""
18+
19+
vertices: frozenset[T]
20+
weight: float
21+
22+
def __str__(self) -> str:
23+
return f"({self.vertices}, {self.weight})"
24+
25+
def __post_init__(self):
26+
# Ensures that there is no loop in a vertex
27+
if len(self.vertices) != 2:
28+
raise ValueError("frozenset must have exactly 2 elements")
29+
30+
@classmethod
31+
def from_3_tuple(cls, x, y, w) -> "TSPEdge":
32+
"""
33+
Construct TSPEdge from a 3-tuple (x, y, w).
34+
x & y are vertices and w is the weight.
35+
"""
36+
return cls(frozenset([x, y]), w)
37+
38+
def __eq__(self, other: object) -> bool:
39+
if not isinstance(other, TSPEdge):
40+
return NotImplemented
41+
return self.vertices == other.vertices
42+
43+
def __add__(self, other: "TSPEdge") -> float:
44+
return self.weight + other.weight
45+
46+
47+
class TSPGraph(Generic[T]):
48+
"""
49+
Represents a graph for the Traveling Salesman Problem (TSP).
50+
The graph is:
51+
- Simple (no loops or multiple edges between vertices).
52+
- Undirected.
53+
- Connected.
54+
"""
55+
56+
def __init__(self, edges: frozenset[TSPEdge] | None = None):
57+
self._edges = edges or frozenset()
58+
59+
def __str__(self) -> str:
60+
return f"{[str(edge) for edge in self._edges]}"
61+
62+
@classmethod
63+
def from_3_tuples(cls, *edges) -> "TSPGraph":
64+
return cls(frozenset(TSPEdge.from_3_tuple(x, y, w) for x, y, w in edges))
65+
66+
@classmethod
67+
def from_weights(cls, weights: list) -> "TSPGraph":
68+
"""
69+
Create TSPGraph from Weights (List of Lists) where the vertices
70+
are labeled with integers.
71+
"""
72+
triples = [
73+
(x, y, weights[x][y])
74+
for x, y in itertools.product(range(len(weights)), range(len(weights[0])))
75+
if x != y # Filter out self-loops
76+
]
77+
# return cls.from_3_tuples(*cast(list[tuple[T, T, float]], triples))
78+
return cls.from_3_tuples(*triples)
79+
80+
@property
81+
def vertices(self) -> frozenset[T]:
82+
return frozenset(vertex for edge in self._edges for vertex in edge.vertices)
83+
84+
@property
85+
def edges(self) -> frozenset[TSPEdge]:
86+
return self._edges
87+
88+
@property
89+
def weight(self) -> float:
90+
"""Total Weight of TSPGraph."""
91+
return sum(edge.weight for edge in self._edges)
92+
93+
def __contains__(self, obj: T | TSPEdge) -> bool:
94+
if isinstance(obj, TSPEdge):
95+
return any(obj == edge_ for edge_ in self._edges)
96+
else:
97+
return obj in self.vertices
98+
99+
def is_edge_in_graph(self, x: T, y: T) -> bool:
100+
return frozenset([x, y]) in self.get_edges()
101+
102+
def add_edge(self, x: T, y: T, w: float) -> "TSPGraph":
103+
# Validator to check if either x or y is in the vertex set to ensure
104+
# that the graph would be connected
105+
# Only use this validator if there exist at least 1 edge in the edge set.
106+
if self._edges and x not in self and y not in self:
107+
error_message = f"Adding the edge ({x}, {y}) may form a disconnected graph."
108+
raise ValueError(error_message)
109+
110+
new_edge = TSPEdge.from_3_tuple(
111+
x, y, w
112+
) # This would raise Vertex Loop error if x == y
113+
114+
# Raise error if Multi-Edges
115+
if new_edge in self:
116+
error_message = f"({x}, {y}, {w}) is invalid."
117+
raise ValueError(error_message)
118+
119+
return TSPGraph(
120+
edges=frozenset(self._edges | frozenset([TSPEdge.from_3_tuple(x, y, w)]))
121+
)
122+
123+
def get_edges(self) -> list[frozenset[T]]:
124+
return [edge.vertices for edge in self.edges]
125+
126+
def get_edge_weight(self, x: T, y: T) -> float:
127+
if (x not in self) or (y not in self):
128+
error_message = f"{x} or {y} does not belong to the graph vertices."
129+
raise ValueError(error_message)
130+
131+
# Find the edge with vertices (x, y)
132+
edge = next(
133+
(edge for edge in self.edges if frozenset([x, y]) == edge.vertices), None
134+
)
135+
136+
if edge is None:
137+
error_message = f"No edge exists between {x} and {y}."
138+
raise ValueError(error_message)
139+
140+
return edge.weight
141+
142+
def get_vertex_neighbors(self, x: T) -> frozenset[T]:
143+
if x not in self.vertices:
144+
error_message = f"{x} does not belong to the graph vertex set."
145+
raise ValueError(error_message)
146+
return frozenset(
147+
next(iter(edge.vertices - frozenset([x])))
148+
for edge in self.edges
149+
if x in edge.vertices
150+
)
151+
152+
def get_vertex_degree(self, x: T) -> int:
153+
if x not in self.vertices:
154+
error_message = f"{x} does not belong to the graph vertices."
155+
raise ValueError(error_message)
156+
return sum(1 for edge in self.edges if x in edge.vertices)
157+
158+
def get_vertex_argmin(self, x: T) -> T:
159+
"""Returns the Neighbor of a Vertex with the Minimum Weight."""
160+
return min(
161+
[(y, self.get_edge_weight(x, y)) for y in self.get_vertex_neighbors(x)],
162+
key=lambda tup: tup[1],
163+
)[0]
164+
165+
def get_vertex_argmax(self, x: T) -> T:
166+
"""Returns the Neighbor of a Vertex with the Maximum Weight."""
167+
return max(
168+
[(y, self.get_edge_weight(x, y)) for y in self.get_vertex_neighbors(x)],
169+
key=lambda tup: tup[1],
170+
)[0]
171+
172+
def get_vertex_neighbor_weights(self, x: T) -> Sequence[tuple[T, float]]:
173+
# Sort by Smallest to Largest
174+
return sorted(
175+
[(y, self.get_edge_weight(x, y)) for y in self.get_vertex_neighbors(x)],
176+
key=lambda tup: tup[1], # pair[1] is the weight (float)
177+
)
178+
179+
180+
def adjacent_tuples(path: list[T]) -> zip:
181+
"""
182+
Generates adjacent pairs of elements from a path.
183+
184+
Args:
185+
path (list[T]): A list of vertices representing a path.
186+
187+
Returns:
188+
zip: A zip object containing tuples of adjacent vertices.
189+
190+
Examples
191+
>>> list(adjacent_tuples([1, 2, 3, 4, 5]))
192+
[(1, 2), (2, 3), (3, 4), (4, 5)]
193+
194+
>>> list(adjacent_tuples(["A", "B", "C", "D", "E"]))
195+
[('A', 'B'), ('B', 'C'), ('C', 'D'), ('D', 'E')]
196+
"""
197+
iter1, iter2 = itertools.tee(path)
198+
next(iter2, None)
199+
return zip(iter1, iter2)
200+
201+
202+
def path_weight(path: list[T], tsp_graph: TSPGraph) -> float:
203+
"""
204+
Calculates the total weight of a given path in the graph.
205+
206+
Args:
207+
path (list[T]): A list of vertices representing a path.
208+
tsp_graph (TSPGraph): The graph containing the edges and weights.
209+
210+
Returns:
211+
float: The total weight of the path.
212+
"""
213+
return sum(tsp_graph.get_edge_weight(x, y) for x, y in adjacent_tuples(path))
214+
215+
216+
def generate_paths(start: T, end: T, tsp_graph: TSPGraph) -> Generator[list[T]]:
217+
"""
218+
Generates all possible paths between two vertices in a
219+
TSPGraph using Depth-First Search (DFS).
220+
221+
Args:
222+
start (T): The starting vertex.
223+
end (T): The target vertex.
224+
tsp_graph (TSPGraph): The graph to traverse.
225+
226+
Yields:
227+
Generator[list[T]]: A generator yielding paths as lists of vertices.
228+
229+
Raises:
230+
AssertionError: If start or end is not in the graph, or if they are the same.
231+
"""
232+
233+
assert start in tsp_graph.vertices
234+
assert end in tsp_graph.vertices
235+
assert start != end
236+
237+
def dfs(
238+
current: T, target: T, visited: set[T], path: list[T]
239+
) -> Generator[list[T]]:
240+
visited.add(current)
241+
path.append(current)
242+
243+
# If we reach the target, yield the current path
244+
if current == target:
245+
yield list(path)
246+
else:
247+
# Recur for all unvisited neighbors
248+
for neighbor in tsp_graph.get_vertex_neighbors(current):
249+
if neighbor not in visited:
250+
yield from dfs(neighbor, target, visited, path)
251+
252+
# Backtrack
253+
path.pop()
254+
visited.remove(current)
255+
256+
# Initialize DFS
257+
yield from dfs(start, end, set(), [])
258+
259+
260+
def nearest_neighborhood(tsp_graph: TSPGraph, v, visited_=None) -> list[T] | None:
261+
"""
262+
Approximates a solution to the Traveling Salesman Problem
263+
using the Nearest Neighbor heuristic.
264+
265+
Args:
266+
tsp_graph (TSPGraph): The graph to traverse.
267+
v (T): The starting vertex.
268+
visited_ (list[T] | None): A list of already visited vertices.
269+
270+
Returns:
271+
list[T] | None: A complete Hamiltonian cycle if possible, otherwise None.
272+
"""
273+
# Initialize visited list on first call
274+
visited = visited_ or [v]
275+
276+
# Base case: if all vertices are visited
277+
if len(visited) == len(tsp_graph.vertices):
278+
# Check if there is an edge to return to the starting point
279+
if tsp_graph.is_edge_in_graph(visited[-1], visited[0]):
280+
return [*visited, visited[0]]
281+
else:
282+
return None
283+
284+
# Get unvisited neighbors
285+
filtered_neighbors = [
286+
tup for tup in tsp_graph.get_vertex_neighbor_weights(v) if tup[0] not in visited
287+
]
288+
289+
# If there are unvisited neighbors, continue to the nearest one
290+
if filtered_neighbors:
291+
next_v = min(filtered_neighbors, key=lambda tup: tup[1])[0]
292+
return nearest_neighborhood(tsp_graph, v=next_v, visited_=[*visited, next_v])
293+
else:
294+
# No more neighbors, return None (cannot form a complete tour)
295+
return None
296+
297+
298+
def sample_1():
299+
# Reference: https://graphicmaths.com/computer-science/graph-theory/travelling-salesman-problem/
300+
301+
edges = [
302+
("A", "B", 7),
303+
("A", "D", 1),
304+
("A", "E", 1),
305+
("B", "C", 3),
306+
("B", "E", 8),
307+
("C", "E", 2),
308+
("C", "D", 6),
309+
("D", "E", 7),
310+
]
311+
312+
# Create the graph
313+
graph = TSPGraph.from_3_tuples(*edges)
314+
315+
import random
316+
317+
init_v = random.choice(list(graph.vertices))
318+
optim_path = nearest_neighborhood(graph, init_v)
319+
# optim_path = nearest_neighborhood(graph, 'A')
320+
print(f"Optimal Cycle: {optim_path}")
321+
if optim_path:
322+
print(f"Optimal Weight: {path_weight(optim_path, graph)}")
323+
324+
325+
def sample_2():
326+
# Example 8x8 weight matrix (symmetric, no self-loops)
327+
weights = [
328+
[0, 1, 2, 3, 4, 5, 6, 7],
329+
[1, 0, 8, 9, 10, 11, 12, 13],
330+
[2, 8, 0, 14, 15, 16, 17, 18],
331+
[3, 9, 14, 0, 19, 20, 21, 22],
332+
[4, 10, 15, 19, 0, 23, 24, 25],
333+
[5, 11, 16, 20, 23, 0, 26, 27],
334+
[6, 12, 17, 21, 24, 26, 0, 28],
335+
[7, 13, 18, 22, 25, 27, 28, 0],
336+
]
337+
338+
graph = TSPGraph.from_weights(weights)
339+
340+
import random
341+
342+
init_v = random.choice(list(graph.vertices))
343+
optim_path = nearest_neighborhood(graph, init_v)
344+
print(f"Optimal Cycle: {optim_path}")
345+
if optim_path:
346+
print(f"Optimal Weight: {path_weight(optim_path, graph)}")
347+
348+
349+
if __name__ == "__main__":
350+
import doctest
351+
352+
doctest.testmod()
353+
sample_1()
354+
sample_2()

0 commit comments

Comments
 (0)