diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 722f0615b370a..ad4013f4a753b 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -569,6 +569,30 @@ def repartitionByRange( # type: ignore[misc] }, ) + def repartitionById( + self, numPartitions: int, partitionIdCol: "ColumnOrName" + ) -> ParentDataFrame: + if not isinstance(numPartitions, int) or isinstance(numPartitions, bool): + raise PySparkTypeError( + errorClass="NOT_INT", + messageParameters={ + "arg_name": "numPartitions", + "arg_type": type(numPartitions).__name__, + }, + ) + if numPartitions <= 0: + raise PySparkValueError( + errorClass="VALUE_NOT_POSITIVE", + messageParameters={ + "arg_name": "numPartitions", + "arg_value": str(numPartitions), + }, + ) + return DataFrame( + self._jdf.repartitionById(numPartitions, _to_java_column(partitionIdCol)), + self.sparkSession, + ) + def distinct(self) -> ParentDataFrame: return DataFrame(self._jdf.distinct(), self.sparkSession) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7998167976026..fb33260689c3b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -443,6 +443,45 @@ def repartitionByRange( # type: ignore[misc] res._cached_schema = self._cached_schema return res + def repartitionById( + self, numPartitions: int, partitionIdCol: "ColumnOrName" + ) -> ParentDataFrame: + if not isinstance(numPartitions, int) or isinstance(numPartitions, bool): + raise PySparkTypeError( + errorClass="NOT_INT", + messageParameters={ + "arg_name": "numPartitions", + "arg_type": type(numPartitions).__name__, + }, + ) + if numPartitions <= 0: + raise PySparkValueError( + errorClass="VALUE_NOT_POSITIVE", + messageParameters={ + "arg_name": "numPartitions", + "arg_value": str(numPartitions), + }, + ) + + from pyspark.sql.connect.expressions import DirectShufflePartitionID + from pyspark.sql.connect.column import Column + + # Convert the partition column to a DirectShufflePartitionID expression + if isinstance(partitionIdCol, str): + partition_col = F.col(partitionIdCol) + else: + partition_col = partitionIdCol + + direct_partition_expr = DirectShufflePartitionID(partition_col._expr) + direct_partition_col = Column(direct_partition_expr) + + res = DataFrame( + plan.RepartitionByExpression(self._plan, numPartitions, [direct_partition_col]), + self._session, + ) + res._cached_schema = self._cached_schema + return res + def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: if subset is not None and not isinstance(subset, (list, tuple)): raise PySparkTypeError( diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 4ddf13757db41..b7a971d5a299d 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -1317,3 +1317,24 @@ def __repr__(self) -> str: repr_parts.append(f"values={self._in_subquery_values}") return f"SubqueryExpression({', '.join(repr_parts)})" + + +class DirectShufflePartitionID(Expression): + """ + Expression that takes a partition ID value and passes it through directly for use in + shuffle partitioning. This is used with RepartitionByExpression to allow users to + directly specify target partition IDs. + """ + + def __init__(self, child: Expression): + super().__init__() + assert child is not None and isinstance(child, Expression) + self._child = child + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + expr = self._create_proto_expression() + expr.direct_shuffle_partition_id.child.CopyFrom(self._child.to_plan(session)) + return expr + + def __repr__(self) -> str: + return f"DirectShufflePartitionID(child={self._child})" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 675d972e3ef51..8450cb9873a89 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1887,6 +1887,67 @@ def repartitionByRange( """ ... + @dispatch_df_method + def repartitionById(self, numPartitions: int, partitionIdCol: "ColumnOrName") -> "DataFrame": + """ + Returns a new :class:`DataFrame` partitioned by the given partition ID expression. + Each row's target partition is determined directly by the value of the partition ID column. + + .. versionadded:: 4.1.0 + + .. versionchanged:: 4.1.0 + Supports Spark Connect. + + Parameters + ---------- + numPartitions : int + target number of partitions + partitionIdCol : str or :class:`Column` + column expression that evaluates to the target partition ID for each row. + Must be an integer type. Values are taken modulo numPartitions to determine + the final partition. Null values are sent to partition 0. + + Returns + ------- + :class:`DataFrame` + Repartitioned DataFrame. + + Notes + ----- + The partition ID expression must evaluate to an integer type. + Partition IDs are taken modulo numPartitions, so values outside the range [0, numPartitions) + are automatically mapped to valid partition IDs. If the partition ID expression evaluates to + a NULL value, the row is sent to partition 0. + + This method provides direct control over partition placement, similar to RDD's + partitionBy with custom partitioners, but at the DataFrame level. + + Examples + -------- + Partition rows based on a computed partition ID: + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.functions import col + >>> df = spark.range(10).withColumn("partition_id", (col("id") % 3).cast("int")) + >>> repartitioned = df.repartitionById(3, "partition_id") + >>> repartitioned.select("id", "partition_id", sf.spark_partition_id()).orderBy("id").show() + +---+------------+--------------------+ + | id|partition_id|SPARK_PARTITION_ID()| + +---+------------+--------------------+ + | 0| 0| 0| + | 1| 1| 1| + | 2| 2| 2| + | 3| 0| 0| + | 4| 1| 1| + | 5| 2| 2| + | 6| 0| 0| + | 7| 1| 1| + | 8| 2| 2| + | 9| 0| 0| + +---+------------+--------------------+ + """ + ... + @dispatch_df_method def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests/test_repartition.py b/python/pyspark/sql/tests/test_repartition.py index 058861e9c1615..862f14cb50b82 100644 --- a/python/pyspark/sql/tests/test_repartition.py +++ b/python/pyspark/sql/tests/test_repartition.py @@ -17,7 +17,7 @@ import unittest -from pyspark.sql.functions import spark_partition_id +from pyspark.sql.functions import spark_partition_id, col, lit, when from pyspark.sql.types import ( StringType, IntegerType, @@ -25,7 +25,7 @@ StructType, StructField, ) -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -84,6 +84,120 @@ def test_repartition_by_range(self): messageParameters={"arg_name": "numPartitions", "arg_type": "list"}, ) + def test_repartition_by_id(self): + # Test basic partition ID passthrough behavior + numPartitions = 10 + df = self.spark.range(100).withColumn("expected_p_id", col("id") % numPartitions) + repartitioned = df.repartitionById(numPartitions, col("expected_p_id").cast("int")) + result = repartitioned.withColumn("actual_p_id", spark_partition_id()) + + # All rows should be in their expected partitions + self.assertEqual(result.filter(col("expected_p_id") != col("actual_p_id")).count(), 0) + + def test_repartition_by_id_negative_values(self): + df = self.spark.range(10).toDF("id") + repartitioned = df.repartitionById(10, (col("id") - 5).cast("int")) + result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() + + for row in result: + actualPartitionId = row["actual_p_id"] + id_val = row["id"] + expectedPartitionId = int((id_val - 5) % 10) + self.assertEqual( + actualPartitionId, + expectedPartitionId, + f"Row with id={id_val} should be in partition {expectedPartitionId}, " + f"but was in partition {actualPartitionId}", + ) + + def test_repartition_by_id_null_values(self): + # Test that null partition ids go to partition 0 + df = self.spark.range(10).toDF("id") + partitionExpr = when(col("id") < 5, col("id")).otherwise(lit(None)).cast("int") + repartitioned = df.repartitionById(10, partitionExpr) + result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() + + nullRows = [row for row in result if row["id"] >= 5] + self.assertTrue(len(nullRows) > 0, "Should have rows with null partition expression") + for row in nullRows: + self.assertEqual( + row["actual_p_id"], + 0, + f"Row with null partition id should go to partition 0, " + f"but went to partition {row['actual_p_id']}", + ) + + nonNullRows = [row for row in result if row["id"] < 5] + for row in nonNullRows: + id_val = row["id"] + actualPartitionId = row["actual_p_id"] + expectedPartitionId = id_val % 10 + self.assertEqual( + actualPartitionId, + expectedPartitionId, + f"Row with id={id_val} should be in partition {expectedPartitionId}, " + f"but was in partition {actualPartitionId}", + ) + + def test_repartition_by_id_error_non_int_type(self): + # Test error for non-integer partition column type + df = self.spark.range(5).withColumn("s", lit("a")) + with self.assertRaises(Exception): # Should raise analysis exception + df.repartitionById(5, col("s")).collect() + + def test_repartition_by_id_error_invalid_num_partitions(self): + df = self.spark.range(5) + + with self.assertRaises(PySparkTypeError) as pe: + df.repartitionById("5", col("id").cast("int")) + self.check_error( + exception=pe.exception, + errorClass="NOT_INT", + messageParameters={"arg_name": "numPartitions", "arg_type": "str"}, + ) + + with self.assertRaises(PySparkValueError) as pe: + df.repartitionById(0, col("id").cast("int")) + self.check_error( + exception=pe.exception, + errorClass="VALUE_NOT_POSITIVE", + messageParameters={"arg_name": "numPartitions", "arg_value": "0"}, + ) + + # Test negative numPartitions + with self.assertRaises(PySparkValueError) as pe: + df.repartitionById(-1, col("id").cast("int")) + self.check_error( + exception=pe.exception, + errorClass="VALUE_NOT_POSITIVE", + messageParameters={"arg_name": "numPartitions", "arg_value": "-1"}, + ) + + def test_repartition_by_id_out_of_range(self): + numPartitions = 10 + df = self.spark.range(20).toDF("id") + repartitioned = df.repartitionById(numPartitions, col("id").cast("int")) + result = repartitioned.collect() + + self.assertEqual(len(result), 20) + # Skip RDD partition count check for Connect mode since RDD is not available + try: + self.assertEqual(repartitioned.rdd.getNumPartitions(), numPartitions) + except Exception: + # Connect mode doesn't support RDD operations, so we skip this check + pass + + def test_repartition_by_id_string_column_name(self): + numPartitions = 5 + df = self.spark.range(25).withColumn( + "partition_id", (col("id") % numPartitions).cast("int") + ) + repartitioned = df.repartitionById(numPartitions, "partition_id") + result = repartitioned.withColumn("actual_p_id", spark_partition_id()) + + mismatches = result.filter(col("partition_id") != col("actual_p_id")).count() + self.assertEqual(mismatches, 0) + class DataFrameRepartitionTests( DataFrameRepartitionTestsMixin,