Skip to content

Commit

Permalink
Merge pull request #64 from willow-ahrens/dtype-sum-prod
Browse files Browse the repository at this point in the history
Refine `dtype` argument for `sum` and `prod`
  • Loading branch information
mtsokol authored Jun 6, 2024
2 parents 66802c0 + 3596a75 commit 12eaa4e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.1.24"
version = "0.1.25"
description = ""
authors = ["Willow Ahrens <[email protected]>"]
readme = "README.md"
Expand Down
70 changes: 60 additions & 10 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,14 +743,14 @@ def astype(x: Tensor, dtype: DType, /, *, copy: bool = True):
if not copy:
if x.dtype == dtype:
return x
else:
if copy is False:
raise ValueError("Unable to avoid a copy while casting in no-copy mode.")
else:
finch_tns = x._obj.body
result = jl.copyto_b(
jl.similar(finch_tns, jc.convert(dtype, jl.default(finch_tns)), dtype), finch_tns
)
return Tensor(jl.swizzle(result, *x.get_order(zero_indexing=False)))

finch_tns = x._obj.body
result = jl.copyto_b(
jl.similar(finch_tns, jc.convert(dtype, jl.default(finch_tns)), dtype), finch_tns
)
return Tensor(jl.swizzle(result, *x.get_order(zero_indexing=False)))


def where(condition: Tensor, x1: Tensor, x2: Tensor, /) -> Tensor:
Expand Down Expand Up @@ -778,13 +778,63 @@ def nonzero(x: Tensor, /) -> tuple[np.ndarray, ...]:
return tuple(Tensor(i[sort_order]) for i in indices)


def _reduce(x: Tensor, fn: Callable, axis, dtype=None):
def _reduce_core(x: Tensor, fn: Callable, axis: int | tuple[int, ...] | None):
if axis is not None:
axis = normalize_axis_tuple(axis, x.ndim)
axis = tuple(i + 1 for i in axis)
result = fn(x._obj, dims=axis)
else:
result = fn(x._obj)
return result


def _reduce_sum_prod(
x: Tensor,
fn: Callable,
axis: int | tuple[int, ...] | None,
dtype: DType | None,
) -> Tensor:
result = _reduce_core(x, fn, axis)

if np.isscalar(result):
if jl.seval(f"{x.dtype} <: Integer"):
tmp_dtype = jl_dtypes.int_
else:
tmp_dtype = x.dtype
result = jl.Tensor(
jl.Element(
jc.convert(tmp_dtype, 0),
np.array(result, dtype=jl_dtypes.jl_to_np_dtype[tmp_dtype])
)
)

result = Tensor(result)

if jl.isa(result._obj, jl.Finch.LazyTensor):
if dtype is not None:
raise ValueError(
"`dtype` keyword for `sum` and `prod` in the lazy mode isn't supported"
)
# dtype casting rules
elif dtype is not None:
result = astype(result, dtype, copy=None)
elif jl.seval(f"{x.dtype} <: Unsigned"):
result = astype(result, jl_dtypes.uint, copy=None)
elif jl.seval(f"{x.dtype} <: Signed"):
result = astype(result, jl_dtypes.int_, copy=None)

return result


def _reduce(x: Tensor, fn: Callable, axis: int | tuple[int, ...] | None):
result = _reduce_core(x, fn, axis)
if np.isscalar(result):
result = jl.Tensor(
jl.Element(
jc.convert(x.dtype, 0),
np.array(result, dtype=jl_dtypes.jl_to_np_dtype[x.dtype])
)
)
return Tensor(result)


Expand All @@ -796,7 +846,7 @@ def sum(
dtype: DType | None = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.sum, axis, dtype)
return _reduce_sum_prod(x, jl.sum, axis, dtype)


def prod(
Expand All @@ -807,7 +857,7 @@ def prod(
dtype: DType | None = None,
keepdims: bool = False,
) -> Tensor:
return _reduce(x, jl.prod, axis, dtype)
return _reduce_sum_prod(x, jl.prod, axis, dtype)


def max(
Expand Down
23 changes: 21 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def test_elemwise_tensor_ops_2_args(arr3d, meth_name):

@pytest.mark.parametrize("func_name", ["sum", "prod", "max", "min", "any", "all"])
@pytest.mark.parametrize("axis", [None, -1, 1, (0, 1), (0, 1, 2)])
@pytest.mark.parametrize("dtype", [None]) # not supported yet
def test_reductions(arr3d, func_name, axis, dtype):
def test_reductions(arr3d, func_name, axis):
A_finch = finch.Tensor(arr3d)

actual = getattr(finch, func_name)(A_finch, axis=axis)
Expand All @@ -168,6 +167,26 @@ def test_reductions(arr3d, func_name, axis, dtype):
assert_equal(actual.todense(), expected)


@pytest.mark.parametrize("func_name", ["sum", "prod"])
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize(
"in_dtype, dtype, expected_dtype",
[
(finch.int64, None, np.int64),
(finch.int16, None, np.int64),
(finch.uint8, None, np.uint64),
(finch.int64, finch.float32, np.float32),
(finch.float64, finch.complex128, np.complex128),
],
)
def test_sum_prod_dtype_arg(arr3d, func_name, axis, in_dtype, dtype, expected_dtype):
arr_finch = finch.asarray(np.abs(arr3d), dtype=in_dtype)

actual = getattr(finch, func_name)(arr_finch, axis=axis, dtype=dtype).todense()

assert actual.dtype == expected_dtype


@pytest.mark.parametrize(
"storage",
[
Expand Down

0 comments on commit 12eaa4e

Please sign in to comment.