Skip to content

Commit 45a0537

Browse files
committed
Fix imports in if statement for PyTests
1 parent e8bd3c6 commit 45a0537

File tree

1 file changed

+116
-117
lines changed

1 file changed

+116
-117
lines changed

Diff for: scripts/immutability_options.py

+116-117
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,123 @@
11
# This is an integration test for the immutability functionality
22
from __future__ import annotations
3-
import json
4-
import os
5-
import pickle
6-
import sys
7-
import time
8-
import paderbox
9-
from paderbox.io.cache import url_to_local_path
10-
from collections import defaultdict
11-
from typing import Any
12-
import lazy_dataset
13-
import numpy as np
14-
import psutil
15-
import torch
16-
from tabulate import tabulate
17-
18-
19-
# Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
20-
def create_coco() -> list[Any]:
21-
json_path = url_to_local_path("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json")
22-
with open(json_path) as f:
23-
obj = json.load(f)
24-
return obj["annotations"]
25-
26-
27-
28-
def get_mem_info(pid: int) -> dict[str, int]:
29-
res = defaultdict(int)
30-
for mmap in psutil.Process(pid).memory_maps():
31-
res['rss'] += mmap.rss
32-
res['pss'] += mmap.pss
33-
res['uss'] += mmap.private_clean + mmap.private_dirty
34-
res['shared'] += mmap.shared_clean + mmap.shared_dirty
35-
if mmap.path.startswith('/'): # looks like a file path
36-
res['shared_file'] += mmap.shared_clean + mmap.shared_dirty
37-
return res
38-
39-
40-
class MemoryMonitor():
41-
"""Class used to monitor the memory usage of processes"""
42-
43-
def __init__(self, pids: list[int] = None):
44-
if pids is None:
45-
pids = [os.getpid()]
46-
self.pids = pids
47-
48-
def add_pid(self, pid: int):
49-
assert pid not in self.pids
50-
self.pids.append(pid)
51-
52-
def _refresh(self):
53-
self.data = {pid: get_mem_info(pid) for pid in self.pids}
54-
return self.data
55-
56-
def table(self) -> str:
57-
self._refresh()
58-
table = []
59-
keys = list(list(self.data.values())[0].keys())
60-
now = str(int(time.perf_counter() % 1e5))
61-
for pid, data in self.data.items():
62-
table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys))
63-
return tabulate(table, headers=["time", "PID"] + keys)
64-
65-
def str(self):
66-
self._refresh()
67-
keys = list(list(self.data.values())[0].keys())
68-
res = []
69-
for pid in self.pids:
70-
s = f"PID={pid}"
71-
for k in keys:
72-
v = self.format(self.data[pid][k])
73-
s += f", {k}={v}"
74-
res.append(s)
75-
return "\n".join(res)
76-
77-
@staticmethod
78-
def format(size: int) -> str:
79-
for unit in ('', 'K', 'M', 'G'):
80-
if size < 1024:
81-
break
82-
size /= 1024.0
83-
return "%.1f%s" % (size, unit)
84-
85-
86-
def read_sample(x):
87-
"""
88-
A function that is supposed to read object x, incrementing its refcount.
89-
This mimics what a real dataloader would do."""
90-
if sys.version_info >= (3, 10, 6):
91-
"""Before this version, pickle does not increment refcount. This is a bug that's
92-
fixed in https://github.com/python/cpython/pull/92931. """
93-
return pickle.dumps(x)
94-
else:
95-
import msgpack
96-
return msgpack.dumps(x)
97-
98-
99-
class DatasetFromList(torch.utils.data.Dataset):
100-
def __init__(self, lst):
101-
self.lst = lst
102-
103-
def __len__(self):
104-
return len(self.lst)
105-
106-
def __getitem__(self, idx: int):
107-
return self.lst[idx]
108-
109-
110-
def worker(_, dataset: torch.utils.data.Dataset):
111-
while True:
112-
for sample in dataset:
113-
# read the data, with a fake latency
114-
time.sleep(0.000001)
115-
result = read_sample(sample)
116-
1173

