-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstate.py
117 lines (98 loc) · 3.36 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from shared import utils
import requests
import logging
import socket
from time import sleep
import multiprocessing
import random
from string import ascii_uppercase, digits
IDLE = 0
CLIENT_TRAIN_MODEL = 1
CLIENT_GET_AGG_MODEL = 2
CLIENT_SEND_MODEL = 3
SEC_AGG_GET_CLIENT_MODEL = 4
SEC_AGG_AGGREGATE_MODELS = 5
SEC_AGG_SEND_TO_MAIN_SERVER = 6
MAIN_SERVER_SEND_MODEL_TO_CLIENTS = 7
MAIN_SERVER_GET_SECAGG_MODEL = 8
PING_CADENCE = 20 # Seconds
hosts = utils.read_hosts()
class State:
def __init__(self, client_type, port, _id=None):
assert client_type in ('client', 'secure_aggregator', 'main_server')
self.client_type = client_type
self.host = socket.gethostbyname(socket.gethostname())
self.port = port
self._id = _id if _id else self.generate_random_id()
self._current_state = IDLE
self.current_state = IDLE
p = multiprocessing.Process(target=self.send_ping_continuously)
p.start()
@staticmethod
def generate_random_id(N=8):
return ''.join(random.choices(ascii_uppercase + digits, k=N))
@property
def current_state(self):
return self._current_state
@current_state.setter
def current_state(self, state):
payload = {
'client_type': self.client_type,
'_id': self._id,
'state': self.get_state_string(self.current_state),
'port': self.port,
'host': self.host
}
try:
requests.post(
url='http://{}:{}/send_state'.format(
hosts['frontend']['host'],
hosts['frontend']['port']
),
json=payload
)
except Exception as e:
logging.warning('Frontend not reachable.\n{}'.format(e))
self._current_state = state
def idle(self):
self.current_state = IDLE
def is_idle(self):
return self.current_state == IDLE
def send_ping_continuously(self):
while True:
self.send_ping()
sleep(PING_CADENCE)
def send_ping(self):
payload = {
'client_type': self.client_type,
'_id': self._id,
'state': self.get_state_string(self.current_state),
'port': self.port,
'host': self.host
}
try:
requests.post(
url='http://{}:{}/ping'.format(
hosts['frontend']['host'],
hosts['frontend']['port']
),
json=payload
)
except Exception as e:
logging.warning('Frontend not reachable.\n{}'.format(e))
@staticmethod
def get_state_string(state):
strings = {
IDLE: 'IDLE',
CLIENT_TRAIN_MODEL: 'Training model',
CLIENT_GET_AGG_MODEL: 'Getting aggregated model',
CLIENT_SEND_MODEL: 'Sending model to secure aggregator',
SEC_AGG_GET_CLIENT_MODEL: 'Getting getting client models',
SEC_AGG_SEND_TO_MAIN_SERVER: 'Sending model to main server',
SEC_AGG_AGGREGATE_MODELS: 'Aggregating models',
MAIN_SERVER_GET_SECAGG_MODEL: 'Getting model from sec agg',
MAIN_SERVER_SEND_MODEL_TO_CLIENTS: 'Sending model to clients',
}
if state not in strings:
raise Exception('State not recognized')
return strings[state]