Skip to content

Commit 981998b

Browse files
committed
Fixed errors
1 parent b20080f commit 981998b

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

classification/trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def __init__(self, port, n_frames, client_id, num_clients, data_split_type):
2626
self.client_id = client_id
2727
self.num_clients = num_clients
2828
self.data_split_type = data_split_type
29-
self.logger = 'logs/{}.log'.format(self.client_id)
30-
self.init_logger()
29+
#self.logger = 'logs/{}.log'.format(self.client_id)
30+
#self.init_logger() # Logged from parent
3131
self.metaFile = self.get_meta_file()
3232

3333
def get_meta_file(self):
@@ -238,8 +238,8 @@ def __init__(self,
238238
port,
239239
n_frames,
240240
num_clients,
241+
client_id,
241242
data_split_type):
242-
client_id = 'sec_agg'
243243
super().__init__(port,
244244
n_frames,
245245
client_id,

client/client.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
class Client:
77

8-
def __init__(self, client_number, port, num_clients, split_type, client_id):
8+
def __init__(self, port, n_frames,
9+
client_number, num_clients, split_type, client_id):
10+
self.n_frames = n_frames
911
self.split_type = split_type
1012
self.client_number = client_number
1113
self.port = port
@@ -18,10 +20,12 @@ def init_model(self):
1820
del self.trainer
1921
except:
2022
pass
21-
self.trainer = trainer.ClientTrainer(self.client_number,
22-
self.client_id,
23+
self.trainer = trainer.ClientTrainer(self.port,
24+
self.n_frames,
25+
self.client_number,
2326
self.num_clients,
24-
self.split_type)
27+
self.split_type,
28+
self.client_id)
2529

2630
def train(self):
2731
self.trainer.train()

secure_aggregator/sec_agg.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
class SecAgg:
1212

13-
def __init__(self, port, num_clients, split_type):
13+
def __init__(self, port, n_frames, num_clients, split_type):
14+
self.n_frames = n_frames
1415
self.init_logger()
1516
self.port = port
1617
self.num_clients = num_clients
1718
self.split_type = split_type
18-
# TODO: Num clients don't have to be in SecAgg
1919
self.client_id = 'client_{}'.format(self.port)
2020
self.client_models_folder = 'secure_aggregator/client_models'
2121
self.init_model()
@@ -32,8 +32,10 @@ def init_logger(self):
3232
)
3333

3434
def init_model(self):
35-
self.trainer = trainer.SecAggTrainer(self.client_id,
35+
self.trainer = trainer.SecAggTrainer(self.port,
36+
self.n_frames,
3637
self.num_clients,
38+
self.client_id,
3739
self.split_type)
3840

3941
def load_models(self):

0 commit comments

Comments
 (0)