-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbmm_flops.py
117 lines (101 loc) · 3.93 KB
/
bmm_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
import torch
import numpy as np
def benchmark_bmm(b, m, n, k, num_iterations=100):
A = torch.randn((b, m, n)).half().to("cuda:0")
B = torch.randn((b, n, k)).half().to("cuda:0")
C = torch.empty((b, m, k)).half().to("cuda:0")
num_warmup_iterations = 50
for i in range(num_warmup_iterations + num_iterations):
if i == num_warmup_iterations:
start_time = time.time()
with torch.no_grad():
torch.bmm(A, B, out=C)
torch.cuda.synchronize()
elapsed_time = (time.time() - start_time) / num_iterations
print(f"Elapsed time for {b}x{m}x{n}x{k}: {elapsed_time:.3f}")
print(f"Throughput (in TFLOP/s) for {b}x{m}x{n}x{k}: {(2 * b * m * n * k) / (elapsed_time * 10**12):.3f}")
flops = (2 * b * m * n * k) / (elapsed_time * 10**12)
print("-" * 80)
return flops
def benchmark_bmm_max(b, m, n, k, num_iterations=200):
A = torch.randn((b, m, n)).half().to("cuda:0")
B = torch.randn((b, n, k)).half().to("cuda:0")
C = torch.empty((b, m, k)).half().to("cuda:0")
num_warmup_iterations=50
times = np.zeros(num_iterations+num_warmup_iterations)
start_time = time.time()
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
torch.bmm(A, B, out=C)
torch.cuda.synchronize()
times[i] = time.time()
#elapsed_time = (time.time() - start_time) / num_iterations
times -= start_time
times = np.diff(times)
times = times[50:]
elapsed_time = np.amax(times)
print(f"Elapsed time for {b}x{m}x{n}x{k}: {elapsed_time:.3f}")
print(f"Throughput (in TFLOP/s) for {b}x{m}x{n}x{k}: {(2 * b * m * n * k) / (elapsed_time * 10**12):.3f}")
flops = (2 * b * m * n * k) / (elapsed_time * 10**12)
print("-" * 80)
return flops
def benchmark_bmm_min(b, m, n, k, num_iterations=200):
A = torch.randn((b, m, n)).half().to("cuda:0")
B = torch.randn((b, n, k)).half().to("cuda:0")
C = torch.empty((b, m, k)).half().to("cuda:0")
num_warmup_iterations=50
times = np.zeros(num_iterations+num_warmup_iterations)
start_time = time.time()
for i in range(num_warmup_iterations + num_iterations):
with torch.no_grad():
torch.bmm(A, B, out=C)
torch.cuda.synchronize()
times[i] = time.time()
#elapsed_time = (time.time() - start_time) / num_iterations
times -= start_time
times = np.diff(times)
times = times[50:]
elapsed_time = np.amin(times)
print(f"Elapsed time for {b}x{m}x{n}x{k}: {elapsed_time:.3f}")
print(f"Throughput (in TFLOP/s) for {b}x{m}x{n}x{k}: {(2 * b * m * n * k) / (elapsed_time * 10**12):.3f}")
flops = (2 * b * m * n * k) / (elapsed_time * 10**12)
print("-" * 80)
return flops
def bench_list(b, m, N, k):
benches = []
for n in N:
benches.append(benchmark_bmm(b, m, n, k))
return benches
if __name__ == '__main__':
torch.cuda.set_device("cuda:0")
#shared dimension sweep.
#N_values= range(64, 2**12, 64)
#for logb in range(5, 9):
# bench_list(b=2**logb, m=2048, N=N_values, k=2048)
# Try to determine the effect of b on throughput with square individual MMs.
'''for log_b in range(7):
b = 2**log_b
benchmark_bmm(b, m=1024, n=1024, k=1024)
benchmark_bmm(b, m=2048, n=2048, k=2048)
benchmark_bmm(b, m=4096, n=4096, k=4096)
benchmark_bmm(b, m=8192, n=8192, k=8192)
'''
# Try to determine the effect of b and outer_dim on throughput with non-square
# individual MMs.
for log_b in range(7):
b = 2**log_b
for log_outer_dim in range(5, 14):
outer_dim = 2**log_outer_dim
benchmark_bmm_min(b, m=outer_dim, n=4096, k=outer_dim)
'''
h = 2048
m = 2048
k = int(h)
n = h
b = 512
A = torch.randn((b, m, n)).half().to("cuda:0")
B = torch.randn((b, n, k)).half().to("cuda:0")
C = torch.empty((b, m, k)).half().to("cuda:0")
torch.bmm(A, B, out=C)
'''