Skip to content

Commit

Permalink
[SPARK-43176][CONNECT][PYTHON][TESTS] Deduplicate imports in Connect …
Browse files Browse the repository at this point in the history
…Tests

### What changes were proposed in this pull request?
Deduplicate imports in Connect Tests

### Why are the changes needed?
for simplicity

### Does this PR introduce _any_ user-facing change?
No, test-only

### How was this patch tested?
updated unittests

Closes apache#40839 from zhengruifeng/connect_test_import.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Apr 19, 2023
1 parent f8604ad commit cac6f58
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 110 deletions.
10 changes: 0 additions & 10 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,6 @@ def test_collect(self):
)

def test_collect_timestamp(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(TIMESTAMP('2022-12-25 10:30:00'), 1),
Expand Down Expand Up @@ -652,10 +649,6 @@ def test_with_atom_type(self):

def test_with_none_and_nan(self):
# SPARK-41855: make createDataFrame support None and NaN

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

# SPARK-41814: test with eqNullSafe
data1 = [Row(id=1, value=float("NaN")), Row(id=2, value=42.0), Row(id=3, value=None)]
data2 = [Row(id=1, value=np.nan), Row(id=2, value=42.0), Row(id=3, value=None)]
Expand Down Expand Up @@ -1662,9 +1655,6 @@ def test_random_split(self):

def test_observe(self):
# SPARK-41527: test DataFrame.observe()
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

observation_name = "my_metric"

self.assert_eq(
Expand Down
15 changes: 1 addition & 14 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import decimal
import datetime

from pyspark.sql import functions as SF
from pyspark.sql.types import (
Row,
StructField,
Expand Down Expand Up @@ -48,6 +47,7 @@

if should_test_connect:
import pandas as pd
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import DistributedSequenceID, LiteralExpression
Expand Down Expand Up @@ -482,9 +482,6 @@ def test_literal_integers(self):
cdf = self.connect.range(0, 1)
sdf = self.spark.range(0, 1)

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

cdf1 = cdf.select(
CF.lit(0),
CF.lit(1),
Expand Down Expand Up @@ -679,9 +676,6 @@ def test_between(self):

def test_column_bitwise_ops(self):
# SPARK-41751: test bitwiseAND, bitwiseOR, bitwiseXOR
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(1, 1, 0), (2, NULL, 1), (3, 3, 4)
Expand Down Expand Up @@ -718,9 +712,6 @@ def test_column_bitwise_ops(self):
)

def test_column_accessor(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT STRUCT(a, b, c) AS x, y, z, c FROM VALUES
(float(1.0), double(1.0), '2022', MAP('b', '123', 'a', 'kk'), ARRAY(1, 2, 3)),
Expand Down Expand Up @@ -840,10 +831,6 @@ def test_column_arithmetic_ops(self):

def test_column_field_ops(self):
# SPARK-41767: test withField, dropFields

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT STRUCT(a, b, c, d) AS x, e FROM VALUES
(float(1.0), double(1.0), '2022', 1, 0),
Expand Down
96 changes: 10 additions & 86 deletions python/pyspark/sql/tests/connect/test_connect_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException

if should_test_connect:
from pyspark.sql.connect.column import Column
from pyspark.sql import functions as SF
from pyspark.sql.window import Window as SW
from pyspark.sql.dataframe import DataFrame as SDF
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.window import Window as CW
from pyspark.sql.connect.dataframe import DataFrame as CDF


class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils):
"""These test cases exercise the interface to the proto plan
Expand All @@ -47,9 +56,6 @@ def tearDownClass(cls):
del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]

def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
from pyspark.sql.dataframe import DataFrame as SDF
from pyspark.sql.connect.dataframe import DataFrame as CDF

assert isinstance(df1, (SDF, CDF))
if isinstance(df1, SDF):
str1 = df1._jdf.showString(n, truncate, False)
Expand All @@ -66,10 +72,6 @@ def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):

def test_count_star(self):
# SPARK-42099: test count(*), count(col(*)) and count(expr(*))

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")]

cdf = self.connect.createDataFrame(data, schema=["age", "name"])
Expand Down Expand Up @@ -123,9 +125,6 @@ def test_count_star(self):
)

def test_broadcast(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(0, float("NAN"), NULL), (1, NULL, 2.0), (2, 2.1, 3.5)
Expand Down Expand Up @@ -174,9 +173,6 @@ def test_broadcast(self):
)

def test_normal_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(0, float("NAN"), NULL), (1, NULL, 2.0), (2, 2.1, 3.5)
Expand Down Expand Up @@ -261,9 +257,6 @@ def test_normal_functions(self):
)

def test_when_otherwise(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(0, float("NAN"), NULL), (1, NULL, 2.0), (2, 2.1, 3.5), (3, 3.1, float("NAN"))
Expand Down Expand Up @@ -375,9 +368,6 @@ def test_when_otherwise(self):
)

def test_sorting_functions_with_column(self):
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.column import Column

funs = [
CF.asc_nulls_first,
CF.asc_nulls_last,
Expand All @@ -403,9 +393,6 @@ def test_sorting_functions_with_column(self):
self.assertIn("""DESC NULLS LAST'""", str(res))

def test_sort_with_nulls_order(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(false, 1, NULL), (true, NULL, 2.0), (NULL, 3, 3.0)
Expand Down Expand Up @@ -449,9 +436,6 @@ def test_sort_with_nulls_order(self):
)

def test_math_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(false, 1, NULL), (true, NULL, 2.0), (NULL, 3, 3.5)
Expand Down Expand Up @@ -571,9 +555,6 @@ def test_math_functions(self):
)

def test_aggregation_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(0, float("NAN"), NULL), (1, NULL, 2.0), (1, 2.1, 3.5), (0, 0.5, 1.0)
Expand Down Expand Up @@ -694,11 +675,6 @@ def test_aggregation_functions(self):
)

def test_window_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.window import Window as SW
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.window import Window as CW

self.assertEqual(CW.unboundedPreceding, SW.unboundedPreceding)

self.assertEqual(CW.unboundedFollowing, SW.unboundedFollowing)
Expand Down Expand Up @@ -950,12 +926,6 @@ def test_window_functions(self):

def test_window_order(self):
# SPARK-41773: test window function with order

from pyspark.sql import functions as SF
from pyspark.sql.window import Window as SW
from pyspark.sql.connect import functions as CF
from pyspark.sql.connect.window import Window as CW

data = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")]
# +---+--------+
# | id|category|
Expand Down Expand Up @@ -1000,9 +970,6 @@ def test_window_order(self):
)

