diff --git a/examples/audio/transformer_asr.py b/examples/audio/transformer_asr.py index f7b1d7130e..b661c73c83 100644 --- a/examples/audio/transformer_asr.py +++ b/examples/audio/transformer_asr.py @@ -247,7 +247,7 @@ def train_step(self, batch): preds = self([source, dec_input]) one_hot = tf.one_hot(dec_target, depth=self.num_classes) mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) + loss = self.compute_loss(None, one_hot, preds, sample_weight=mask) trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) @@ -262,7 +262,7 @@ def test_step(self, batch): preds = self([source, dec_input]) one_hot = tf.one_hot(dec_target, depth=self.num_classes) mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) + loss = self.compute_loss(None, one_hot, preds, sample_weight=mask) self.loss_metric.update_state(loss) return {"loss": self.loss_metric.result()}