9
9
10
10
11
11
epochs = 1 # Before 30
12
- n_frames = 7 # Optimal == 7
13
12
batch_size = 10
14
13
workers = 0
15
14
19
18
20
19
21
20
class Trainer (object ):
22
- def __init__ (self ):
23
- # __init__ overrided in child classes
24
- pass
21
+
22
+ def __init__ (self , port , n_frames , client_id , num_clients , data_split_type ):
23
+ # Common init in both childs
24
+ self .port = port
25
+ self .n_frames = n_frames # Optimal frames == 7
26
+ self .client_id = client_id
27
+ self .num_clients = num_clients
28
+ self .data_split_type = data_split_type
29
+ self .logger = 'logs/{}.log' .format (self .client_id )
30
+ self .init_logger ()
31
+ self .metaFile = self .get_meta_file ()
32
+
33
+ def get_meta_file (self ):
34
+ mf = 'data/classification'
35
+ if self .data_split_type == 'iid' :
36
+ return '{}/metadata_{}_clients_iid.mat' .format (mf ,
37
+ self .num_clients )
38
+ elif self .data_split_type == 'non-iid-a' :
39
+ return '{}/metadata_{}_clients_non_iid_a.mat' .format (
40
+ mf , self .num_clients )
41
+ elif self .data_split_type == 'no_split' :
42
+ return '{}/metadata.mat' .format (mf )
43
+ else :
44
+ raise Exception ('Data split type "{}" not implemented' .format (
45
+ self .data_split_type ))
25
46
26
47
def init_logger (self ):
27
48
logging .basicConfig (
@@ -40,7 +61,7 @@ def loadDatasets(self,
40
61
return torch .utils .data .DataLoader (
41
62
ObjectClusterDataset (
42
63
split = split , doAugment = (split == 'train' ),
43
- doFilter = doFilter , sequenceLength = n_frames ,
64
+ doFilter = doFilter , sequenceLength = self . n_frames ,
44
65
metaFile = self .metaFile , useClusters = useClusterSampling
45
66
),
46
67
batch_size = batch_size ,
@@ -114,7 +135,7 @@ def initModel(self):
114
135
115
136
self .model = Model (
116
137
numClasses = len (self .val_loader .dataset .meta ['objects' ]),
117
- sequenceLength = n_frames )
138
+ sequenceLength = self . n_frames )
118
139
self .model .epoch = 0
119
140
self .model .bestPrec = - 1e20
120
141
@@ -213,28 +234,21 @@ def save_model(self):
213
234
214
235
215
236
class SecAggTrainer (Trainer ):
216
- def __init__ (self , client_id , num_clients , data_split_type ):
217
- self .client_id = client_id
218
- self .logger = 'logs/{}.log' .format (self .client_id )
219
- self .init_logger ()
237
+ def __init__ (self ,
238
+ port ,
239
+ n_frames ,
240
+ num_clients ,
241
+ data_split_type ):
242
+ client_id = 'sec_agg'
243
+ super ().__init__ (port ,
244
+ n_frames ,
245
+ client_id ,
246
+ num_clients ,
247
+ data_split_type )
220
248
self .type = 'secure_aggregator'
221
249
self .snapshotDir = 'secure_aggregator/persistent_storage'
222
- self .client_number = None
223
250
self .train_split = 'train' # Shouldn't be needed since it doesn't train
224
- mf = 'data/classification'
225
- if data_split_type == 'iid' :
226
- self .metaFile = '{}/metadata_{}_clients_iid.mat' .format (
227
- mf , num_clients )
228
- elif data_split_type == 'non-iid-a' :
229
- self .metaFile = '{}/metadata_{}_clients_non_iid_a.mat' .format (
230
- mf , num_clients )
231
- elif data_split_type == 'no_split' :
232
- self .metaFile = '{}/metadata.mat' .format (mf )
233
- else :
234
- raise Exception ('Data split type "{}" not implemented' .format (
235
- data_split_type ))
236
251
self .init ()
237
- super (Trainer , self ).__init__ ()
238
252
239
253
def get_checkpoint_path (self ):
240
254
return os .path .join (self .snapshotDir , 'checkpoint.tar' )
@@ -244,31 +258,25 @@ def get_best_model_path(self):
244
258
245
259
246
260
class ClientTrainer (Trainer ):
247
- def __init__ (self , client_number , client_id , num_clients , data_split_type ):
248
- self .client_id = client_id
249
- self .logger = 'logs/{}.log' .format (self .client_id )
250
- self .init_logger ()
261
+ def __init__ (self ,
262
+ port ,
263
+ n_frames ,
264
+ client_number ,
265
+ num_clients ,
266
+ data_split_type ,
267
+ client_id ):
268
+ super ().__init__ (port ,
269
+ n_frames ,
270
+ client_id ,
271
+ num_clients ,
272
+ data_split_type )
251
273
self .type = 'client'
252
- self .snapshotDir = 'client/snapshots_{}' .format (self .client_id )
253
- self .client_number = client_number
254
- mf = 'data/classification'
255
- if data_split_type == 'iid' :
256
- self .train_split = 'train_{}' .format (client_number )
257
- self .metaFile = '{}/metadata_{}_clients_iid.mat' .format (
258
- mf , num_clients )
259
- elif data_split_type == 'non-iid-a' :
260
- self .train_split = 'train_{}' .format (client_number )
261
- self .metaFile = '{}/metadata_{}_clients_non_iid_a.mat' .format (
262
- mf , num_clients )
263
- elif data_split_type == 'no_split' :
264
- self .train_split = 'train'
265
- self .metaFile = '{}/metadata.mat' .format (mf )
266
- else :
267
- raise Exception ('Data split type "{}" not implemented' .format (
268
- data_split_type ))
274
+ self .snapshotDir = 'client/snapshots_{}' .format (self .port )
275
+ # Dont do it with the client_id to avoid tons of folders generated
276
+ self .train_split = self .get_train_split ()
269
277
270
278
# Split dataset if file does not exist
271
- if data_split_type in ('iid' , 'non-iid-a' , 'non-iid-b' ):
279
+ if self . data_split_type in ('iid' , 'non-iid-a' , 'non-iid-b' ):
272
280
from shared import dataset_tools
273
281
# TODO: Reimplement this with a lock file.
274
282
# If multiple clients are spawned this can be a problem.
@@ -287,7 +295,14 @@ def __init__(self, client_number, client_id, num_clients, data_split_type):
287
295
logging .info ('File {} already exists. '
288
296
'Not creating.' .format (self .metaFile ))
289
297
self .init ()
290
- super (Trainer , self ).__init__ ()
298
+
299
+ def get_train_split (self ):
300
+ if self .data_split_type == 'iid' :
301
+ return 'train_{}' .format (self .client_number )
302
+ elif self .data_split_type == 'non-iid-a' :
303
+ return 'train_{}' .format (self .client_number )
304
+ elif self .data_split_type == 'no_split' :
305
+ return 'train'
291
306
292
307
293
308
class AverageMeter (object ):
0 commit comments