Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit a0fa7c7

Browse files
authored
Separated utils into backend and benchmark
1 parent bea2e0b commit a0fa7c7

File tree

7 files changed

+238
-249
lines changed

7 files changed

+238
-249
lines changed

dl_bench/utils.py renamed to dl_bench/backend.py

+28-237
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,8 @@
1-
import time
21
from typing import Any
32

43
import numpy as np
54
import torch
65
from torch.nn import Module
7-
from torch.utils.data import DataLoader, Dataset
8-
9-
10-
def get_time():
11-
return time.perf_counter()
12-
13-
14-
class RandomInfDataset(Dataset):
15-
def __init__(self, n, in_shape, seed=42):
16-
super().__init__()
17-
np.random.seed(seed)
18-
19-
self.values = np.random.randn(n, *in_shape).astype(np.float32)
20-
21-
def __len__(self):
22-
return len(self.values)
23-
24-
def __getitem__(self, index):
25-
return self.values[index]
26-
27-
28-
def get_inf_loaders(n, in_shape, batch_size, device: str):
29-
# This speeds up data copy for cuda devices
30-
pin_memory = device == "cuda"
31-
32-
ds = RandomInfDataset(n, in_shape)
33-
train_loader = DataLoader(
34-
ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=pin_memory
35-
)
36-
test_loader = DataLoader(
37-
ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=pin_memory
38-
)
39-
return train_loader, test_loader
40-
41-
42-
def recursively_convert_to_numpy(o: Any):
43-
if isinstance(o, torch.Tensor):
44-
return o.numpy()
45-
if isinstance(o, tuple):
46-
return tuple(recursively_convert_to_numpy(x) for x in o)
47-
if isinstance(o, list):
48-
return [recursively_convert_to_numpy(x) for x in o]
49-
if isinstance(o, dict):
50-
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
51-
# No-op cases. Explicitly enumerated to avoid things sneaking through.
52-
if isinstance(o, str):
53-
return o
54-
if isinstance(o, float):
55-
return o
56-
if isinstance(o, int):
57-
return o
58-
raise Exception(f"Unexpected Python function input: {o}")
59-
60-
61-
def recursively_convert_from_numpy(o: Any):
62-
if isinstance(o, np.ndarray):
63-
return torch.from_numpy(o)
64-
if isinstance(o, tuple):
65-
return tuple(recursively_convert_from_numpy(x) for x in o)
66-
if isinstance(o, list):
67-
return [recursively_convert_from_numpy(x) for x in o]
68-
if isinstance(o, dict):
69-
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
70-
# No-op cases. Explicitly enumerated to avoid things sneaking through.
71-
if isinstance(o, str):
72-
return o
73-
if isinstance(o, float):
74-
return o
75-
if isinstance(o, int):
76-
return o
77-
raise Exception(f"Unexpected Python function output: {o}")
78-
79-
80-
def refine_result_type(_result):
81-
if isinstance(_result, tuple):
82-
return tuple(refine_result_type(x) for x in _result)
83-
elif isinstance(_result, np.ndarray):
84-
return torch.from_numpy(_result)
85-
elif isinstance(_result, (bool, int, float)):
86-
return _result
87-
else:
88-
raise ValueError(f"Unhandled return type {type(_result)}")
896

907

918
def str_to_dtype(dtype: str):
@@ -337,157 +254,31 @@ def _get_device(device_name):
337254
raise ValueError(f"Unknown execution device {device_name}.")
338255

339256

