Skip to content

Commit df3706c

Browse files
authored
Merge pull request #6 from gferrate/client-disconnection
Client disconnection
2 parents 9610fde + b319f2f commit df3706c

File tree

8 files changed

+210
-48
lines changed

8 files changed

+210
-48
lines changed

client/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
)
4343

4444
client = Client(args.client_number, port, num_clients, args.split_type)
45-
state = State('client', client.client_id, port)
45+
46+
# _id auto generated by State
47+
state = State('client', port, _id=None)
4648

4749
app = Flask(__name__)
4850

frontend/app.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def check_clients_ok():
2525
else:
2626
client['check_ok'] = True
2727

28-
2928
@app.route('/')
3029
def index():
3130
check_clients_ok()
@@ -57,7 +56,7 @@ def fake_data():
5756
client = random.choice(client_types)
5857
status = random.choice(statuses)
5958
check_ok = random.choice([True, False])
60-
_id = 'client_{}'.format(port)
59+
_id = 'fake_client_{}'.format(port)
6160
data = {
6261
'client_type': client,
6362
'_id': _id,
@@ -82,14 +81,16 @@ def index_fe():
8281
def keep_alive():
8382
global global_state
8483
data = json.loads(request.data)
84+
host = request.remote_addr
85+
port = data['port']
8586
_id = data['_id']
8687
if _id not in global_state['clients']:
8788
global_state['clients'][_id] = {
8889
'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'),
8990
'client_type': data['client_type'],
90-
'port': data['port'],
91+
'port': port,
9192
'check_ok': True,
92-
'host': data['host'],
93+
'host': host,
9394
'last_ping': datetime.now()
9495
}
9596
global_state['clients'][_id].update({'state': data['state']})
@@ -110,15 +111,17 @@ def iteration():
110111
@app.route('/send_state', methods=['POST'])
111112
def get_state():
112113
global global_state
114+
host = request.remote_addr
113115
data = json.loads(request.data)
116+
port = data['port']
114117
_id = data['_id']
115118
if _id not in global_state['clients']:
116119
global_state['clients'][_id] = {
117120
'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'),
118121
'client_type': data['client_type'],
119-
'port': data['port'],
122+
'port': port,
120123
'check_ok': True,
121-
'host': data['host'],
124+
'host': host,
122125
'last_ping': datetime.now()
123126
}
124127
global_state['clients'][_id].update({'state': data['state']})

main_server/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
rsa = rsa_utils.RSAUtils()
3434
args = parser.parse_args()
3535
hosts = utils.read_hosts()
36-
state = State('main_server', 'client_{}'.format(args.port), args.port)
36+
37+
# _id auto generated by State
38+
state = State('main_server', args.port, _id=None)
3739

3840
app = Flask(__name__)
3941

