Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client disconnection #6

Merged
merged 6 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
)

client = Client(args.client_number, port, num_clients, args.split_type)
state = State('client', client.client_id, port)

# _id auto generated by State
state = State('client', port, _id=None)

app = Flask(__name__)

Expand Down
15 changes: 9 additions & 6 deletions frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def check_clients_ok():
else:
client['check_ok'] = True


@app.route('/')
def index():
check_clients_ok()
Expand Down Expand Up @@ -57,7 +56,7 @@ def fake_data():
client = random.choice(client_types)
status = random.choice(statuses)
check_ok = random.choice([True, False])
_id = 'client_{}'.format(port)
_id = 'fake_client_{}'.format(port)
data = {
'client_type': client,
'_id': _id,
Expand All @@ -82,14 +81,16 @@ def index_fe():
def keep_alive():
global global_state
data = json.loads(request.data)
host = request.remote_addr
port = data['port']
_id = data['_id']
if _id not in global_state['clients']:
global_state['clients'][_id] = {
'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'),
'client_type': data['client_type'],
'port': data['port'],
'port': port,
'check_ok': True,
'host': data['host'],
'host': host,
'last_ping': datetime.now()
}
global_state['clients'][_id].update({'state': data['state']})
Expand All @@ -110,15 +111,17 @@ def iteration():
@app.route('/send_state', methods=['POST'])
def get_state():
global global_state
host = request.remote_addr
data = json.loads(request.data)
port = data['port']
_id = data['_id']
if _id not in global_state['clients']:
global_state['clients'][_id] = {
'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'),
'client_type': data['client_type'],
'port': data['port'],
'port': port,
'check_ok': True,
'host': data['host'],
'host': host,
'last_ping': datetime.now()
}
global_state['clients'][_id].update({'state': data['state']})
Expand Down
4 changes: 3 additions & 1 deletion main_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
rsa = rsa_utils.RSAUtils()
args = parser.parse_args()
hosts = utils.read_hosts()
state = State('main_server', 'client_{}'.format(args.port), args.port)

# _id auto generated by State
state = State('main_server', args.port, _id=None)

app = Flask(__name__)

Expand Down
154 changes: 154 additions & 0 deletions orchestrator/client_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from multiprocessing.dummy import Pool
import requests
from time import time, sleep
import logging


class ClientHandler:
"""Performs concurrent requests with timeout

:OPERATION_MODE n_firsts, timeout or wait_all.
:clients list of clients' (host, port) tuples
"""

def __init__(self, clients, OPERATION_MODE='wait_all', **kwargs):
self.clients = self.parse_clients(clients)
self.n_clients = len(clients)
self.OPERATION_MODE = OPERATION_MODE
# Set the pool as None, later on will be created
self.pool = None
logging.info(
'[Client Handler] Operation mode: {}'.format(self.OPERATION_MODE))
default_n_firsts = max(1, self.n_clients - 2)
if self.OPERATION_MODE == 'n_firsts':
self.N_FIRSTS = kwargs.get('n_firsts', default_n_firsts)
import pudb; pudb.set_trace()
assert self.N_FIRSTS <= self.n_clients, \
'n_firsts must be <= than num clients'
logging.info(
'[Client Handler] n_firsts: {}'.format(self.N_FIRSTS))
elif self.OPERATION_MODE == 'timeout':
self.WAIT_FROM_N_FIRSTS = kwargs.get('wait_from_n_firsts',
default_n_firsts)
self.TIMEOUT = kwargs.get('timoeut', 60) # Seconds
elif self.OPERATION_MODE == 'wait_all':
self.N_FIRSTS = self.n_clients
logging.info('[Client Handler] Will wait '
'until {} clients'.format(self.N_FIRSTS))
else:
raise Exception('Operation mode not accepted')
self.operations_history = {}
self.init_operations_history()

