-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstats.py
68 lines (54 loc) · 2.47 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
class StatsTracker(object):
"""
Container for tracking the statistics associated with an epoch. For each of
a training and validation pass, a new StatsTracker should be instantiated.
The common use pattern of the class looks as follows::
for e in range(num_epochs):
stats = StatsTracker()
# Add some loss stats
stats.update_loss(loss)
# Add accuracy metrics
stats.update_accuracies(decoded_output, labels, true_labels, mask)
# Get current average stats
a, b, c = stats.averages()
"""
def __init__(self):
self.loss = 0.
# Number of correct samples from the view of reconstruction-accuracy.
self.num_reconstruction_match = 0
# Number of correct samples from the view of overall-accuracy.
self.num_overall_match = 0
# Hold different counters for the number of loss and accuracy attempts.
# Losses are added in the unit of the average for a minibatch, while
# accuracy metrics are added for individual samples.
self.num_loss_attempts = 0
self.num_match_attempts = 0
def averages(self):
"""
Returns average loss, reconstruction-accuracy, and overall-accuracy
since this ``StatsTracker`` was instantiated.
"""
avg_loss = self.loss / self.num_loss_attempts
avg_recon_acc = self.num_reconstruction_match / self.num_match_attempts
avg_overall_acc = self.num_overall_match / self.num_match_attempts
return avg_loss, avg_recon_acc, avg_overall_acc
def update_accuracies(self, decoded, base_model_outputs, true_labels, mask):
"""
Calculates the number of decoded outputs that match (1) the outputs
from the base model and (2) the true labels associated with the decoded
sample. These results are maintained for later aggregate statistics.
"""
self.num_match_attempts += decoded.size(0)
max_decoded = torch.max(decoded, dim=2)[1]
max_outputs = torch.max(base_model_outputs, dim=2)[1]
self.num_reconstruction_match += torch.sum(
(max_decoded == max_outputs) * mask).item()
self.num_overall_match += torch.sum(
(max_decoded == true_labels) * mask).item()
def update_loss(self, loss):
"""
Adds ``loss`` to the current aggregate loss for this epoch.
"""
self.loss += loss
self.num_loss_attempts += 1