From 9144431ed6c2528f318c740ed398ddd82cb9f191 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sun, 12 Mar 2023 00:10:46 +0000 Subject: [PATCH] correct the special-case where baddbmm should ignore bias parameter (to ignore it when bias is scaled by a factor of 0). see https://pytorch.org/docs/stable/generated/torch.baddbmm.html --- coremltools/converters/mil/frontend/torch/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 370f4401a..afc68c2ee 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -5271,7 +5271,7 @@ def baddbmm(context, node): inputs = _get_inputs(context, node, expected=5) bias, batch1, batch2, beta, alpha = inputs - if beta.val != 1.0: + if beta.val != 0.0: # Apply scaling factor beta to the bias. bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled") context.add(bias)