Skip to content

Commit aeed505

Browse files
committed
State bug fixed, CR configurable from args, finish frontend endpoint
1 parent 7e481b8 commit aeed505

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

frontend/app.py

+10
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
'clients': {},
1515
'iteration': -1,
1616
'started_training': False
17+
'finished_training': False
1718
}
1819

1920
def check_clients_ok():
@@ -25,6 +26,7 @@ def check_clients_ok():
2526
else:
2627
client['check_ok'] = True
2728

29+
2830
@app.route('/')
2931
def index():
3032
check_clients_ok()
@@ -38,10 +40,18 @@ def restart():
3840
'clients': {},
3941
'iteration': -1,
4042
'started_training': False
43+
'finished_training': False
4144
}
4245
return jsonify(global_state)
4346

4447

48+
@app.route('/finish')
49+
def finish():
50+
global global_state
51+
global_state['finished_training'] = True
52+
return jsonify(global_state)
53+
54+
4555
@app.route('/fake')
4656
def fake_data():
4757
global global_state

orchestrator/orchestrator.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def send_iteration_to_frontend(i):
5151
logging.warning('Frontend may be down')
5252

5353

54+
def end_frontend():
55+
logging.info('Sending end signal to frontend')
56+
try:
57+
requests.post(
58+
url='http://{}:{}/finish'.format(hosts['frontend']['host'],
59+
hosts['frontend']['port']))
60+
except:
61+
logging.warning('Frontend may be down')
62+
63+
5464
def restart_frontend():
5565
logging.info('Restarting frontend')
5666
try:
@@ -64,16 +74,14 @@ def restart_frontend():
6474
logging.warning('Frontend may be down')
6575

6676

67-
def main(op_mode):
68-
# TODO: Configure epochs and everything from here
69-
NUM_ITERATIONS = 50
77+
def main(op_mode, communication_rounds):
7078
all_results = []
7179
ch = client_handler.ClientHandler(clients=hosts['clients'],
7280
OPERATION_MODE=op_mode)
7381
#train_accs = {}
7482
start = time.time()
7583
# restart_frontend()
76-
for i in range(NUM_ITERATIONS):
84+
for i in range(communication_rounds):
7785
logging.info('Iteration {}...'.format(i))
7886
send_iteration_to_frontend(i)
7987

@@ -131,6 +139,7 @@ def main(op_mode):
131139

132140
logging.info('All results:')
133141
logging.info(all_results)
142+
end_frontend()
134143

135144

136145
if __name__ == '__main__':
@@ -142,9 +151,15 @@ def main(op_mode):
142151
'Operation mode. '
143152
'Options: wait_all (default), n_firsts, timeout'
144153
))
154+
parser.add_argument('-c',
155+
'--communication-rounds',
156+
type=int,
157+
required=False,
158+
default=50,
159+
help='Number of communication rounds. Default: 50)
145160
args = parser.parse_args()
146161
try:
147-
main(op_mode=args.operation_mode)
162+
main(op_mode=args.operation_mode, args.communication_rounds)
148163
except Exception:
149164
logging.error("Fatal error in main loop", exc_info=True)
150165

shared/state.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def current_state(self):
4848

4949
@current_state.setter
5050
def current_state(self, state):
51+
self._current_state = state
5152
payload = {
5253
'client_type': self.client_type,
5354
'_id': self._id,
54-
'state': self.get_state_string(self.current_state),
55+
'state': self.get_state_string(self._current_state),
5556
'port': self.port,
5657
'host': self.host
5758
}
@@ -66,8 +67,6 @@ def current_state(self, state):
6667
except Exception as e:
6768
logging.warning('Frontend not reachable.\n{}'.format(e))
6869

69-
self._current_state = state
70-
7170
def idle(self):
7271
self.current_state = IDLE
7372

0 commit comments

Comments
 (0)