-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_helpers.py
58 lines (42 loc) · 2.15 KB
/
tf_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
from keras import backend as K
class CER(tf.keras.metrics.Metric):
"""
A custom Keras metric to compute the Character Error Rate
"""
def __init__(self, name='CER', decode_greedy=True, **kwargs):
super(CER, self).__init__(name=name, **kwargs)
self.decode_greedy = decode_greedy
self.cer_accumulator = self.add_weight(name="total_cer", initializer="zeros")
self.counter = self.add_weight(name="cer_count", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
input_shape = K.shape(y_pred)
input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')
decode, log = K.ctc_decode(y_pred, input_length, greedy=True)
decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))
y_true_sparse = tf.sparse.retain(y_true_sparse, tf.not_equal(y_true_sparse.values, tf.math.reduce_max(y_true_sparse.values)))
decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
self.cer_accumulator.assign_add(tf.reduce_sum(distance))
self.counter.assign_add(K.cast(len(y_true), 'float32'))
def result(self):
return tf.math.divide_no_nan(self.cer_accumulator, self.counter)
def reset_state(self):
self.cer_accumulator.assign(0.0)
self.counter.assign(0.0)
def CTCLoss(y_true, y_pred):
"""
Compute the training-time loss value
"""
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = tf.math.count_nonzero(y_true, axis=1)
label_length = tf.expand_dims(label_length, axis=1)
loss = K.ctc_batch_cost(y_true, y_pred, input_length, label_length)
return loss
def warmup_tf_model(model, input_shapes):
for i in range(10):
model(tuple(tf.random.normal(s) for s in input_shapes))
return model