1
1
# This is an integration test for the immutability functionality
2
2
from __future__ import annotations
3
- import json
4
- import os
5
- import pickle
6
- import sys
7
- import time
8
- import paderbox
9
- from paderbox .io .cache import url_to_local_path
10
- from collections import defaultdict
11
- from typing import Any
12
- import lazy_dataset
13
- import numpy as np
14
- import psutil
15
- import torch
16
- from tabulate import tabulate
17
-
18
-
19
- # Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
20
- def create_coco () -> list [Any ]:
21
- json_path = url_to_local_path ("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json" )
22
- with open (json_path ) as f :
23
- obj = json .load (f )
24
- return obj ["annotations" ]
25
-
26
-
27
-
28
- def get_mem_info (pid : int ) -> dict [str , int ]:
29
- res = defaultdict (int )
30
- for mmap in psutil .Process (pid ).memory_maps ():
31
- res ['rss' ] += mmap .rss
32
- res ['pss' ] += mmap .pss
33
- res ['uss' ] += mmap .private_clean + mmap .private_dirty
34
- res ['shared' ] += mmap .shared_clean + mmap .shared_dirty
35
- if mmap .path .startswith ('/' ): # looks like a file path
36
- res ['shared_file' ] += mmap .shared_clean + mmap .shared_dirty
37
- return res
38
-
39
-
40
- class MemoryMonitor ():
41
- """Class used to monitor the memory usage of processes"""
42
-
43
- def __init__ (self , pids : list [int ] = None ):
44
- if pids is None :
45
- pids = [os .getpid ()]
46
- self .pids = pids
47
-
48
- def add_pid (self , pid : int ):
49
- assert pid not in self .pids
50
- self .pids .append (pid )
51
-
52
- def _refresh (self ):
53
- self .data = {pid : get_mem_info (pid ) for pid in self .pids }
54
- return self .data
55
-
56
- def table (self ) -> str :
57
- self ._refresh ()
58
- table = []
59
- keys = list (list (self .data .values ())[0 ].keys ())
60
- now = str (int (time .perf_counter () % 1e5 ))
61
- for pid , data in self .data .items ():
62
- table .append ((now , str (pid )) + tuple (self .format (data [k ]) for k in keys ))
63
- return tabulate (table , headers = ["time" , "PID" ] + keys )
64
-
65
- def str (self ):
66
- self ._refresh ()
67
- keys = list (list (self .data .values ())[0 ].keys ())
68
- res = []
69
- for pid in self .pids :
70
- s = f"PID={ pid } "
71
- for k in keys :
72
- v = self .format (self .data [pid ][k ])
73
- s += f", { k } ={ v } "
74
- res .append (s )
75
- return "\n " .join (res )
76
-
77
- @staticmethod
78
- def format (size : int ) -> str :
79
- for unit in ('' , 'K' , 'M' , 'G' ):
80
- if size < 1024 :
81
- break
82
- size /= 1024.0
83
- return "%.1f%s" % (size , unit )
84
-
85
-
86
- def read_sample (x ):
87
- """
88
- A function that is supposed to read object x, incrementing its refcount.
89
- This mimics what a real dataloader would do."""
90
- if sys .version_info >= (3 , 10 , 6 ):
91
- """Before this version, pickle does not increment refcount. This is a bug that's
92
- fixed in https://github.com/python/cpython/pull/92931. """
93
- return pickle .dumps (x )
94
- else :
95
- import msgpack
96
- return msgpack .dumps (x )
97
-
98
-
99
- class DatasetFromList (torch .utils .data .Dataset ):
100
- def __init__ (self , lst ):
101
- self .lst = lst
102
-
103
- def __len__ (self ):
104
- return len (self .lst )
105
-
106
- def __getitem__ (self , idx : int ):
107
- return self .lst [idx ]
108
-
109
-
110
- def worker (_ , dataset : torch .utils .data .Dataset ):
111
- while True :
112
- for sample in dataset :
113
- # read the data, with a fake latency
114
- time .sleep (0.000001 )
115
- result = read_sample (sample )
116
-
117
3
118
4
if __name__ == "__main__" :
5
+ import json
6
+ import os
7
+ import pickle
8
+ import sys
9
+ import time
10
+ import paderbox
11
+ from paderbox .io .cache import url_to_local_path
12
+ from collections import defaultdict
13
+ from typing import Any
14
+ import lazy_dataset
15
+ import numpy as np
16
+ import psutil
17
+ from tabulate import tabulate
119
18
import matplotlib .pyplot as plt
19
+ import torch
20
+
21
+ # Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
22
+ def create_coco () -> list [Any ]:
23
+ json_path = url_to_local_path ("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json" )
24
+ with open (json_path ) as f :
25
+ obj = json .load (f )
26
+ return obj ["annotations" ]
27
+
28
+
29
+
30
+ def get_mem_info (pid : int ) -> dict [str , int ]:
31
+ res = defaultdict (int )
32
+ for mmap in psutil .Process (pid ).memory_maps ():
33
+ res ['rss' ] += mmap .rss
34
+ res ['pss' ] += mmap .pss
35
+ res ['uss' ] += mmap .private_clean + mmap .private_dirty
36
+ res ['shared' ] += mmap .shared_clean + mmap .shared_dirty
37
+ if mmap .path .startswith ('/' ): # looks like a file path
38
+ res ['shared_file' ] += mmap .shared_clean + mmap .shared_dirty
39
+ return res
40
+
41
+
42
+ class MemoryMonitor ():
43
+ """Class used to monitor the memory usage of processes"""
44
+
45
+ def __init__ (self , pids : list [int ] = None ):
46
+ if pids is None :
47
+ pids = [os .getpid ()]
48
+ self .pids = pids
49
+
50
+ def add_pid (self , pid : int ):
51
+ assert pid not in self .pids
52
+ self .pids .append (pid )
53
+
54
+ def _refresh (self ):
55
+ self .data = {pid : get_mem_info (pid ) for pid in self .pids }
56
+ return self .data
57
+
58
+ def table (self ) -> str :
59
+ self ._refresh ()
60
+ table = []
61
+ keys = list (list (self .data .values ())[0 ].keys ())
62
+ now = str (int (time .perf_counter () % 1e5 ))
63
+ for pid , data in self .data .items ():
64
+ table .append ((now , str (pid )) + tuple (self .format (data [k ]) for k in keys ))
65
+ return tabulate (table , headers = ["time" , "PID" ] + keys )
66
+
67
+ def str (self ):
68
+ self ._refresh ()
69
+ keys = list (list (self .data .values ())[0 ].keys ())
70
+ res = []
71
+ for pid in self .pids :
72
+ s = f"PID={ pid } "
73
+ for k in keys :
74
+ v = self .format (self .data [pid ][k ])
75
+ s += f", { k } ={ v } "
76
+ res .append (s )
77
+ return "\n " .join (res )
78
+
79
+ @staticmethod
80
+ def format (size : int ) -> str :
81
+ for unit in ('' , 'K' , 'M' , 'G' ):
82
+ if size < 1024 :
83
+ break
84
+ size /= 1024.0
85
+ return "%.1f%s" % (size , unit )
86
+
87
+
88
+ def read_sample (x ):
89
+ """
90
+ A function that is supposed to read object x, incrementing its refcount.
91
+ This mimics what a real dataloader would do."""
92
+ if sys .version_info >= (3 , 10 , 6 ):
93
+ """Before this version, pickle does not increment refcount. This is a bug that's
94
+ fixed in https://github.com/python/cpython/pull/92931. """
95
+ return pickle .dumps (x )
96
+ else :
97
+ import msgpack
98
+ return msgpack .dumps (x )
99
+
100
+
101
+ class DatasetFromList (torch .utils .data .Dataset ):
102
+
103
+ def __init__ (self , lst ):
104
+ self .lst = lst
105
+
106
+ def __len__ (self ):
107
+ return len (self .lst )
108
+
109
+ def __getitem__ (self , idx : int ):
110
+ return self .lst [idx ]
111
+
112
+
113
+ def worker (_ , dataset : torch .utils .data .Dataset ):
114
+ while True :
115
+ for sample in dataset :
116
+ # read the data, with a fake latency
117
+ time .sleep (0.000001 )
118
+ result = read_sample (sample )
120
119
monitor = MemoryMonitor ()
121
- immutable_warranty = "pickle " # copy pickle wu
120
+ immutable_warranty = "wu " # copy pickle wu
122
121
ds = lazy_dataset .new (create_coco (), immutable_warranty = immutable_warranty )
123
122
print (monitor .table ())
124
123
@@ -144,7 +143,7 @@ def worker(_, dataset: torch.utils.data.Dataset):
144
143
axis .set_xlabel ("Times (s)" )
145
144
axis .legend ()
146
145
axis .set_ylabel ("Memory usage (MB)" )
147
- # plt.savefig(f"/net/vol/deegen/SHK/Lazy_dataset_test/{immutable_warranty}.svg", format="svg")#, dpi=600)
148
- plt .show ()
146
+ plt .savefig (f"/net/vol/deegen/SHK/Lazy_dataset_test/{ immutable_warranty } .svg" , format = "svg" )#, dpi=600)
147
+ # plt.show()
149
148
finally :
150
149
ctx .join ()
0 commit comments