340-
def get_report(fw_times, duration_s, n_items, flops_per_sample):
341-
return {
342-
"duration_s": duration_s,
343-
"samples_per_s": n_items / sum(fw_times),
344-
"samples_per_s_dirty": n_items / duration_s,
345-
"flops_per_sample": flops_per_sample,
346-
"n_items": n_items,
347-
"p00": np.percentile(fw_times, 0),
348-
"p50": np.percentile(fw_times, 50),
349-
"p90": np.percentile(fw_times, 90),
350-
"p100": max(fw_times),
351-
}
352-
353-
354-
class Benchmark:
355-
def __init__(
356-
self,
357-
net,
358-
in_shape,
359-
dataset,
360-
batch_size,
361-
min_batches=10,
362-
min_seconds=10,
363-
warmup_batches=3,
364-
) -> None:
365-
self.model = net
366-
self.in_shape = in_shape
367-
self.dataset = dataset
368-
self.batch_size = batch_size
369-
self.warmup_batches = warmup_batches
370-
self.min_batches = min_batches
371-
self.min_seconds = min_seconds
372-
373-
def compile(self, sample, backend: Backend):
374-
self.model = backend.prepare_eval_model(self.model, sample_input=sample)
375-
376-
def inference(self, backend: Backend):
377-
# timout if running for more than 3 minutes already
378-
max_time = 180
379-
380-
test_loader = torch.utils.data.DataLoader(
381-
self.dataset,
382-
batch_size=self.batch_size,
383-
shuffle=False,
384-
num_workers=0,
385-
pin_memory=backend.device_name == "cuda",
386-
)
387-
388-
try:
389-
print("Torch cpu capability:", torch.backends.cpu.get_cpu_capability())
390-
except:
391-
pass
392-
393-
flops_per_sample = get_macs(self.model, self.in_shape, backend) * 2
394-
395-
sample = next(iter(test_loader))
396-
self.compile(sample, backend)
397-
398-
n_items = 0
399-
outputs = []
400-
fw_times = []
401-
402-
self.model.eval()
403-
with torch.inference_mode():
404-
start = get_time()
405-
for i, x in enumerate(test_loader):
406-
backend.sync()
407-
s = get_time()
408-
x = backend.to_device(x)
409-
if backend.dtype != torch.float32:
410-
with torch.autocast(
411-
device_type=backend.device_name,
412-
dtype=backend.dtype,
413-
):
414-
y = self.model(x)
415-
else:
416-
y = self.model(x)
417-
418-
backend.sync()
419-
420-
if i < self.warmup_batches:
421-
# We restart timer because that was just a warmup
422-
start = time.perf_counter()
423-
continue
424-
425-
fw_times.append(get_time() - s)
426-
n_items += len(x)
427-
outputs.append(y)
428-
429-
# early stopping if we have 10+ batches and were running for 10+ seconds
430-
if (
431-
(time.perf_counter() - start) > self.min_seconds
432-
and n_items >= self.batch_size * self.min_batches
433-
):
434-
break
435-
436-
if (get_time() - start) > max_time:
437-
break
438-
439-
stop = get_time()
440-
441-
report = get_report(
442-
fw_times=fw_times,
443-
duration_s=stop - start,
444-
n_items=n_items,
445-
flops_per_sample=flops_per_sample,
446-
)
447-
return report, outputs
448-
449-
def train(self):
450-
# We are not interested in training yet.
451-
# criterion = nn.CrossEntropyLoss()
452-
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
453-
454-
# N_EPOCHS = 3
455-
# epoch_stats = {}
456-
# n_report = 10
457-
# for epoch in range(n_epochs): # loop over the dataset multiple times
458-
# running_loss = 0.0
459-
460-
# n_items = 0
461-
# start = get_time()
462-
# for i, (x, y) in enumerate(trainloader):
463-
# optimizer.zero_grad()
464-
465-
# outputs = net(x)
466-
# loss = criterion(outputs, y)
467-
# loss.backward()
468-
# optimizer.step()
469-
470-
# n_items += len(x)
471-
472-
# running_loss += loss.item()
473-
# if i % n_report == (n_report - 1):
474-
# print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / n_report:.3f}')
475-
# running_loss = 0.0
476-
477-
# stop = get_time()
478-
# print(f"{n_items} took {stop - start}")
479-
480-
# print('Finished Training')
481-
pass
482-
483-
484-
def get_macs(model, in_shape, backend):
485-
"""Calculate MACs, conventional FLOPS = MACs * 2."""
486-
from ptflops import get_model_complexity_info
487-
488-
model.eval()
489-
with torch.no_grad():
490-
macs, params = get_model_complexity_info(
491-
model, in_shape, as_strings=False, print_per_layer_stat=False, verbose=True
492-
)
493-
return macs
257+
def recursively_convert_to_numpy(o: Any):
258+
if isinstance(o, torch.Tensor):
259+
return o.numpy()
260+
if isinstance(o, tuple):
261+
return tuple(recursively_convert_to_numpy(x) for x in o)
262+
if isinstance(o, list):
263+
return [recursively_convert_to_numpy(x) for x in o]
264+
if isinstance(o, dict):
265+
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
266+
# No-op cases. Explicitly enumerated to avoid things sneaking through.
267+
if isinstance(o, str):
268+
return o
269+
if isinstance(o, float):
270+
return o
271+
if isinstance(o, int):
272+
return o
273+
raise Exception(f"Unexpected Python function input: {o}")
274+
275+
276+
def refine_result_type(_result):
277+
if isinstance(_result, tuple):
278+
return tuple(refine_result_type(x) for x in _result)
279+
elif isinstance(_result, np.ndarray):
280+
return torch.from_numpy(_result)
281+
elif isinstance(_result, (bool, int, float)):
282+
return _result
283+
else:
284+
raise ValueError(f"Unhandled return type {type(_result)}")

