-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdriver_graphblast.py
88 lines (75 loc) · 3.97 KB
/
driver_graphblast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import subprocess
import driver
import config
import dataset
__all__ = [
"DriverGraphBLAST"
]
class DriverGraphBLAST(driver.Driver):
def __init__(self):
super().__init__()
self.exec_dir = config.DEPS / "graphblast" / "bin"
self.gbfs = "gbfs"
self.gsssp = "gsssp"
self.gpr = "gpr"
self.gtc = "gtc"
self.graph_exceptions[dataset.ALGORITHM_NAME_pr] = {
dataset.GRAPH_NAME_mycielskian19,
dataset.GRAPH_NAME_socLiveJournal,
dataset.GRAPH_NAME_comOrkut,
dataset.GRAPH_NAME_rgg_n_2_23_s0,
dataset.GRAPH_NAME_road_central
}
# 0: do not display per iteration timing, 1: display per iteration timing
self.timing = 0
# 0: follow mtx, 1: force undirected graph to be directed, 2: force directed graph to be undirected
self.directed = 0
# 0: run CPU verification, 1: skip CPU algorithm verification
self.skip_cpu_verify = 0
def name(self) -> str:
return "graphblast"
def run_bfs(self, graph: dataset.Graph, source_vertex, num_iterations) -> driver.ExecutionResult:
output = subprocess.check_output([str(self.exec_dir / self.gbfs),
f"--source={source_vertex}",
f"--niter={num_iterations}",
f"--timing={self.timing}",
f"--directed={self.directed}",
f"--skip_cpu_verify={self.skip_cpu_verify}",
str(graph.path())])
return DriverGraphBLAST._parse_output(output, num_iterations)
def run_sssp(self, graph: dataset.Graph, source_vertex, num_iterations) -> driver.ExecutionResult:
output = subprocess.check_output([str(self.exec_dir / self.gsssp),
f"--source={source_vertex}",
f"--niter={num_iterations}",
f"--timing={self.timing}",
f"--directed={self.directed}",
f"--skip_cpu_verify={self.skip_cpu_verify}",
str(graph.path())])
return DriverGraphBLAST._parse_output(output, num_iterations)
def run_pr(self, graph: dataset.Graph, num_iterations) -> driver.ExecutionResult:
output = subprocess.check_output([str(self.exec_dir / self.gpr),
f"--niter={num_iterations}",
f"--timing={self.timing}",
f"--directed={self.directed}",
f"--skip_cpu_verify={self.skip_cpu_verify}",
str(graph.path())])
return DriverGraphBLAST._parse_output(output, num_iterations)
def run_tc(self, graph: dataset.Graph, num_iterations) -> driver.ExecutionResult:
output = subprocess.check_output([str(self.exec_dir / self.gtc),
f"--niter={num_iterations}",
f"--timing={self.timing}",
f"--directed={self.directed}",
f"--skip_cpu_verify={self.skip_cpu_verify}",
str(graph.path())])
return DriverGraphBLAST._parse_output(output, num_iterations)
@staticmethod
def _parse_output(output, n):
lines = output.decode("ASCII").replace("\r", "").split("\n")
warmup = 0.0
tight = 0.0
for line in lines:
if line.startswith("warmup"):
warmup = float(line.replace(",", "").split(" ")[1])
if line.startswith("tight"):
tight = float(line.replace(",", "").split(" ")[1])
return driver.ExecutionResult(warmup, [tight] * n)