Skip to content

Commit

Permalink
Work around dmlc#10994
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Feb 5, 2025
1 parent 93d54ca commit 0ae1453
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from hypothesis import given, note, settings, strategies
from hypothesis._settings import duration
from packaging.version import parse as parse_version

import xgboost as xgb
from xgboost import testing as tm
Expand Down Expand Up @@ -41,14 +42,20 @@
try:
import cudf
import dask.dataframe as dd
from dask import __version__ as dask_version
from dask import array as da
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

from xgboost import dask as dxgb
from xgboost.testing.dask import check_init_estimation, check_uneven_nan
except ImportError:
pass
dask_version = None


dask_version_ge110 = dask_version and parse_version(dask_version) >= parse_version(
"2024.11.0"
)


def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None:
Expand Down Expand Up @@ -378,6 +385,9 @@ def test_early_stopping(self, local_cuda_client: Client) -> None:
dump = booster.get_dump(dump_format="json")
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1

@pytest.mark.xfail(
dask_version_ge110, reason="Test cannot pass with Dask 2024.11.0+"
)
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.parametrize("model", ["boosting"])
def test_dask_classifier(self, model: str, local_cuda_client: Client) -> None:
Expand Down

0 comments on commit 0ae1453

Please sign in to comment.