def perform_requests_and_wait(self, endpoint):
self.perform_parallel_requests(endpoint)
if self.OPERATION_MODE == 'n_firsts':
if endpoint == 'send_model':
# TODO: Do this part with redundancy
return self.wait_until_n_responses(wait_all=True)
return self.wait_until_n_responses()
elif self.OPERATION_MODE == 'timeout':
self.started = time()
return self.wait_until_timeout()
elif self.OPERATION_MODE == 'wait_all':
return self.wait_until_n_responses(wait_all=True)

def init_operations_history(self):
for host, port in self.clients:
key = self.get_client_key(host, port)
self.operations_history[key] = []

@staticmethod
def parse_clients(clients):
p_clients = []
for cl in clients:
host = cl[list(cl.keys())[0]]['host']
port = cl[list(cl.keys())[0]]['port']
p_clients.append((host, port))
return p_clients

def perform_parallel_requests(self, endpoint):
futures = []
self.pool = Pool(self.n_clients)
for host, port in self.clients:
futures.append(
self.pool.apply_async(self.perform_request,
[host, port, endpoint]))
self.pool.close()

def wait_until_timeout(self):
ended_clients = set()
completed = False
while not completed:
for key in self.clients:
try:
last_operation = self.operations_history[key][-1]
except IndexError:
# Last operation still not computed
continue
if last_operation['ended']:
# TODO: Handle exception when status code != 200
assert last_operation['res'].status_code == 200
logging.info(
'[Client Handler] client {} '
'finished performing operation {}'.format(
key, last_operation['op']
)
)
ended_clients.add(key)
elapsed = time() - self.started
if ((len(ended_clients) >= self.WAIT_FROM_N_FIRSTS) and
elapsed > self.TIMEOUT):
self.pool.terminate()
completed = True
sleep(0.1)
return list(ended_clients)

def wait_until_n_responses(self, wait_all=False):
# TODO: What to do in send model?
ended_clients = set()
completed = False
while not completed:
# Periodically check if the requests are ending
for key in self.clients:
try:
last_operation = self.operations_history[key][-1]
except IndexError:
# Last operation still not computed
continue
if last_operation['ended']:
# TODO: Handle exception when status code != 200
assert last_operation['res'].status_code == 200
logging.info(
'[Client Handler] client {} '
'finished performing operation {}'.format(
key, last_operation['op']
)
)
ended_clients.add(key)
if ((not wait_all and (len(ended_clients) >= self.N_FIRSTS))
or (wait_all and len(ended_clients) == self.N_FIRSTS)):
self.pool.terminate()
completed = True
sleep(0.1)
return list(ended_clients)

@staticmethod
def get_client_key(host, port):
return (host, port)

def perform_request(self, host, port, endpoint):
key = self.get_client_key(host, port)
last_operation = {
'started': time(),
'op': endpoint,
'status': 'started',
'ended': None
}
url = 'http://{}:{}/{}'.format(host, port, endpoint)
res = requests.get(url)
last_operation.update({'status': 'ended',
'ended': time(),
'response': res})
self.operations_history.setdefault(key, []).append(last_operation)

63 changes: 28 additions & 35 deletions orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import requests
import time
import logging
import argparse

from shared import utils
import client_handler

logging.basicConfig(
format='%(asctime)s %(message)s',
Expand All @@ -17,17 +19,8 @@
)


hosts = utils.read_hosts(is_docker=False)


def get_client_urls(endpoint):
hp = []
for cl in hosts['clients']:
host = cl[list(cl.keys())[0]]['host']
port = cl[list(cl.keys())[0]]['port']
url = 'http://{}:{}/{}'.format(host, port, endpoint)
hp.append(url)
return hp
hosts = utils.read_hosts(override_localhost=False)
client_opertaions_history = {}


def log_elapsed_time(start):
Expand Down Expand Up @@ -72,42 +65,33 @@ def restart_frontend():
logging.warning('Frontend may be down')


