Open
Description
Hi guys,
I implemented the make_s_curve function and corresponding tests for it in #967.
However, one of the tests is failing consistently, even though I have verified the values in the outputs and the computation graph. Appreciate any feedback or suggestions on how to resolve this issue.
The test that failed.
@pytest.mark.parametrize(
"generator",
[
dask_ml.datasets.make_blobs,
dask_ml.datasets.make_classification,
dask_ml.datasets.make_counts,
dask_ml.datasets.make_regression,
dask_ml.datasets.make_s_curve,
],
)
def test_deterministic(generator, scheduler):
a, t = generator(chunks=100, random_state=10)
b, u = generator(chunks=100, random_state=10)
assert_eq(a, b)
assert_eq(t, u)
Environment:
Dask version: 2023.3.2
Dask_ml versuib: 2022.5.27
Python version: 3.10.9
Operating System: OSX
Install method (conda, pip, source): pip
Reproducible example
import numpy as np
import dask.array as da
from dask.array.utils import assert_eq
import dask_ml
def make_s_curve(
n_samples=100,
noise=0.0,
random_state=None,
chunks=None,
):
rng = dask_ml.utils.check_random_state(random_state)
t_scale = 3 * np.pi * 0.5
t = rng.uniform(low=-t_scale, high=t_scale, size=(n_samples), chunks=(chunks,))
X = da.empty(shape=(n_samples, 3), chunks=(chunks, 3), dtype="f8")
X[:, 0] = da.sin(t)
X[:, 1] = rng.uniform(low=0, high=2, size=n_samples, chunks=(chunks,))
X[:, 2] = da.sign(t) * (da.cos(t) - 1)
if noise > 0:
X += rng.normal(scale=noise, size=X.shape, chunks=X.chunks)
return X, t
if __name__ == '__main__':
a, t = make_s_curve(chunks=100, random_state=10)
b, u = make_s_curve(chunks=100, random_state=10)
assert_eq(a, b)
assert_eq(t, u)
Traceback
Traceback (most recent call last):
File "/Users/dask_mltest.py", line 30, in <module>
assert_eq(a, b)
File "/Users/venv/lib/python3.9/site-packages/dask/array/utils.py", line 304, in assert_eq
a, adt, a_meta, a_computed = _get_dt_meta_computed(
File "/Users/venv/lib/python3.9/site-packages/dask/array/utils.py", line 259, in _get_dt_meta_computed
_check_dsk(x.dask)
File "/Users/venv/lib/python3.9/site-packages/dask/array/utils.py", line 210, in _check_dsk
assert not non_one, non_one
AssertionError: {('uniform-c022afab3445a4ac294ad46da60634e2', 0): 2}
Metadata
Metadata
Assignees
Labels
No labels