Skip to content

Commit 7aeb4f6

Browse files
committed
Fix masking for density channel
1 parent b8009a9 commit 7aeb4f6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

neuralprocesses/coders/setconv/density.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def code(coder: PrependDensityChannel, xz, z: B.Numeric, x, **kw_args):
5858
@_dispatch
5959
def code(coder: PrependDensityChannel, xz, z: Masked, x, **kw_args):
6060
mask = z.mask
61-
# Set the missing values to zero. Zeros in the data channel do not affect the
62-
# encoding.
63-
z = z.y * mask
64-
return code(coder, xz, z, x, **kw_args)
61+
d = data_dims(xz)
62+
# Set the missing values to zero by multiplying with the mask. Zeros in the data
63+
# channel do not affect the encoding.
64+
return xz, B.concat(z.mask, z.y * z.mask, axis=-d - 1)
6565

6666

6767
@register_module

0 commit comments

Comments
 (0)