def main():
def main(op_mode):
# TODO: Configure epochs and everything from here
num_iterations = 50
NUM_ITERATIONS = 50
all_results = []
ch = client_handler.ClientHandler(clients=hosts['clients'],
OPERATION_MODE=op_mode)
#train_accs = {}
start = time.time()
# restart_frontend()
for i in range(num_iterations):
for i in range(NUM_ITERATIONS):
logging.info('Iteration {}...'.format(i))
send_iteration_to_frontend(i)

logging.info('Sending /train_model request to clients...')
client_urls = get_client_urls('train_model')
rs = (grequests.get(u) for u in client_urls)
responses = grequests.map(rs)
# print('\nTrain acc:')
for res in responses:
check_response_ok(res)
#res_json = res.json()
#print(res_json)
#print('\n')
#train_accs.setdefault(res_json['client_id'], []).append(
# res_json['results'])
performed_clients = ch.perform_requests_and_wait('train_model')
logging.info('Performed clients: {}'.format(performed_clients))
logging.info('Done')
log_elapsed_time(start)

logging.info('Sending /send_model command to clients...')
client_urls = get_client_urls('send_model')
rs = (grequests.get(u) for u in client_urls)
responses = grequests.map(rs)
for res in responses:
check_response_ok(res)
performed_clients = ch.perform_requests_and_wait('train_model')
logging.info('Performed clients: {}'.format(performed_clients))
logging.info('Done')
log_elapsed_time(start)

logging.info('Sending /aggregate_models command to secure aggregator...')
logging.info('Sending /aggregate_models '
'command to secure aggregator...')
url = 'http://{}:{}/aggregate_models'.format(
hosts['secure_aggregator']['host'],
hosts['secure_aggregator']['port']
Expand All @@ -122,7 +106,9 @@ def main():
logging.info('Done')
log_elapsed_time(start)

logging.info('Sending /send_model_to_main_server command to secure aggregator...')
logging.info(
'Sending /send_model_to_main_server '
'command to secure aggregator...')
url = 'http://{}:{}/send_model_to_main_server'.format(
hosts['secure_aggregator']['host'],
hosts['secure_aggregator']['port']
Expand All @@ -144,15 +130,22 @@ def main():
logging.info('Test result: {}'.format(test_result))
log_elapsed_time(start)

# logging.info('All train accuracies:')
# print(train_accs)
logging.info('All results:')
logging.info(all_results)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Orchestrator')
# TODO: Add configuration for Client handlers in other modes (timeout, etc)
parser.add_argument('-o', '--operation-mode', type=str, required=False,
default='wait_all',
help=(
'Operation mode. '
'Options: wait_all (default), n_firsts, timeout'
))
args = parser.parse_args()
try:
main()
main(op_mode=args.operation_mode)
except Exception:
logging.error("Fatal error in main loop", exc_info=True)

4 changes: 3 additions & 1 deletion secure_aggregator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

use_cuda = True
sec_agg = SecAgg(args.port, use_cuda)
state = State('secure_aggregator', sec_agg.client_id, args.port)

# _id auto generated by State
state = State('secure_aggregator', args.port, _id=None)


app = Flask(__name__)
Expand Down
10 changes: 8 additions & 2 deletions shared/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import socket
from time import sleep
import multiprocessing
import random
from string import ascii_uppercase, digits


IDLE = 0
Expand All @@ -25,17 +27,21 @@


class State:
def __init__(self, client_type, _id, port):
def __init__(self, client_type, port, _id=None):
assert client_type in ('client', 'secure_aggregator', 'main_server')
self.client_type = client_type
self.host = socket.gethostbyname(socket.gethostname())
self.port = port
self._id = _id
self._id = _id if _id else self.generate_random_id()
self._current_state = IDLE
self.current_state = IDLE
p = multiprocessing.Process(target=self.send_ping_continuously)
p.start()

@staticmethod
def generate_random_id(N=8):
return ''.join(random.choices(ascii_uppercase + digits, k=N))

@property
def current_state(self):
return self._current_state
Expand Down
Loading