Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,30 @@ def repartitionByRange( # type: ignore[misc]
},
)

def repartitionById(
self, numPartitions: int, partitionIdCol: "ColumnOrName"
) -> ParentDataFrame:
if not isinstance(numPartitions, int) or isinstance(numPartitions, bool):
raise PySparkTypeError(
errorClass="NOT_INT",
messageParameters={
"arg_name": "numPartitions",
"arg_type": type(numPartitions).__name__,
},
)
if numPartitions <= 0:
raise PySparkValueError(
errorClass="VALUE_NOT_POSITIVE",
messageParameters={
"arg_name": "numPartitions",
"arg_value": str(numPartitions),
},
)
return DataFrame(
self._jdf.repartitionById(numPartitions, _to_java_column(partitionIdCol)),
self.sparkSession,
)

def distinct(self) -> ParentDataFrame:
return DataFrame(self._jdf.distinct(), self.sparkSession)

Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,45 @@ def repartitionByRange( # type: ignore[misc]
res._cached_schema = self._cached_schema
return res

def repartitionById(
self, numPartitions: int, partitionIdCol: "ColumnOrName"
) -> ParentDataFrame:
if not isinstance(numPartitions, int) or isinstance(numPartitions, bool):
raise PySparkTypeError(
errorClass="NOT_INT",
messageParameters={
"arg_name": "numPartitions",
"arg_type": type(numPartitions).__name__,
},
)
if numPartitions <= 0:
raise PySparkValueError(
errorClass="VALUE_NOT_POSITIVE",
messageParameters={
"arg_name": "numPartitions",
"arg_value": str(numPartitions),
},
)

from pyspark.sql.connect.expressions import DirectShufflePartitionID
from pyspark.sql.connect.column import Column

# Convert the partition column to a DirectShufflePartitionID expression
if isinstance(partitionIdCol, str):
partition_col = F.col(partitionIdCol)
else:
partition_col = partitionIdCol

direct_partition_expr = DirectShufflePartitionID(partition_col._expr)
direct_partition_col = Column(direct_partition_expr)

res = DataFrame(
plan.RepartitionByExpression(self._plan, numPartitions, [direct_partition_col]),
self._session,
)
res._cached_schema = self._cached_schema
return res

def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
if subset is not None and not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,3 +1317,24 @@ def __repr__(self) -> str:
repr_parts.append(f"values={self._in_subquery_values}")

return f"SubqueryExpression({', '.join(repr_parts)})"


class DirectShufflePartitionID(Expression):
"""
Expression that takes a partition ID value and passes it through directly for use in
shuffle partitioning. This is used with RepartitionByExpression to allow users to
directly specify target partition IDs.
"""

def __init__(self, child: Expression):
super().__init__()
assert child is not None and isinstance(child, Expression)
self._child = child

def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.direct_shuffle_partition_id.child.CopyFrom(self._child.to_plan(session))
return expr

def __repr__(self) -> str:
return f"DirectShufflePartitionID(child={self._child})"
61 changes: 61 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,67 @@ def repartitionByRange(
"""
...

@dispatch_df_method
def repartitionById(self, numPartitions: int, partitionIdCol: "ColumnOrName") -> "DataFrame":
"""
Returns a new :class:`DataFrame` partitioned by the given partition ID expression.
Each row's target partition is determined directly by the value of the partition ID column.

.. versionadded:: 4.1.0

.. versionchanged:: 4.1.0
Supports Spark Connect.

Parameters
----------
numPartitions : int
target number of partitions
partitionIdCol : str or :class:`Column`
column expression that evaluates to the target partition ID for each row.
Must be an integer type. Values are taken modulo numPartitions to determine
the final partition. Null values are sent to partition 0.

Returns
-------
:class:`DataFrame`
Repartitioned DataFrame.

Notes
-----
The partition ID expression must evaluate to an integer type.
Partition IDs are taken modulo numPartitions, so values outside the range [0, numPartitions)
are automatically mapped to valid partition IDs. If the partition ID expression evaluates to
a NULL value, the row is sent to partition 0.

This method provides direct control over partition placement, similar to RDD's
partitionBy with custom partitioners, but at the DataFrame level.

Examples
--------
Partition rows based on a computed partition ID:

>>> from pyspark.sql import functions as sf
>>> from pyspark.sql.functions import col
>>> df = spark.range(10).withColumn("partition_id", (col("id") % 3).cast("int"))
>>> repartitioned = df.repartitionById(3, "partition_id")
>>> repartitioned.select("id", "partition_id", sf.spark_partition_id()).orderBy("id").show()
+---+------------+--------------------+
| id|partition_id|SPARK_PARTITION_ID()|
+---+------------+--------------------+
| 0| 0| 0|
| 1| 1| 1|
| 2| 2| 2|
| 3| 0| 0|
| 4| 1| 1|
| 5| 2| 2|
| 6| 0| 0|
| 7| 1| 1|
| 8| 2| 2|
| 9| 0| 0|
+---+------------+--------------------+
"""
...

@dispatch_df_method
def distinct(self) -> "DataFrame":
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
Expand Down
118 changes: 116 additions & 2 deletions python/pyspark/sql/tests/test_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import unittest

from pyspark.sql.functions import spark_partition_id
from pyspark.sql.functions import spark_partition_id, col, lit, when
from pyspark.sql.types import (
StringType,
IntegerType,
DoubleType,
StructType,
StructField,
)
from pyspark.errors import PySparkTypeError
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase


Expand Down Expand Up @@ -84,6 +84,120 @@ def test_repartition_by_range(self):
messageParameters={"arg_name": "numPartitions", "arg_type": "list"},
)

def test_repartition_by_id(self):
# Test basic partition ID passthrough behavior
numPartitions = 10
df = self.spark.range(100).withColumn("expected_p_id", col("id") % numPartitions)
repartitioned = df.repartitionById(numPartitions, col("expected_p_id").cast("int"))
result = repartitioned.withColumn("actual_p_id", spark_partition_id())

# All rows should be in their expected partitions
self.assertEqual(result.filter(col("expected_p_id") != col("actual_p_id")).count(), 0)

def test_repartition_by_id_negative_values(self):
df = self.spark.range(10).toDF("id")
repartitioned = df.repartitionById(10, (col("id") - 5).cast("int"))
result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

for row in result:
actualPartitionId = row["actual_p_id"]
id_val = row["id"]
expectedPartitionId = int((id_val - 5) % 10)
self.assertEqual(
actualPartitionId,
expectedPartitionId,
f"Row with id={id_val} should be in partition {expectedPartitionId}, "
f"but was in partition {actualPartitionId}",
)

def test_repartition_by_id_null_values(self):
# Test that null partition ids go to partition 0
df = self.spark.range(10).toDF("id")
partitionExpr = when(col("id") < 5, col("id")).otherwise(lit(None)).cast("int")
repartitioned = df.repartitionById(10, partitionExpr)
result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

nullRows = [row for row in result if row["id"] >= 5]
self.assertTrue(len(nullRows) > 0, "Should have rows with null partition expression")
for row in nullRows:
self.assertEqual(
row["actual_p_id"],
0,
f"Row with null partition id should go to partition 0, "
f"but went to partition {row['actual_p_id']}",
)

nonNullRows = [row for row in result if row["id"] < 5]
for row in nonNullRows:
id_val = row["id"]
actualPartitionId = row["actual_p_id"]
expectedPartitionId = id_val % 10
self.assertEqual(
actualPartitionId,
expectedPartitionId,
f"Row with id={id_val} should be in partition {expectedPartitionId}, "
f"but was in partition {actualPartitionId}",
)

def test_repartition_by_id_error_non_int_type(self):
# Test error for non-integer partition column type
df = self.spark.range(5).withColumn("s", lit("a"))
with self.assertRaises(Exception): # Should raise analysis exception
df.repartitionById(5, col("s")).collect()

def test_repartition_by_id_error_invalid_num_partitions(self):
df = self.spark.range(5)

with self.assertRaises(PySparkTypeError) as pe:
df.repartitionById("5", col("id").cast("int"))
self.check_error(
exception=pe.exception,
errorClass="NOT_INT",
messageParameters={"arg_name": "numPartitions", "arg_type": "str"},
)

with self.assertRaises(PySparkValueError) as pe:
df.repartitionById(0, col("id").cast("int"))
self.check_error(
exception=pe.exception,
errorClass="VALUE_NOT_POSITIVE",
messageParameters={"arg_name": "numPartitions", "arg_value": "0"},
)

# Test negative numPartitions
with self.assertRaises(PySparkValueError) as pe:
df.repartitionById(-1, col("id").cast("int"))
self.check_error(
exception=pe.exception,
errorClass="VALUE_NOT_POSITIVE",
messageParameters={"arg_name": "numPartitions", "arg_value": "-1"},
)

def test_repartition_by_id_out_of_range(self):
numPartitions = 10
df = self.spark.range(20).toDF("id")
repartitioned = df.repartitionById(numPartitions, col("id").cast("int"))
result = repartitioned.collect()

self.assertEqual(len(result), 20)
# Skip RDD partition count check for Connect mode since RDD is not available
try:
self.assertEqual(repartitioned.rdd.getNumPartitions(), numPartitions)
except Exception:
# Connect mode doesn't support RDD operations, so we skip this check
pass

def test_repartition_by_id_string_column_name(self):
numPartitions = 5
df = self.spark.range(25).withColumn(
"partition_id", (col("id") % numPartitions).cast("int")
)
repartitioned = df.repartitionById(numPartitions, "partition_id")
result = repartitioned.withColumn("actual_p_id", spark_partition_id())

mismatches = result.filter(col("partition_id") != col("actual_p_id")).count()
self.assertEqual(mismatches, 0)


class DataFrameRepartitionTests(
DataFrameRepartitionTestsMixin,
Expand Down