def test_collection_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'),
Expand Down Expand Up @@ -1257,9 +1224,6 @@ def test_collection_functions(self):
)

def test_map_collection_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(MAP('a', 'ab'), MAP('x', 'ab'), MAP(1, 2, 3, 4), 1, 'a', ARRAY(1, 2), ARRAY('X', 'Y')),
Expand Down Expand Up @@ -1315,9 +1279,6 @@ def test_map_collection_functions(self):
)

def test_generator_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3),
Expand Down Expand Up @@ -1442,9 +1403,6 @@ def test_generator_functions(self):
)

def test_lambda_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a', NULL, MAP(0, 0)),
Expand Down Expand Up @@ -1619,10 +1577,6 @@ def test_lambda_functions(self):

def test_nested_lambda_function(self):
# SPARK-42089: test nested lambda function

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = "SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters"

cdf = self.connect.sql(query).select(
Expand Down Expand Up @@ -1652,9 +1606,6 @@ def test_nested_lambda_function(self):
self.assertEqual(cdf.collect(), sdf.collect())

def test_csv_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
('1,2,3', 'a,b,5.0'),
Expand Down Expand Up @@ -1732,9 +1683,6 @@ def test_csv_functions(self):
)

def test_json_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
('{"a": 1}', '[1, 2, 3]', '{"f1": "value1", "f2": "value2"}'),
Expand Down Expand Up @@ -1869,9 +1817,6 @@ def test_json_functions(self):
)

def test_string_functions_one_arg(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(' ab ', 'ab ', NULL), (' ab', NULL, 'ab')
Expand Down Expand Up @@ -1913,9 +1858,6 @@ def test_string_functions_one_arg(self):
)

def test_string_functions_multi_args(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(1, 'abcdef', 'ghij', 'hello world', 'a.b.c.d'),
Expand Down Expand Up @@ -2013,9 +1955,6 @@ def test_string_functions_multi_args(self):

# TODO(SPARK-41283): To compare toPandas for test cases with dtypes marked
def test_date_ts_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
('1997/02/28 10:30:00', '2023/03/01 06:00:00', 'JST', 1428476400, 2020, 12, 6),
Expand Down Expand Up @@ -2160,9 +2099,6 @@ def test_date_ts_functions(self):
)

def test_time_window_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT * FROM VALUES
(TIMESTAMP('2022-12-25 10:30:00'), 1),
Expand Down Expand Up @@ -2264,9 +2200,6 @@ def test_time_window_functions(self):
)

def test_misc_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT a, b, c, BINARY(c) as d FROM VALUES
(0, float("NAN"), 'x'), (1, NULL, 'y'), (1, 2.1, 'z'), (0, 0.5, NULL)
Expand Down Expand Up @@ -2329,9 +2262,6 @@ def test_misc_functions(self):
)

def test_call_udf(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT a, b, c, BINARY(c) as d FROM VALUES
(-1.0, float("NAN"), 'x'), (-2.1, NULL, 'y'), (1, 2.1, 'z'), (0, 0.5, NULL)
Expand Down Expand Up @@ -2360,9 +2290,6 @@ def test_call_udf(self):
)

def test_udf(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

query = """
SELECT a, b, c FROM VALUES
(1, 1.0, 'x'), (2, 2.0, 'y'), (3, 3.0, 'z')
Expand Down Expand Up @@ -2408,9 +2335,6 @@ def sfun(x):
)

def test_pandas_udf_import(self):
from pyspark.sql.connect import functions as CF
from pyspark.sql import functions as SF

self.assert_eq(getattr(CF, "pandas_udf"), getattr(SF, "pandas_udf"))


Expand Down

0 comments on commit cac6f58

Please sign in to comment.