Skip to content

Commit db5bb79

Browse files
committed
add python loss plot file, put all scripts in folder
1 parent 39b2a31 commit db5bb79

9 files changed

+101
-0
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

Diff for: crop_file.py renamed to scripts/crop_file.py

File renamed without changes.

Diff for: scripts/plot_loss.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/usr/bin/python
2+
import os
3+
import re
4+
import sys
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
TRAIN_LOSS_PATTERN = r"Iteration (\d+), loss = (\d+\.\d*)"
9+
#TEST_LOSS_PATTERN = r"Iteration (\d+), Testing net \(#0\)\n.*\n.*\n.*\n.* rec_loss = (\d+\.\d+)"
10+
#TEST_LOSS_PATTERN = r"Iteration (\d+), Testing net \(#0\)\n.*\n.* loss = (\d+\.\d+)"
11+
#TEST_ACC_PATTERN = r"Iteration (\d+), Testing net \(#0\)\n.* accuracy = (\d+\.\d+)"
12+
13+
def main():
14+
if len(sys.argv) > 1:
15+
log_file_name = sys.argv[1]
16+
else:
17+
raise("please provide log file to process")
18+
19+
log_file = open(log_file_name, 'r')
20+
log_data = log_file.read()
21+
training_result = re.findall(TRAIN_LOSS_PATTERN,log_data)
22+
#testing_result = re.findall(TEST_LOSS_PATTERN, log_data)
23+
#testing_accuracy = re.findall(TEST_ACC_PATTERN, log_data)
24+
25+
train_iter = []
26+
train_loss = []
27+
test_iter = []
28+
test_loss = []
29+
test_acc_iter = []
30+
test_acc = []
31+
32+
# test_loss_length = len(testing_result[0]) - 1
33+
for train in training_result:
34+
train_iter.append(int(train[0]))
35+
train_loss.append(float(train[1]))
36+
'''
37+
for test in testing_result:
38+
test_iter.append(int(test[0]))
39+
temp_loss = 0
40+
for i in range(test_loss_length):
41+
temp_loss += float(test[i+1])
42+
test_loss.append(temp_loss)
43+
'''
44+
#for test in testing_accuracy:
45+
# test_acc_iter.append(int(test[0]))
46+
# test_acc.append(float(test[1]))
47+
48+
#print test_iter
49+
#print test_loss
50+
# display
51+
plt.plot(train_iter, train_loss, 'k', label='Train loss', linewidth=0.75)
52+
#plt.plot(test_iter, test_loss, 'r', label='Test loss', linewidth=1.0)
53+
#plt.plot(test_acc_iter, test_acc, 'b', label='Test accuracy', linewidth=1.0)
54+
plt.legend()
55+
#plt.minorticks_on()
56+
plt.ylabel('Loss')
57+
plt.xlabel('Iteration')
58+
#plt.yticks(np.arange(0, 2.5, 0.1))
59+
plt.grid()
60+
plt.savefig(os.path.join(os.path.dirname(log_file_name), log_file_name) +'.png')
61+
62+
63+
def disp_results(fig, ax1, ax2, loss_iterations, losses, accuracy_iterations, accuracies, accuracies_iteration_checkpoints_ind, fileName, color_ind=0):
64+
modula = len(plt.rcParams['axes.color_cycle'])
65+
acrIterations =[]
66+
top_acrs={}
67+
if accuracies.size:
68+
if accuracies.size>4:
69+
top_n = 4
70+
else:
71+
top_n = accuracies.size -1
72+
temp = np.argpartition(-accuracies, top_n)
73+
result_indexces = temp[:top_n]
74+
temp = np.partition(-accuracies, top_n)
75+
result = -temp[:top_n]
76+
for acr in result_indexces:
77+
acrIterations.append(accuracy_iterations[acr])
78+
top_acrs[str(accuracy_iterations[acr])]=str(accuracies[acr])
79+
80+
sorted_top4 = sorted(top_acrs.items(), key=operator.itemgetter(1))
81+
maxAcc = np.amax(accuracies, axis=0)
82+
iterIndx = np.argmax(accuracies)
83+
maxAccIter = accuracy_iterations[iterIndx]
84+
maxIter = accuracy_iterations[-1]
85+
consoleInfo = format('\n[%s]:maximum accuracy [from 0 to %s ] = [Iteration %s]: %s ' %(fileName,maxIter,maxAccIter ,maxAcc))
86+
plotTitle = format('max accuracy(%s) [Iteration %s]: %s ' % (fileName,maxAccIter, maxAcc))
87+
print (consoleInfo)
88+
#print (str(result))
89+
#print(acrIterations)
90+
# print 'Top 4 accuracies:'
91+
print ('Top 4 accuracies:'+str(sorted_top4))
92+
plt.title(plotTitle)
93+
ax1.plot(loss_iterations, losses, color=plt.rcParams['axes.color_cycle'][(color_ind * 2 + 0) % modula])
94+
ax2.plot(accuracy_iterations, accuracies, plt.rcParams['axes.color_cycle'][(color_ind * 2 + 1) % modula], label=str(fileName))
95+
ax2.plot(accuracy_iterations[accuracies_iteration_checkpoints_ind], accuracies[accuracies_iteration_checkpoints_ind], 'o', color=plt.rcParams['axes.color_cycle'][(color_ind * 2 + 1) % modula])
96+
plt.legend(loc='lower right')
97+
98+
99+
100+
if __name__ == "__main__":
101+
main()

Diff for: profiler.py renamed to scripts/profiler.py

File renamed without changes.

Diff for: random_crop.py renamed to scripts/random_crop.py

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)