Skip to content

Commit 68d45d6

Browse files
committed
mean_squared_difference
#38
1 parent 74f75ca commit 68d45d6

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

nn/loss.py

+12
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,15 @@ def kl_div(*, target: nn.Tensor, target_type: str,
105105
kl = nn.dot(nn.exp(log_target), log_target - log_est, reduce=axis)
106106

107107
return kl
108+
109+
110+
@nn.scoped
111+
def mean_squared_difference(a: nn.Tensor, b: nn.Tensor, *, axis: Optional[nn.Dim] = None) -> nn.Tensor:
112+
"""
113+
Mean squared difference between two tensors,
114+
i.e. mean_{axis}( (a - b) ** 2 ), where axis is the feature dim by default.
115+
"""
116+
if not axis:
117+
assert a.feature_dim
118+
axis = a.feature_dim
119+
return nn.reduce(nn.squared_difference(a, b), mode="mean", axis=axis)

0 commit comments

Comments
 (0)