diff --git a/client/app.py b/client/app.py index f6f08a1..5ff92c0 100644 --- a/client/app.py +++ b/client/app.py @@ -42,7 +42,9 @@ ) client = Client(args.client_number, port, num_clients, args.split_type) -state = State('client', client.client_id, port) + +# _id auto generated by State +state = State('client', port, _id=None) app = Flask(__name__) diff --git a/frontend/app.py b/frontend/app.py index 1cc6aba..08fc263 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -25,7 +25,6 @@ def check_clients_ok(): else: client['check_ok'] = True - @app.route('/') def index(): check_clients_ok() @@ -57,7 +56,7 @@ def fake_data(): client = random.choice(client_types) status = random.choice(statuses) check_ok = random.choice([True, False]) - _id = 'client_{}'.format(port) + _id = 'fake_client_{}'.format(port) data = { 'client_type': client, '_id': _id, @@ -82,14 +81,16 @@ def index_fe(): def keep_alive(): global global_state data = json.loads(request.data) + host = request.remote_addr + port = data['port'] _id = data['_id'] if _id not in global_state['clients']: global_state['clients'][_id] = { 'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'), 'client_type': data['client_type'], - 'port': data['port'], + 'port': port, 'check_ok': True, - 'host': data['host'], + 'host': host, 'last_ping': datetime.now() } global_state['clients'][_id].update({'state': data['state']}) @@ -110,15 +111,17 @@ def iteration(): @app.route('/send_state', methods=['POST']) def get_state(): global global_state + host = request.remote_addr data = json.loads(request.data) + port = data['port'] _id = data['_id'] if _id not in global_state['clients']: global_state['clients'][_id] = { 'joined_at': datetime.now().strftime('%Y-%m-%d %H:%m'), 'client_type': data['client_type'], - 'port': data['port'], + 'port': port, 'check_ok': True, - 'host': data['host'], + 'host': host, 'last_ping': datetime.now() } global_state['clients'][_id].update({'state': data['state']}) diff --git a/main_server/app.py b/main_server/app.py index c5f409a..4bd612e 100644 --- a/main_server/app.py +++ b/main_server/app.py @@ -33,7 +33,9 @@ rsa = rsa_utils.RSAUtils() args = parser.parse_args() hosts = utils.read_hosts() -state = State('main_server', 'client_{}'.format(args.port), args.port) + +# _id auto generated by State +state = State('main_server', args.port, _id=None) app = Flask(__name__) diff --git a/orchestrator/client_handler.py b/orchestrator/client_handler.py new file mode 100644 index 0000000..dc6da07 --- /dev/null +++ b/orchestrator/client_handler.py @@ -0,0 +1,154 @@ +from multiprocessing.dummy import Pool +import requests +from time import time, sleep +import logging + + +class ClientHandler: + """Performs concurrent requests with timeout + + :OPERATION_MODE n_firsts, timeout or wait_all. + :clients list of clients' (host, port) tuples + """ + + def __init__(self, clients, OPERATION_MODE='wait_all', **kwargs): + self.clients = self.parse_clients(clients) + self.n_clients = len(clients) + self.OPERATION_MODE = OPERATION_MODE + # Set the pool as None, later on will be created + self.pool = None + logging.info( + '[Client Handler] Operation mode: {}'.format(self.OPERATION_MODE)) + default_n_firsts = max(1, self.n_clients - 2) + if self.OPERATION_MODE == 'n_firsts': + self.N_FIRSTS = kwargs.get('n_firsts', default_n_firsts) + import pudb; pudb.set_trace() + assert self.N_FIRSTS <= self.n_clients, \ + 'n_firsts must be <= than num clients' + logging.info( + '[Client Handler] n_firsts: {}'.format(self.N_FIRSTS)) + elif self.OPERATION_MODE == 'timeout': + self.WAIT_FROM_N_FIRSTS = kwargs.get('wait_from_n_firsts', + default_n_firsts) + self.TIMEOUT = kwargs.get('timoeut', 60) # Seconds + elif self.OPERATION_MODE == 'wait_all': + self.N_FIRSTS = self.n_clients + logging.info('[Client Handler] Will wait ' + 'until {} clients'.format(self.N_FIRSTS)) + else: + raise Exception('Operation mode not accepted') + self.operations_history = {} + self.init_operations_history() + + def perform_requests_and_wait(self, endpoint): + self.perform_parallel_requests(endpoint) + if self.OPERATION_MODE == 'n_firsts': + if endpoint == 'send_model': + # TODO: Do this part with redundancy + return self.wait_until_n_responses(wait_all=True) + return self.wait_until_n_responses() + elif self.OPERATION_MODE == 'timeout': + self.started = time() + return self.wait_until_timeout() + elif self.OPERATION_MODE == 'wait_all': + return self.wait_until_n_responses(wait_all=True) + + def init_operations_history(self): + for host, port in self.clients: + key = self.get_client_key(host, port) + self.operations_history[key] = [] + + @staticmethod + def parse_clients(clients): + p_clients = [] + for cl in clients: + host = cl[list(cl.keys())[0]]['host'] + port = cl[list(cl.keys())[0]]['port'] + p_clients.append((host, port)) + return p_clients + + def perform_parallel_requests(self, endpoint): + futures = [] + self.pool = Pool(self.n_clients) + for host, port in self.clients: + futures.append( + self.pool.apply_async(self.perform_request, + [host, port, endpoint])) + self.pool.close() + + def wait_until_timeout(self): + ended_clients = set() + completed = False + while not completed: + for key in self.clients: + try: + last_operation = self.operations_history[key][-1] + except IndexError: + # Last operation still not computed + continue + if last_operation['ended']: + # TODO: Handle exception when status code != 200 + assert last_operation['res'].status_code == 200 + logging.info( + '[Client Handler] client {} ' + 'finished performing operation {}'.format( + key, last_operation['op'] + ) + ) + ended_clients.add(key) + elapsed = time() - self.started + if ((len(ended_clients) >= self.WAIT_FROM_N_FIRSTS) and + elapsed > self.TIMEOUT): + self.pool.terminate() + completed = True + sleep(0.1) + return list(ended_clients) + + def wait_until_n_responses(self, wait_all=False): + # TODO: What to do in send model? + ended_clients = set() + completed = False + while not completed: + # Periodically check if the requests are ending + for key in self.clients: + try: + last_operation = self.operations_history[key][-1] + except IndexError: + # Last operation still not computed + continue + if last_operation['ended']: + # TODO: Handle exception when status code != 200 + assert last_operation['res'].status_code == 200 + logging.info( + '[Client Handler] client {} ' + 'finished performing operation {}'.format( + key, last_operation['op'] + ) + ) + ended_clients.add(key) + if ((not wait_all and (len(ended_clients) >= self.N_FIRSTS)) + or (wait_all and len(ended_clients) == self.N_FIRSTS)): + self.pool.terminate() + completed = True + sleep(0.1) + return list(ended_clients) + + @staticmethod + def get_client_key(host, port): + return (host, port) + + def perform_request(self, host, port, endpoint): + key = self.get_client_key(host, port) + last_operation = { + 'started': time(), + 'op': endpoint, + 'status': 'started', + 'ended': None + } + url = 'http://{}:{}/{}'.format(host, port, endpoint) + res = requests.get(url) + last_operation.update({'status': 'ended', + 'ended': time(), + 'response': res}) + self.operations_history.setdefault(key, []).append(last_operation) + diff --git a/orchestrator/orchestrator.py b/orchestrator/orchestrator.py index 7561c65..5457776 100644 --- a/orchestrator/orchestrator.py +++ b/orchestrator/orchestrator.py @@ -4,8 +4,10 @@ import requests import time import logging +import argparse from shared import utils +import client_handler logging.basicConfig( format='%(asctime)s %(message)s', @@ -17,17 +19,8 @@ ) -hosts = utils.read_hosts(is_docker=False) - - -def get_client_urls(endpoint): - hp = [] - for cl in hosts['clients']: - host = cl[list(cl.keys())[0]]['host'] - port = cl[list(cl.keys())[0]]['port'] - url = 'http://{}:{}/{}'.format(host, port, endpoint) - hp.append(url) - return hp +hosts = utils.read_hosts(override_localhost=False) +client_opertaions_history = {} def log_elapsed_time(start): @@ -72,42 +65,33 @@ def restart_frontend(): logging.warning('Frontend may be down') -def main(): +def main(op_mode): # TODO: Configure epochs and everything from here - num_iterations = 50 + NUM_ITERATIONS = 50 all_results = [] + ch = client_handler.ClientHandler(clients=hosts['clients'], + OPERATION_MODE=op_mode) #train_accs = {} start = time.time() # restart_frontend() - for i in range(num_iterations): + for i in range(NUM_ITERATIONS): logging.info('Iteration {}...'.format(i)) send_iteration_to_frontend(i) logging.info('Sending /train_model request to clients...') - client_urls = get_client_urls('train_model') - rs = (grequests.get(u) for u in client_urls) - responses = grequests.map(rs) - # print('\nTrain acc:') - for res in responses: - check_response_ok(res) - #res_json = res.json() - #print(res_json) - #print('\n') - #train_accs.setdefault(res_json['client_id'], []).append( - # res_json['results']) + performed_clients = ch.perform_requests_and_wait('train_model') + logging.info('Performed clients: {}'.format(performed_clients)) logging.info('Done') log_elapsed_time(start) logging.info('Sending /send_model command to clients...') - client_urls = get_client_urls('send_model') - rs = (grequests.get(u) for u in client_urls) - responses = grequests.map(rs) - for res in responses: - check_response_ok(res) + performed_clients = ch.perform_requests_and_wait('train_model') + logging.info('Performed clients: {}'.format(performed_clients)) logging.info('Done') log_elapsed_time(start) - logging.info('Sending /aggregate_models command to secure aggregator...') + logging.info('Sending /aggregate_models ' + 'command to secure aggregator...') url = 'http://{}:{}/aggregate_models'.format( hosts['secure_aggregator']['host'], hosts['secure_aggregator']['port'] @@ -122,7 +106,9 @@ def main(): logging.info('Done') log_elapsed_time(start) - logging.info('Sending /send_model_to_main_server command to secure aggregator...') + logging.info( + 'Sending /send_model_to_main_server ' + 'command to secure aggregator...') url = 'http://{}:{}/send_model_to_main_server'.format( hosts['secure_aggregator']['host'], hosts['secure_aggregator']['port'] @@ -144,15 +130,22 @@ def main(): logging.info('Test result: {}'.format(test_result)) log_elapsed_time(start) - # logging.info('All train accuracies:') - # print(train_accs) logging.info('All results:') logging.info(all_results) if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Orchestrator') + # TODO: Add configuration for Client handlers in other modes (timeout, etc) + parser.add_argument('-o', '--operation-mode', type=str, required=False, + default='wait_all', + help=( + 'Operation mode. ' + 'Options: wait_all (default), n_firsts, timeout' + )) + args = parser.parse_args() try: - main() + main(op_mode=args.operation_mode) except Exception: logging.error("Fatal error in main loop", exc_info=True) diff --git a/secure_aggregator/app.py b/secure_aggregator/app.py index 0b0c40e..4ecc169 100644 --- a/secure_aggregator/app.py +++ b/secure_aggregator/app.py @@ -26,7 +26,9 @@ use_cuda = True sec_agg = SecAgg(args.port, use_cuda) -state = State('secure_aggregator', sec_agg.client_id, args.port) + +# _id auto generated by State +state = State('secure_aggregator', args.port, _id=None) app = Flask(__name__) diff --git a/shared/state.py b/shared/state.py index 971e718..aee0826 100644 --- a/shared/state.py +++ b/shared/state.py @@ -4,6 +4,8 @@ import socket from time import sleep import multiprocessing +import random +from string import ascii_uppercase, digits IDLE = 0 @@ -25,17 +27,21 @@ class State: - def __init__(self, client_type, _id, port): + def __init__(self, client_type, port, _id=None): assert client_type in ('client', 'secure_aggregator', 'main_server') self.client_type = client_type self.host = socket.gethostbyname(socket.gethostname()) self.port = port - self._id = _id + self._id = _id if _id else self.generate_random_id() self._current_state = IDLE self.current_state = IDLE p = multiprocessing.Process(target=self.send_ping_continuously) p.start() + @staticmethod + def generate_random_id(N=8): + return ''.join(random.choices(ascii_uppercase + digits, k=N)) + @property def current_state(self): return self._current_state diff --git a/shared/utils.py b/shared/utils.py index e98c8de..31f3b9d 100644 --- a/shared/utils.py +++ b/shared/utils.py @@ -1,10 +1,10 @@ import yaml -def read_hosts(is_docker=True): +def read_hosts(override_localhost=False): with open('hosts.yml', 'r') as f: hosts = yaml.safe_load(f) - if not is_docker: + if override_localhost: # Change to hosts to localhost for x, vals in hosts.items(): if x != 'frontend' and x != 'clients':