orchestrator/client_handler.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from multiprocessing.dummy import Pool
2+
import requests
3+
from time import time, sleep
4+
import logging
5+
6+
7+
class ClientHandler:
8+
"""Performs concurrent requests with timeout
9+
10+
:OPERATION_MODE n_firsts, timeout or wait_all.
11+
:clients list of clients' (host, port) tuples
12+
"""
13+
14+
def __init__(self, clients, OPERATION_MODE='wait_all', **kwargs):
15+
self.clients = self.parse_clients(clients)
16+
self.n_clients = len(clients)
17+
self.OPERATION_MODE = OPERATION_MODE
18+
# Set the pool as None, later on will be created
19+
self.pool = None
20+
logging.info(
21+
'[Client Handler] Operation mode: {}'.format(self.OPERATION_MODE))
22+
default_n_firsts = max(1, self.n_clients - 2)
23+
if self.OPERATION_MODE == 'n_firsts':
24+
self.N_FIRSTS = kwargs.get('n_firsts', default_n_firsts)
25+
import pudb; pudb.set_trace()
26+
assert self.N_FIRSTS <= self.n_clients, \
27+
'n_firsts must be <= than num clients'
28+
logging.info(
29+
'[Client Handler] n_firsts: {}'.format(self.N_FIRSTS))
30+
elif self.OPERATION_MODE == 'timeout':
31+
self.WAIT_FROM_N_FIRSTS = kwargs.get('wait_from_n_firsts',
32+
default_n_firsts)
33+
self.TIMEOUT = kwargs.get('timoeut', 60) # Seconds
34+
elif self.OPERATION_MODE == 'wait_all':
35+
self.N_FIRSTS = self.n_clients
36+
logging.info('[Client Handler] Will wait '
37+
'until {} clients'.format(self.N_FIRSTS))
38+
else:
39+
raise Exception('Operation mode not accepted')
40+
self.operations_history = {}
41+
self.init_operations_history()
42+
43+
def perform_requests_and_wait(self, endpoint):
44+
self.perform_parallel_requests(endpoint)
45+
if self.OPERATION_MODE == 'n_firsts':
46+
if endpoint == 'send_model':
47+
# TODO: Do this part with redundancy
48+
return self.wait_until_n_responses(wait_all=True)
49+
return self.wait_until_n_responses()
50+
elif self.OPERATION_MODE == 'timeout':
51+
self.started = time()
52+
return self.wait_until_timeout()
53+
elif self.OPERATION_MODE == 'wait_all':
54+
return self.wait_until_n_responses(wait_all=True)
55+
56+
def init_operations_history(self):
57+
for host, port in self.clients:
58+
key = self.get_client_key(host, port)
59+
self.operations_history[key] = []
60+
61+
@staticmethod
62+
def parse_clients(clients):
63+
p_clients = []
64+
for cl in clients:
65+
host = cl[list(cl.keys())[0]]['host']
66+
port = cl[list(cl.keys())[0]]['port']
67+
p_clients.append((host, port))
68+
return p_clients
69+
70+
def perform_parallel_requests(self, endpoint):
71+
futures = []
72+
self.pool = Pool(self.n_clients)
73+
for host, port in self.clients:
74+
futures.append(
75+
self.pool.apply_async(self.perform_request,
76+
[host, port, endpoint]))
77+
self.pool.close()
78+
79+
def wait_until_timeout(self):
80+
ended_clients = set()
81+
completed = False
82+
while not completed:
83+
for key in self.clients:
84+
try:
85+
last_operation = self.operations_history[key][-1]
86+
except IndexError:
87+
# Last operation still not computed
88+
continue
89+
if last_operation['ended']:
90+
# TODO: Handle exception when status code != 200
91+
assert last_operation['res'].status_code == 200
92+
logging.info(
93+
'[Client Handler] client {} '
94+
'finished performing operation {}'.format(
95+
key, last_operation['op']
96+
)
97+
)
98+
ended_clients.add(key)
99+
elapsed = time() - self.started
100+
if ((len(ended_clients) >= self.WAIT_FROM_N_FIRSTS) and
101+
elapsed > self.TIMEOUT):
102+
self.pool.terminate()
103+
completed = True
104+
sleep(0.1)
105+
return list(ended_clients)
106+
107+
def wait_until_n_responses(self, wait_all=False):
108+
# TODO: What to do in send model?
109+
ended_clients = set()
110+
completed = False
111+
while not completed:
112+
# Periodically check if the requests are ending
113+
for key in self.clients:
114+
try:
115+
last_operation = self.operations_history[key][-1]
116+
except IndexError:
117+
# Last operation still not computed
118+
continue
119+
if last_operation['ended']:
120+
# TODO: Handle exception when status code != 200
121+
assert last_operation['res'].status_code == 200
122+
logging.info(
123+
'[Client Handler] client {} '
124+
'finished performing operation {}'.format(
125+
key, last_operation['op']
126+
)
127+
)
128+
ended_clients.add(key)
129+
if ((not wait_all and (len(ended_clients) >= self.N_FIRSTS))
130+
or (wait_all and len(ended_clients) == self.N_FIRSTS)):
131+
self.pool.terminate()
132+
completed = True
133+
sleep(0.1)
134+
return list(ended_clients)
135+
136+
@staticmethod
137+
def get_client_key(host, port):
138+
return (host, port)
139+
140+
def perform_request(self, host, port, endpoint):
141+
key = self.get_client_key(host, port)
142+
last_operation = {
143+
'started': time(),
144+
'op': endpoint,
145+
'status': 'started',
146+
'ended': None
147+
}
148+
url = 'http://{}:{}/{}'.format(host, port, endpoint)
149+
res = requests.get(url)
150+
last_operation.update({'status': 'ended',
151+
'ended': time(),
152+
'response': res})
153+
self.operations_history.setdefault(key, []).append(last_operation)
154+

orchestrator/orchestrator.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import requests
55
import time
66
import logging
7+
import argparse
78

89
from shared import utils
10+
import client_handler
911

1012
logging.basicConfig(
1113
format='%(asctime)s %(message)s',
@@ -17,17 +19,8 @@
1719
)
1820

1921

20-
hosts = utils.read_hosts(is_docker=False)
21-
22-
23-
def get_client_urls(endpoint):
24-
hp = []
25-
for cl in hosts['clients']:
26-
host = cl[list(cl.keys())[0]]['host']
27-
port = cl[list(cl.keys())[0]]['port']
28-
url = 'http://{}:{}/{}'.format(host, port, endpoint)
29-
hp.append(url)
30-
return hp
22+
hosts = utils.read_hosts(override_localhost=False)
23+
client_opertaions_history = {}
3124

3225

3326
def log_elapsed_time(start):
@@ -72,42 +65,33 @@ def restart_frontend():
7265
logging.warning('Frontend may be down')
7366

7467

