Skip to content

Commit b20080f

Browse files
committed
Refactor n frames
1 parent 6e1a9c7 commit b20080f

File tree

5 files changed

+89
-56
lines changed

5 files changed

+89
-56
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ You can see the results of the experiments done [here](results/README.md).
3939
- There are 3 env variables which can be set:
4040
- `DELETE_OLD_LOGS`: if set to 1, removes all logs in `logs/` folder. Useful for new run.
4141
- `RESTART_SCREEN`: if set to 1, kills all current screens. Useful for new run.
42+
- `N_FRAMES`: Number of input frames.
4243
- `SPLIT_TYPE`: Can be set to 3 values (more information in dissemination):
4344
- `no_split`: No split in the dataset is performed.
4445
- `iid`: IID split type is performed.

classification/trainer.py

+62-47
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
epochs = 1 # Before 30
12-
n_frames = 7 # Optimal == 7
1312
batch_size = 10
1413
workers = 0
1514

@@ -19,9 +18,31 @@
1918

2019

2120
class Trainer(object):
22-
def __init__(self):
23-
# __init__ overrided in child classes
24-
pass
21+
22+
def __init__(self, port, n_frames, client_id, num_clients, data_split_type):
23+
# Common init in both childs
24+
self.port = port
25+
self.n_frames = n_frames # Optimal frames == 7
26+
self.client_id = client_id
27+
self.num_clients = num_clients
28+
self.data_split_type = data_split_type
29+
self.logger = 'logs/{}.log'.format(self.client_id)
30+
self.init_logger()
31+
self.metaFile = self.get_meta_file()
32+
33+
def get_meta_file(self):
34+
mf = 'data/classification'
35+
if self.data_split_type == 'iid':
36+
return '{}/metadata_{}_clients_iid.mat'.format(mf,
37+
self.num_clients)
38+
elif self.data_split_type == 'non-iid-a':
39+
return '{}/metadata_{}_clients_non_iid_a.mat'.format(
40+
mf, self.num_clients)
41+
elif self.data_split_type == 'no_split':
42+
return '{}/metadata.mat'.format(mf)
43+
else:
44+
raise Exception('Data split type "{}" not implemented'.format(
45+
self.data_split_type))
2546

