Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def impl(*inputs_and_core_shapes):
output_bc_patterns,
output_dtypes,
inplace_pattern,
False, # allow_core_scalar
(), # constant_inputs
inputs,
tuple_core_shapes,
Expand All @@ -98,6 +99,7 @@ def impl(*inputs_and_core_shapes):
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key = None
else:
blockwise_cache_version = 1
blockwise_key = "_".join(
map(
str,
Expand All @@ -108,6 +110,7 @@ def impl(*inputs_and_core_shapes):
blockwise_op.signature,
input_bc_patterns,
core_op_key,
blockwise_cache_version,
),
)
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
fgraph, squeeze_output=True, **kwargs
fgraph, squeeze_output=True, fgraph_name="numba_ofg", **kwargs
)

if fgraph_cache_key is None:
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def impl(*inputs):
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
inputs,
core_output_shapes, # core_shapes
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def impl(core_shape, rng, size, *dist_params):
output_bc_patterns,
output_dtypes,
inplace_pattern,
True, # allow_core_scalar
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None)

composite_fn, fgraph_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
op.fgraph, squeeze_output=True, fgraph_name="numba_composite", **kwargs
)
if fgraph_key is None:
composite_key = None
Expand Down
4 changes: 3 additions & 1 deletion pytensor/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)

scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan"
)

outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
Expand Down
27 changes: 19 additions & 8 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _vectorized(
output_bc_patterns,
output_dtypes,
inplace_pattern,
allow_core_scalar,
constant_inputs_types,
input_types,
output_core_shape_types,
Expand All @@ -93,6 +94,7 @@ def _vectorized(
output_bc_patterns,
output_dtypes,
inplace_pattern,
allow_core_scalar,
constant_inputs_types,
input_types,
output_core_shape_types,
Expand All @@ -119,6 +121,10 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))

if not isinstance(allow_core_scalar, types.Literal):
raise TypeError("allow_core_scalar must be literal.")
allow_core_scalar = allow_core_scalar.literal_value

batch_ndim = len(input_bc_patterns[0])
nin = len(constant_inputs_types) + len(input_types)
nout = len(output_bc_patterns)
Expand All @@ -142,8 +148,7 @@ def _vectorized(
core_input_types = []
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
core_ndim = input_type.ndim - len(bc_pattern)
# TODO: Reconsider this
if core_ndim == 0:
if allow_core_scalar and core_ndim == 0:
core_input_type = input_type.dtype
else:
core_input_type = types.Array(
Expand Down Expand Up @@ -196,7 +201,7 @@ def codegen(
sig,
args,
):
[_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
[_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args

constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
inputs = cgutils.unpack_tuple(builder, inputs)
Expand Down Expand Up @@ -256,6 +261,7 @@ def codegen(
output_bc_patterns_val,
input_types,
output_types,
core_scalar=allow_core_scalar,
)

if len(outputs) == 1:
Expand Down Expand Up @@ -429,6 +435,7 @@ def make_loop_call(
output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...],
output_types: tuple[Any, ...],
core_scalar: bool = True,
):
safe = (False, False)

Expand Down Expand Up @@ -486,7 +493,7 @@ def make_loop_call(
idxs_bc,
*safe,
)
if core_ndim == 0:
if core_scalar and core_ndim == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if core_ndim == 0 but not core_scalar (the old case)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old case is this branch, it's what happens with elemwise. In that case the inner function gets a scalar not a 0d array.

# Retrive scalar item at index
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
Expand All @@ -499,15 +506,19 @@ def make_loop_call(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
core_shape = cgutils.unpack_tuple(builder, input.shape)[
input_type.ndim - core_ndim :
]
core_strides = cgutils.unpack_tuple(builder, input.strides)[
input_type.ndim - core_ndim :
]
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape),
strides=cgutils.pack_array(builder, core_strides),
shape=core_shape,
strides=core_strides,
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
Expand Down
42 changes: 40 additions & 2 deletions tests/link/numba/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import pytest

from pytensor import function
from pytensor.tensor import lvector, tensor, tensor3
from pytensor.graph import Apply
from pytensor.scalar import ScalarOp
from pytensor.tensor import TensorVariable, lvector, tensor, tensor3, vector
from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
Expand Down Expand Up @@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle():
)
out = blockwise_scalar_ds(x)
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)


def test_blockwise_vs_elemwise_scalar_op():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1760

class TestScalarOp(ScalarOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
[x] = inputs
if isinstance(node.inputs[0], TensorVariable):
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, np.number | float)
out = x + 1
if isinstance(node.outputs[0], TensorVariable):
out = np.asarray(out)
outputs[0][0] = out

x = vector("x")
y = Elemwise(TestScalarOp())(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], y, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])

z = Blockwise(TestScalarOp(), signature="()->()")(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], z, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
Loading