Skip to content

Commit f53c3e0

Browse files
authored
Unblock CI failures from scikit-learn 1.4.0, pandas 2.2.0 (#1295)
* Pin sklearn to <1.4 * Unpin sqlalchemy<2 * Refactor pyhive input/tests for sqlalchemy 2 * Use astype to normalize dtypes in _assert_query_gives_same_result * Refine pd.NA normalization in _assert_query_gives_same_result * Explicitly compute pandas result in test_join_reorder * xfail tpot tests, unpin sklearn * Linting
1 parent 3ffba21 commit f53c3e0

File tree

11 files changed

+74
-29
lines changed

11 files changed

+74
-29
lines changed

continuous_integration/environment-3.10.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies:
2828
- python=3.10
2929
- scikit-learn>=1.0.0
3030
- sphinx
31-
- sqlalchemy<2
31+
- sqlalchemy
3232
- tpot>=0.12.0
3333
# FIXME: https://github.com/fugue-project/fugue/issues/526
3434
- triad<0.9.2

continuous_integration/environment-3.11.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies:
2828
- python=3.11
2929
- scikit-learn>=1.0.0
3030
- sphinx
31-
- sqlalchemy<2
31+
- sqlalchemy
3232
- tpot>=0.12.0
3333
# FIXME: https://github.com/fugue-project/fugue/issues/526
3434
- triad<0.9.2

continuous_integration/environment-3.12.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies:
2929
- python=3.12
3030
- scikit-learn>=1.0.0
3131
- sphinx
32-
- sqlalchemy<2
32+
- sqlalchemy
3333
- tpot>=0.12.0
3434
# FIXME: https://github.com/fugue-project/fugue/issues/526
3535
- triad<0.9.2

continuous_integration/environment-3.9.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies:
2828
- python=3.9
2929
- scikit-learn=1.0.0
3030
- sphinx
31-
- sqlalchemy<2
31+
- sqlalchemy
3232
- tpot>=0.12.0
3333
# FIXME: https://github.com/fugue-project/fugue/issues/526
3434
- triad<0.9.2

continuous_integration/gpuci/environment-3.10.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies:
3333
- python=3.10
3434
- scikit-learn>=1.0.0
3535
- sphinx
36-
- sqlalchemy<2
36+
- sqlalchemy
3737
- tpot>=0.12.0
3838
# FIXME: https://github.com/fugue-project/fugue/issues/526
3939
- triad<0.9.2

continuous_integration/gpuci/environment-3.9.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies:
3333
- python=3.9
3434
- scikit-learn>=1.0.0
3535
- sphinx
36-
- sqlalchemy<2
36+
- sqlalchemy
3737
- tpot>=0.12.0
3838
# FIXME: https://github.com/fugue-project/fugue/issues/526
3939
- triad<0.9.2

dask_sql/input_utils/hive.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ class HiveInputPlugin(BaseInputPlugin):
3030
def is_correct_input(
3131
self, input_item: Any, table_name: str, format: str = None, **kwargs
3232
):
33-
is_sqlalchemy_hive = sqlalchemy and isinstance(
34-
input_item, sqlalchemy.engine.base.Connection
35-
)
3633
is_hive_cursor = hive and isinstance(input_item, hive.Cursor)
3734

38-
return is_sqlalchemy_hive or is_hive_cursor or format == "hive"
35+
return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive"
36+
37+
def is_sqlalchemy_hive(self, input_item: Any):
38+
return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection)
3939