2647
def init_logger(self):
2748
logging.basicConfig(
@@ -40,7 +61,7 @@ def loadDatasets(self,
4061
return torch.utils.data.DataLoader(
4162
ObjectClusterDataset(
4263
split=split, doAugment=(split == 'train'),
43-
doFilter=doFilter, sequenceLength=n_frames,
64+
doFilter=doFilter, sequenceLength=self.n_frames,
4465
metaFile=self.metaFile, useClusters=useClusterSampling
4566
),
4667
batch_size=batch_size,
@@ -114,7 +135,7 @@ def initModel(self):
114135

115136
self.model = Model(
116137
numClasses=len(self.val_loader.dataset.meta['objects']),
117-
sequenceLength=n_frames)
138+
sequenceLength=self.n_frames)
118139
self.model.epoch = 0
119140
self.model.bestPrec = -1e20
120141

@@ -213,28 +234,21 @@ def save_model(self):
213234

214235

215236
class SecAggTrainer(Trainer):
216-
def __init__(self, client_id, num_clients, data_split_type):
217-
self.client_id = client_id
218-
self.logger = 'logs/{}.log'.format(self.client_id)
219-
self.init_logger()
237+
def __init__(self,
238+
port,
239+
n_frames,
240+
num_clients,
241+
data_split_type):
242+
client_id = 'sec_agg'
243+
super().__init__(port,
244+
n_frames,
245+
client_id,
246+
num_clients,
247+
data_split_type)
220248
self.type = 'secure_aggregator'
221249
self.snapshotDir = 'secure_aggregator/persistent_storage'
222-
self.client_number = None
223250
self.train_split = 'train' # Shouldn't be needed since it doesn't train
224-
mf = 'data/classification'
225-
if data_split_type == 'iid':
226-
self.metaFile = '{}/metadata_{}_clients_iid.mat'.format(
227-
mf, num_clients)
228-
elif data_split_type == 'non-iid-a':
229-
self.metaFile = '{}/metadata_{}_clients_non_iid_a.mat'.format(
230-
mf, num_clients)
231-
elif data_split_type == 'no_split':
232-
self.metaFile = '{}/metadata.mat'.format(mf)
233-
else:
234-
raise Exception('Data split type "{}" not implemented'.format(
235-
data_split_type))
236251
self.init()
237-
super(Trainer, self).__init__()
238252

239253
def get_checkpoint_path(self):
240254
return os.path.join(self.snapshotDir, 'checkpoint.tar')
@@ -244,31 +258,25 @@ def get_best_model_path(self):
244258

245259

246260
class ClientTrainer(Trainer):
247-
def __init__(self, client_number, client_id, num_clients, data_split_type):
248-
self.client_id = client_id
249-
self.logger = 'logs/{}.log'.format(self.client_id)
250-
self.init_logger()
261+
def __init__(self,
262+
port,
263+
n_frames,
264+
client_number,
265+
num_clients,
266+
data_split_type,
267+
client_id):
268+
super().__init__(port,
269+
n_frames,
270+
client_id,
271+
num_clients,
272+
data_split_type)
251273
self.type = 'client'
252-
self.snapshotDir = 'client/snapshots_{}'.format(self.client_id)
253-
self.client_number = client_number
254-
mf = 'data/classification'
255-
if data_split_type == 'iid':
256-
self.train_split = 'train_{}'.format(client_number)
257-
self.metaFile = '{}/metadata_{}_clients_iid.mat'.format(
258-
mf, num_clients)
259-
elif data_split_type == 'non-iid-a':
260-
self.train_split = 'train_{}'.format(client_number)
261-
self.metaFile = '{}/metadata_{}_clients_non_iid_a.mat'.format(
262-
mf, num_clients)
263-
elif data_split_type == 'no_split':
264-
self.train_split = 'train'
265-
self.metaFile = '{}/metadata.mat'.format(mf)
266-
else:
267-
raise Exception('Data split type "{}" not implemented'.format(
268-
data_split_type))
274+
self.snapshotDir = 'client/snapshots_{}'.format(self.port)
275+
# Dont do it with the client_id to avoid tons of folders generated
276+
self.train_split = self.get_train_split()
269277

270278
# Split dataset if file does not exist
271-
if data_split_type in ('iid', 'non-iid-a', 'non-iid-b'):
279+
if self.data_split_type in ('iid', 'non-iid-a', 'non-iid-b'):
272280
from shared import dataset_tools
273281
# TODO: Reimplement this with a lock file.
274282
# If multiple clients are spawned this can be a problem.
@@ -287,7 +295,14 @@ def __init__(self, client_number, client_id, num_clients, data_split_type):
287295
logging.info('File {} already exists. '
288296
'Not creating.'.format(self.metaFile))
289297
self.init()
290-
super(Trainer, self).__init__()
298+
299+
def get_train_split(self):
300+
if self.data_split_type == 'iid':
301+
return 'train_{}'.format(self.client_number)
302+
elif self.data_split_type == 'non-iid-a':
303+
return 'train_{}'.format(self.client_number)
304+
elif self.data_split_type == 'no_split':
305+
return 'train'
291306

292307

293308
class AverageMeter(object):

client/app.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
default='no_split',
2626
help=('Metadata split type. '
2727
'Example: no_split, iid, non-iid-a'))
28+
parser.add_argument('-f', '--n-frames', type=int, required=False,
29+
default=1,
30+
help='n_frames. [1-8]. Default=1')
2831

2932
rsa = rsa_utils.RSAUtils()
3033
args = parser.parse_args()
@@ -45,8 +48,9 @@
4548
# _id auto generated by State
4649
state = State('client', port, _id=None)
4750

48-
client = Client(args.client_number,
49-
port,
51+
client = Client(port,
52+
args.n_frames,
53+
args.client_number,
5054
num_clients,
5155
args.split_type,
5256
state._id)

initialize.sh

+16-6
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,31 @@ else
2727
TO_APPEND="-s no_split"
2828
fi
2929

30+
# N frames
31+
if [[ -z "$N_FRAMES" ]]
32+
then
33+
echo "No N_FRAMES DEFINED. DEFAULT = 1"
34+
N_FRAMES_APPEND=""
35+
else
36+
echo "N_FRAMES DEFINED: $N_FRAMES"
37+
N_FRAMES_APPEND=" -f $N_FRAMES"
38+
fi
39+
3040
# Restart the frontend
3141
curl http://95.179.192.253:8002/restart
3242

3343
# Start main server in a new screen.
3444
screen -dmS main_server bash -c "$PYTHON_PATH main_server/app.py -p 8000"
3545

3646
# Start secure aggregator in a new screen.
37-
screen -dmS secure_aggregator bash -c "$PYTHON_PATH secure_aggregator/app.py -p 8001 $TO_APPEND"
47+
screen -dmS secure_aggregator bash -c "$PYTHON_PATH secure_aggregator/app.py -p 8001 $TO_APPEND $N_FRAMES_APPEND"
3848

3949
# Start N clients in new screens. Add or comment lines as wanted.
40-
screen -dmS client_0 bash -c "$PYTHON_PATH client/app.py -p 8003 -n 0 $TO_APPEND"
41-
screen -dmS client_1 bash -c "$PYTHON_PATH client/app.py -p 8004 -n 1 $TO_APPEND"
42-
screen -dmS client_2 bash -c "$PYTHON_PATH client/app.py -p 8005 -n 2 $TO_APPEND"
43-
screen -dmS client_3 bash -c "$PYTHON_PATH client/app.py -p 8006 -n 3 $TO_APPEND"
44-
screen -dmS client_4 bash -c "$PYTHON_PATH client/app.py -p 8007 -n 4 $TO_APPEND"
50+
screen -dmS client_0 bash -c "$PYTHON_PATH client/app.py -p 8003 -n 0 $TO_APPEND $N_FRAMES_APPEND"
51+
screen -dmS client_1 bash -c "$PYTHON_PATH client/app.py -p 8004 -n 1 $TO_APPEND $N_FRAMES_APPEND"
52+
screen -dmS client_2 bash -c "$PYTHON_PATH client/app.py -p 8005 -n 2 $TO_APPEND $N_FRAMES_APPEND"
53+
screen -dmS client_3 bash -c "$PYTHON_PATH client/app.py -p 8006 -n 3 $TO_APPEND $N_FRAMES_APPEND"
54+
screen -dmS client_4 bash -c "$PYTHON_PATH client/app.py -p 8007 -n 4 $TO_APPEND $N_FRAMES_APPEND"
4555

4656
# Start the orchestrator (will start the training).
4757
echo "Waiting 1 minutes so all clients start"

secure_aggregator/app.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
default='no_split',
2424
help=('Metadata split type. '
2525
'Example: no_split, iid, non-iid-a'))
26+
parser.add_argument('-f', '--n-frames', type=int, required=False,
27+
default=1,
28+
help='n_frames. [1-8]. Default=1')
2629

2730
rsa = rsa_utils.RSAUtils()
2831
args = parser.parse_args()
2932
hosts = utils.read_hosts()
3033
num_clients = len(hosts['clients'])
3134

32-
sec_agg = SecAgg(args.port, num_clients, args.split_type)
35+
sec_agg = SecAgg(args.port, args.n_frames, num_clients, args.split_type)
3336

3437
# _id auto generated by State
3538
state = State('secure_aggregator', args.port, _id=None)

0 commit comments

Comments
 (0)