dl_bench/bench/cnn.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dl_bench.utils import Benchmark, RandomInfDataset
1+
from dl_bench.benchmark import Benchmark, RandomInfDataset
22

33

44
def get_cnn(name):
@@ -43,5 +43,11 @@ def __init__(self, params) -> None:
4343
net = get_cnn(name=name)
4444

4545
super().__init__(
46-
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size, min_batches=min_batches, warmup_batches=warmup, min_seconds=min_seconds
46+
net=net,
47+
in_shape=in_shape,
48+
dataset=dataset,
49+
batch_size=batch_size,
50+
min_batches=min_batches,
51+
warmup_batches=warmup,
52+
min_seconds=min_seconds,
4753
)

dl_bench/bench/llm.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import math
44

55
import torch
6-
import numpy as np
76
from transformers import (
87
AutoModelForCausalLM,
98
AutoTokenizer,
109
LlamaForCausalLM,
1110
LlamaTokenizer,
1211
)
1312

14-
from dl_bench.utils import Benchmark, get_report, get_time, str_to_dtype
13+
from dl_bench.benchmark import Benchmark, get_report, get_time
14+
from dl_bench.backend import str_to_dtype
1515

1616

1717
def get_llm(name, dtype):
@@ -81,7 +81,6 @@ def inference(self, backend):
8181
outputs = []
8282
fw_times = []
8383

84-
8584
# Ipex gives error with eval, other backends have no effect
8685
# self.model.eval()
8786
for i in range(self.n_iter):

dl_bench/bench/mlp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn as nn
55

6-
from dl_bench.utils import Benchmark, RandomInfDataset
6+
from dl_bench.benchmark import Benchmark, RandomInfDataset
77

88

99
size2_struct = [512, 1024, 2048, 512]

dl_bench/bench/mlp_basic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import time
1+
from typing import List
22

33
import torch
44
from torch.nn import Module, Linear
55
import torch.nn.functional as F
66

7-
from dl_bench.utils import Benchmark
8-
from dl_bench.bench.mlp import RandomInfDataset
9-
from typing import List
7+
from dl_bench.benchmark import Benchmark, RandomInfDataset
108

119

1210
class MLP(Module):
@@ -48,7 +46,7 @@ def __init__(self, params) -> None:
4846

4947

5048
def train(model: Module, device):
51-
from tools import train, validate_accuracy
49+
from dl_bench.tools import train, validate_accuracy
5250

5351
epochs = 2
5452

0 commit comments

Comments
 (0)