4040
def to_dc(
4141
self,
@@ -201,7 +201,11 @@ def _parse_hive_table_description(
201201
of the DESCRIBE FORMATTED call, which is unfortunately
202202
in a format not easily readable by machines.
203203
"""
204-
cursor.execute(f"USE {schema}")
204+
cursor.execute(
205+
sqlalchemy.text(f"USE {schema}")
206+
if self.is_sqlalchemy_hive(cursor)
207+
else f"USE {schema}"
208+
)
205209
if partition:
206210
# Hive wants quoted, comma separated list of partition keys
207211
partition = partition.replace("=", '="')
@@ -283,7 +287,11 @@ def _parse_hive_partition_description(
283287
"""
284288
Extract all partition informaton for a given table
285289
"""
286-
cursor.execute(f"USE {schema}")
290+
cursor.execute(
291+
sqlalchemy.text(f"USE {schema}")
292+
if self.is_sqlalchemy_hive(cursor)
293+
else f"USE {schema}"
294+
)
287295
result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}")
288296

289297
return [row[0] for row in result]
@@ -298,7 +306,9 @@ def _fetch_all_results(
298306
The former has the fetchall method on the cursor,
299307
whereas the latter on the executed query.
300308
"""
301-
result = cursor.execute(sql)
309+
result = cursor.execute(
310+
sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql
311+
)
302312

303313
try:
304314
return result.fetchall()

tests/integration/fixtures.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,10 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
335335
# as expressions are handled differently
336336
dask_result.columns = sql_result.columns
337337

338+
# replace all pd.NA scalars, which are resistent to
339+
# check_dype=False and .astype()
340+
dask_result = dask_result.replace({pd.NA: None})
341+
338342
if sort_columns:
339343
sql_result = sql_result.sort_values(sort_columns)
340344
dask_result = dask_result.sort_values(sort_columns)

tests/integration/test_hive.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,25 +142,43 @@ def hive_cursor():
142142

143143
# Create a non-partitioned column
144144
cursor.execute(
145-
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
145+
sqlalchemy.text(
146+
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
147+
)
146148
)
147-
cursor.execute("INSERT INTO df (i, j) VALUES (1, 2)")
148-
cursor.execute("INSERT INTO df (i, j) VALUES (2, 4)")
149+
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)"))
150+
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)"))
149151

150152
cursor.execute(
151-
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
153+
sqlalchemy.text(
154+
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
155+
)
156+
)
157+
cursor.execute(
158+
sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
159+
)
160+
cursor.execute(
161+
sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")
152162
)
153-
cursor.execute("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
154-
cursor.execute("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")
155163

156164
cursor.execute(
157-
f"""
165+
sqlalchemy.text(
166+
f"""
158167
CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING)
159168
ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}'
160169
"""
170+
)
171+
)
172+
cursor.execute(
173+
sqlalchemy.text(
174+
"INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)"
175+
)
176+
)
177+
cursor.execute(
178+
sqlalchemy.text(
179+
"INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)"
180+
)
161181
)
162-
cursor.execute("INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)")
163-
cursor.execute("INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)")
164182

165183
# The data files are created as root user by default. Change that:
166184
hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir])

tests/integration/test_join.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ def test_join_reorder(c):
463463
SELECT a1, b2, c3
464464
FROM a, b, c
465465
WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2
466-
LIMIT 10
467466
"""
468467

469468
explain_string = c.explain(query)
@@ -491,15 +490,20 @@ def test_join_reorder(c):
491490
assert explain_string.index(second_join) < explain_string.index(first_join)
492491

493492
result_df = c.sql(query)
494-
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4] * 10})
495-
assert_eq(result_df, expected_df)
493+
merged_df = df.merge(df2, left_on="a1", right_on="b1").merge(
494+
df3, left_on="b2", right_on="c2"
495+
)
496+
expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][
497+
["a1", "b2", "c3"]
498+
]
499+
500+
assert_eq(result_df, expected_df, check_index=False)
496501

497502
# By default, join reordering should NOT reorder unfiltered dimension tables
498503
query = """
499504
SELECT a1, b2, c3
500505
FROM a, b, c
501506
WHERE a1 = b1 AND b2 = c2
502-
LIMIT 10
503507
"""
504508

505509
explain_string = c.explain(query)
@@ -510,8 +514,11 @@ def test_join_reorder(c):
510514
assert explain_string.index(second_join) < explain_string.index(first_join)
511515

512516
result_df = c.sql(query)
513-
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4, 5] * 5})
514-
assert_eq(result_df, expected_df)
517+
expected_df = df.merge(df2, left_on="a1", right_on="b1").merge(
518+
df3, left_on="b2", right_on="c2"
519+
)[["a1", "b2", "c3"]]
520+
521+
assert_eq(result_df, expected_df, check_index=False)
515522

516523

517524
@pytest.mark.xfail(

0 commit comments

Comments
 (0)