75-
def main():
68+
def main(op_mode):
7669
# TODO: Configure epochs and everything from here
77-
num_iterations = 50
70+
NUM_ITERATIONS = 50
7871
all_results = []
72+
ch = client_handler.ClientHandler(clients=hosts['clients'],
73+
OPERATION_MODE=op_mode)
7974
#train_accs = {}
8075
start = time.time()
8176
# restart_frontend()
82-
for i in range(num_iterations):
77+
for i in range(NUM_ITERATIONS):
8378
logging.info('Iteration {}...'.format(i))
8479
send_iteration_to_frontend(i)
8580

8681
logging.info('Sending /train_model request to clients...')
87-
client_urls = get_client_urls('train_model')
88-
rs = (grequests.get(u) for u in client_urls)
89-
responses = grequests.map(rs)
90-
# print('\nTrain acc:')
91-
for res in responses:
92-
check_response_ok(res)
93-
#res_json = res.json()
94-
#print(res_json)
95-
#print('\n')
96-
#train_accs.setdefault(res_json['client_id'], []).append(
97-
# res_json['results'])
82+
performed_clients = ch.perform_requests_and_wait('train_model')
83+
logging.info('Performed clients: {}'.format(performed_clients))
9884
logging.info('Done')
9985
log_elapsed_time(start)
10086

10187
logging.info('Sending /send_model command to clients...')
102-
client_urls = get_client_urls('send_model')
103-
rs = (grequests.get(u) for u in client_urls)
104-
responses = grequests.map(rs)
105-
for res in responses:
106-
check_response_ok(res)
88+
performed_clients = ch.perform_requests_and_wait('train_model')
89+
logging.info('Performed clients: {}'.format(performed_clients))
10790
logging.info('Done')
10891
log_elapsed_time(start)
10992

110-
logging.info('Sending /aggregate_models command to secure aggregator...')
93+
logging.info('Sending /aggregate_models '
94+
'command to secure aggregator...')
11195
url = 'http://{}:{}/aggregate_models'.format(
11296
hosts['secure_aggregator']['host'],
11397
hosts['secure_aggregator']['port']
@@ -122,7 +106,9 @@ def main():
122106
logging.info('Done')
123107
log_elapsed_time(start)
124108

125-
logging.info('Sending /send_model_to_main_server command to secure aggregator...')
109+
logging.info(
110+
'Sending /send_model_to_main_server '
111+
'command to secure aggregator...')
126112
url = 'http://{}:{}/send_model_to_main_server'.format(
127113
hosts['secure_aggregator']['host'],
128114
hosts['secure_aggregator']['port']
@@ -144,15 +130,22 @@ def main():
144130
logging.info('Test result: {}'.format(test_result))
145131
log_elapsed_time(start)
146132

147-
# logging.info('All train accuracies:')
148-
# print(train_accs)
149133
logging.info('All results:')
150134
logging.info(all_results)
151135

152136

153137
if __name__ == '__main__':
138+
parser = argparse.ArgumentParser(description='Orchestrator')
139+
# TODO: Add configuration for Client handlers in other modes (timeout, etc)
140+
parser.add_argument('-o', '--operation-mode', type=str, required=False,
141+
default='wait_all',
142+
help=(
143+
'Operation mode. '
144+
'Options: wait_all (default), n_firsts, timeout'
145+
))
146+
args = parser.parse_args()
154147
try:
155-
main()
148+
main(op_mode=args.operation_mode)
156149
except Exception:
157150
logging.error("Fatal error in main loop", exc_info=True)
158151

secure_aggregator/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727
use_cuda = True
2828
sec_agg = SecAgg(args.port, use_cuda)
29-
state = State('secure_aggregator', sec_agg.client_id, args.port)
29+
30+
# _id auto generated by State
31+
state = State('secure_aggregator', args.port, _id=None)
3032

3133

3234
app = Flask(__name__)

shared/state.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import socket
55
from time import sleep
66
import multiprocessing
7+
import random
8+
from string import ascii_uppercase, digits
79

810

911
IDLE = 0
@@ -25,17 +27,21 @@
2527

2628

2729
class State:
28-
def __init__(self, client_type, _id, port):
30+
def __init__(self, client_type, port, _id=None):
2931
assert client_type in ('client', 'secure_aggregator', 'main_server')
3032
self.client_type = client_type
3133
self.host = socket.gethostbyname(socket.gethostname())
3234
self.port = port
33-
self._id = _id
35+
self._id = _id if _id else self.generate_random_id()
3436
self._current_state = IDLE
3537
self.current_state = IDLE
3638
p = multiprocessing.Process(target=self.send_ping_continuously)
3739
p.start()
3840

41+
@staticmethod
42+
def generate_random_id(N=8):
43+
return ''.join(random.choices(ascii_uppercase + digits, k=N))
44+
3945
@property
4046
def current_state(self):
4147
return self._current_state

0 commit comments

Comments
 (0)