diff --git a/src/zarr/testing/utils.py b/src/zarr/testing/utils.py index 0a93b93fdb..28d6774286 100644 --- a/src/zarr/testing/utils.py +++ b/src/zarr/testing/utils.py @@ -44,7 +44,7 @@ def has_cupy() -> bool: # Decorator for GPU tests def gpu_test(func: T_Callable) -> T_Callable: return cast( - T_Callable, + "T_Callable", pytest.mark.gpu( pytest.mark.skipif(not has_cupy(), reason="CuPy not installed or no GPU available")( func diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 33ac0266eb..d6175ad506 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -148,6 +148,34 @@ async def test_codecs_use_of_gpu_prototype() -> None: assert cp.array_equal(expect, got) +@gpu_test +@pytest.mark.asyncio +async def test_sharding_use_of_gpu_prototype() -> None: + with zarr.config.enable_gpu(): + expect = cp.zeros((10, 10), dtype="uint16", order="F") + + a = await zarr.api.asynchronous.create_array( + StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype", + shape=expect.shape, + chunks=(5, 5), + shards=(10, 10), + dtype=expect.dtype, + fill_value=0, + ) + expect[:] = cp.arange(100).reshape(10, 10) + + await a.setitem( + selection=(slice(0, 10), slice(0, 10)), + value=expect[:], + prototype=gpu.buffer_prototype, + ) + got = await a.getitem( + selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype + ) + assert isinstance(got, cp.ndarray) + assert cp.array_equal(expect, got) + + def test_numpy_buffer_prototype() -> None: buffer = cpu.buffer_prototype.buffer.create_zero_length() ndbuffer = cpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64")) @@ -155,3 +183,13 @@ def test_numpy_buffer_prototype() -> None: assert isinstance(ndbuffer.as_ndarray_like(), np.ndarray) with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"): ndbuffer.as_scalar() + + +@gpu_test +def test_gpu_buffer_prototype() -> None: + buffer = gpu.buffer_prototype.buffer.create_zero_length() + ndbuffer = gpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=cp.dtype("int64")) + assert isinstance(buffer.as_array_like(), cp.ndarray) + assert isinstance(ndbuffer.as_ndarray_like(), cp.ndarray) + with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"): + ndbuffer.as_scalar()