Skip to content

Commit 5b0dac8

Browse files
Gabriel FerrateGabriel Ferrate
Gabriel Ferrate
authored and
Gabriel Ferrate
committed
Losses
1 parent 81a59c4 commit 5b0dac8

6 files changed

+127
-0
lines changed

results/lossess/extract_losses.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from plot import make_plot
2+
from copy import deepcopy
3+
4+
5+
def extract_training_results(fn):
6+
data = {}
7+
with open(fn, 'r') as f:
8+
epoch = 0
9+
for l in f:
10+
if 'Train:' in l:
11+
s = l.split('\t')
12+
loss = float(s[3].split(' ')[1])
13+
data.setdefault(epoch, {}).setdefault('loss', []).append(loss)
14+
15+
if 'DONE' in l:
16+
epoch += 1
17+
return data
18+
19+
20+
def extract_val_results(fn):
21+
data = {}
22+
clustering = False
23+
last_clustering = False
24+
with open(fn, 'r') as f:
25+
epoch = 0
26+
for l in f:
27+
if 'clustering...' in l:
28+
clustering = True
29+
if 'Test:' in l:
30+
s = l.split('\t')
31+
#epoch = int(s[0].split('[')[1].split(']')[0])
32+
top_1 = float(s[-2].split(' ')[1].split(' ')[0])
33+
top_3 = float(s[-1].split(' ')[1].split(' ')[0])
34+
time = float(s[1].split(' ')[1].split(' ')[0])
35+
loss = float(s[-3].split(' ')[1])
36+
if clustering:
37+
data.setdefault(epoch, {}).setdefault('loss-clustering', []).append(loss)
38+
#data[epoch].setdefault('test-top3-clustering', []).append(top_3)
39+
else:
40+
data.setdefault(epoch, {}).setdefault('loss', []).append(loss)
41+
#data[epoch].setdefault('test-top3', []).append(top_3)
42+
43+
if len(data[epoch].get('loss-clustering', [])) == 1612 and clustering:
44+
epoch += 1
45+
clustering = False
46+
47+
return data
48+
49+
def avg(_list):
50+
return sum(_list) / len(_list)
51+
52+
def average_epochs(data):
53+
for epoch, d in data.items():
54+
d['loss'] = avg(d['loss'])
55+
#d['test-top3'] = avg(d['test-top3'])
56+
if 'loss-clustering' in d:
57+
d['loss-clustering'] = avg(d['loss-clustering'])
58+
#d['test-top3-clustering'] = avg(d['test-top3-clustering'])
59+
return data
60+
61+
def get_all_val_data(fn):
62+
data = extract_val_results(fn)
63+
average_epochs(data)
64+
return data
65+
66+
def get_all_training_data(fn):
67+
data = extract_training_results(fn)
68+
average_epochs(data)
69+
return data
70+
71+
if __name__ == '__main__':
72+
filenames = (
73+
('plots/9_clients_7_frames_iid.png', '../9_clients_7_frames/raw/logs_iid/{}', 'sec_agg.log', 9),
74+
('plots/9_clients_7_frames_non_iid.png', '../9_clients_7_frames/raw/logs_non_iid/{}', 'sec_agg.log', 9),
75+
('plots/5_clients_7_frames_iid.png', '../9_clients_7_frames/raw/logs_iid/{}', 'sec_agg.log', 5),
76+
('plots/5_clients_7_frames_non_iid.png', '../9_clients_7_frames/raw/logs_non_iid/{}', 'sec_agg.log', 5),
77+
)
78+
79+
datas = []
80+
for name, base_fn, secagg_file, n_clients in filenames:
81+
datas = []
82+
for i in range(n_clients):
83+
# Training Data
84+
client_file = f'client_{8004+i}.log'
85+
client_path = base_fn.format(client_file)
86+
training_data = get_all_training_data(client_path)
87+
datas.append((deepcopy(training_data), f'training-loss-client_{i}'))
88+
89+
# Validation Data
90+
secagg_path = base_fn.format(secagg_file)
91+
validation_data = get_all_val_data(secagg_path)
92+
datas.append((validation_data, 'validation-loss'))
93+
94+
make_plot(name, datas)

results/lossess/plot.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import matplotlib
2+
matplotlib.use('agg')
3+
import matplotlib.pyplot as plt
4+
import sys
5+
sys.path.insert(0,'..')
6+
7+
MARKER_SIZE=0
8+
9+
10+
def make_plot(name, datas):
11+
plt.figure()
12+
epochs = 200
13+
X_AXIS = list(map(lambda x: x+1, range(epochs)))
14+
for data, label in datas:
15+
y_values = [data[i]['loss'] for i in range(len(data))]
16+
plt.plot(X_AXIS,
17+
y_values,
18+
linewidth=1,
19+
label=label,
20+
linestyle='solid',
21+
marker="^",
22+
markersize=MARKER_SIZE)
23+
24+
plt.legend(loc="upper right")
25+
# Limits
26+
plt.axis(xmin=0, xmax=epochs+1)#, ymin=2, ymax=4.5)
27+
# Labels
28+
#plt.title('1 input frame comparison with 9 clients')
29+
plt.ylabel('Loss')
30+
plt.xlabel('Communication Round')
31+
# Save
32+
plt.savefig(name, dpi=400)
33+
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)