1184
if __name__ == "__main__":
5+
import json
6+
import os
7+
import pickle
8+
import sys
9+
import time
10+
import paderbox
11+
from paderbox.io.cache import url_to_local_path
12+
from collections import defaultdict
13+
from typing import Any
14+
import lazy_dataset
15+
import numpy as np
16+
import psutil
17+
from tabulate import tabulate
11918
import matplotlib.pyplot as plt
19+
import torch
20+
21+
# Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
22+
def create_coco() -> list[Any]:
23+
json_path = url_to_local_path("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json")
24+
with open(json_path) as f:
25+
obj = json.load(f)
26+
return obj["annotations"]
27+
28+
29+
30+
def get_mem_info(pid: int) -> dict[str, int]:
31+
res = defaultdict(int)
32+
for mmap in psutil.Process(pid).memory_maps():
33+
res['rss'] += mmap.rss
34+
res['pss'] += mmap.pss
35+
res['uss'] += mmap.private_clean + mmap.private_dirty
36+
res['shared'] += mmap.shared_clean + mmap.shared_dirty
37+
if mmap.path.startswith('/'): # looks like a file path
38+
res['shared_file'] += mmap.shared_clean + mmap.shared_dirty
39+
return res
40+
41+
42+
class MemoryMonitor():
43+
"""Class used to monitor the memory usage of processes"""
44+
45+
def __init__(self, pids: list[int] = None):
46+
if pids is None:
47+
pids = [os.getpid()]
48+
self.pids = pids
49+
50+
def add_pid(self, pid: int):
51+
assert pid not in self.pids
52+
self.pids.append(pid)
53+
54+
def _refresh(self):
55+
self.data = {pid: get_mem_info(pid) for pid in self.pids}
56+
return self.data
57+
58+
def table(self) -> str:
59+
self._refresh()
60+
table = []
61+
keys = list(list(self.data.values())[0].keys())
62+
now = str(int(time.perf_counter() % 1e5))
63+
for pid, data in self.data.items():
64+
table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys))
65+
return tabulate(table, headers=["time", "PID"] + keys)
66+
67+
def str(self):
68+
self._refresh()
69+
keys = list(list(self.data.values())[0].keys())
70+
res = []
71+
for pid in self.pids:
72+
s = f"PID={pid}"
73+
for k in keys:
74+
v = self.format(self.data[pid][k])
75+
s += f", {k}={v}"
76+
res.append(s)
77+
return "\n".join(res)
78+
79+
@staticmethod
80+
def format(size: int) -> str:
81+
for unit in ('', 'K', 'M', 'G'):
82+
if size < 1024:
83+
break
84+
size /= 1024.0
85+
return "%.1f%s" % (size, unit)
86+
87+
88+
def read_sample(x):
89+
"""
90+
A function that is supposed to read object x, incrementing its refcount.
91+
This mimics what a real dataloader would do."""
92+
if sys.version_info >= (3, 10, 6):
93+
"""Before this version, pickle does not increment refcount. This is a bug that's
94+
fixed in https://github.com/python/cpython/pull/92931. """
95+
return pickle.dumps(x)
96+
else:
97+
import msgpack
98+
return msgpack.dumps(x)
99+
100+
101+
class DatasetFromList(torch.utils.data.Dataset):
102+
103+
def __init__(self, lst):
104+
self.lst = lst
105+
106+
def __len__(self):
107+
return len(self.lst)
108+
109+
def __getitem__(self, idx: int):
110+
return self.lst[idx]
111+
112+
113+
def worker(_, dataset: torch.utils.data.Dataset):
114+
while True:
115+
for sample in dataset:
116+
# read the data, with a fake latency
117+
time.sleep(0.000001)
118+
result = read_sample(sample)
120119
monitor = MemoryMonitor()
121-
immutable_warranty = "pickle" # copy pickle wu
120+
immutable_warranty = "wu" # copy pickle wu
122121
ds = lazy_dataset.new(create_coco(), immutable_warranty=immutable_warranty)
123122
print(monitor.table())
124123

@@ -144,7 +143,7 @@ def worker(_, dataset: torch.utils.data.Dataset):
144143
axis.set_xlabel("Times (s)")
145144
axis.legend()
146145
axis.set_ylabel("Memory usage (MB)")
147-
# plt.savefig(f"/net/vol/deegen/SHK/Lazy_dataset_test/{immutable_warranty}.svg", format="svg")#, dpi=600)
148-
plt.show()
146+
plt.savefig(f"/net/vol/deegen/SHK/Lazy_dataset_test/{immutable_warranty}.svg", format="svg")#, dpi=600)
147+
# plt.show()
149148
finally:
150149
ctx.join()

0 commit comments

Comments
 (0)