Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2430,6 +2430,28 @@ object SQLConf {
.checkValue(v => v > 0, "The maximum number of partitions must be a positive integer.")
.createOptional

val FILES_PARTITION_STRATEGY = buildConf("spark.sql.files.partitionStrategy")
.doc("The strategy to coalesce small files into larger partitions when reading files. " +
"Options are `size_based` (coalesce based on size of files), and `file_based` "
+ "(coalesce based on number of files). The number of output partitions depends on " +
"`spark.sql.files.maxPartitionBytes` and `spark.sql.files.maxPartitionNum`. " +
"This configuration is effective only when using file-based sources such as " +
"Parquet, JSON and ORC.")
.version("3.5.0")
.stringConf
.checkValues(Set("size_based", "file_based"))
.createWithDefault("size_based")

val SMALL_FILE_THRESHOLD =
buildConf("spark.sql.files.smallFileThreshold")
.doc(
"Defines the total size threshold for small files in a table scan. If the cumulative size " +
"of small files falls below this threshold, they are distributed across multiple " +
"partitions to avoid concentrating them in a single partition. This configuration is " +
"used when `spark.sql.files.coalesceStrategy` is set to `file_based`.")
.doubleConf
.createWithDefault(0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may I ask how did we arrive with 0.5 as the default smallfilethreshold?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marin-ma If i understand correctly, does 0.5 means "if the total small files are less than 50% of a partition, distribute them across multiple partitions"??


val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
"encountering corrupted files and the contents that have been read will still be returned. " +
Expand Down Expand Up @@ -6949,6 +6971,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def filesMaxPartitionNum: Option[Int] = getConf(FILES_MAX_PARTITION_NUM)

def filesPartitionStrategy: String = getConf(FILES_PARTITION_STRATEGY)

def smallFileThreshold: Double = getConf(SMALL_FILE_THRESHOLD)

def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)

def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ case class FilePartition(index: Int, files: Array[PartitionedFile])

object FilePartition extends SessionStateHelper with Logging {

private def getFilePartitions(
private def getFilePartitionsBySize(
partitionedFiles: Seq[PartitionedFile],
maxSplitBytes: Long,
openCostInBytes: Long): Seq[FilePartition] = {
Expand All @@ -75,7 +75,7 @@ object FilePartition extends SessionStateHelper with Logging {
}

// Assign files to partitions using "Next Fit Decreasing"
partitionedFiles.foreach { file =>
partitionedFiles.sortBy(_.length)(implicitly[Ordering[Long]].reverse).foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
}
Expand All @@ -87,28 +87,116 @@ object FilePartition extends SessionStateHelper with Logging {
partitions.toSeq
}

private def getFilePartitionsByFileNum(
partitionedFiles: Seq[PartitionedFile],
outputPartitions: Int,
smallFileThreshold: Double): Seq[FilePartition] = {
// Flatten and sort descending by file size.
val filesSorted: Seq[(PartitionedFile, Long)] =
partitionedFiles
.map(f => (f, f.length))
.sortBy(_._2)(Ordering.Long.reverse)

val partitions = Seq.fill(outputPartitions)(mutable.ArrayBuffer.empty[PartitionedFile])

def addToBucket(
heap: mutable.PriorityQueue[(Long, Int, Int)],
file: PartitionedFile,
sz: Long): Unit = {
val (load, numFiles, idx) = heap.dequeue()
partitions(idx) += file
heap.enqueue((load + sz, numFiles + 1, idx))
}

// First by load, then by numFiles.
val heapByFileSize =
mutable.PriorityQueue.empty[(Long, Int, Int)](
Ordering
.by[(Long, Int, Int), (Long, Int)] {
case (load, numFiles, _) =>
(load, numFiles)
}
.reverse
)

if (smallFileThreshold > 0) {
val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold
// First by numFiles, then by load.
val heapByFileNum =
mutable.PriorityQueue.empty[(Long, Int, Int)](
Ordering
.by[(Long, Int, Int), (Int, Long)] {
case (load, numFiles, _) =>
(numFiles, load)
}
.reverse
)

(0 until outputPartitions).foreach(i => heapByFileNum.enqueue((0L, 0, i)))

var numSmallFiles = 0
var smallFileSize = 0L
// Enqueue small files to the least number of files and the least load.
filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= smallFileTotalSize).foreach {
case (file, sz) =>
addToBucket(heapByFileNum, file, sz)
numSmallFiles += 1
smallFileSize += sz
}

// Move buckets from heapByFileNum to heapByFileSize.
while (heapByFileNum.nonEmpty) {
heapByFileSize.enqueue(heapByFileNum.dequeue())
}

// Finally, enqueue remaining files.
filesSorted.take(filesSorted.size - numSmallFiles).foreach {
case (file, sz) =>
addToBucket(heapByFileSize, file, sz)
}
} else {
(0 until outputPartitions).foreach(i => heapByFileSize.enqueue((0L, 0, i)))

filesSorted.foreach {
case (file, sz) =>
addToBucket(heapByFileSize, file, sz)
}
}

partitions.zipWithIndex.map {
case (p, idx) => FilePartition(idx, p.toArray)
}
}

def getFilePartitions(
sparkSession: SparkSession,
partitionedFiles: Seq[PartitionedFile],
maxSplitBytes: Long): Seq[FilePartition] = {
val conf = getSqlConf(sparkSession)
val openCostBytes = conf.filesOpenCostInBytes
val maxPartNum = conf.filesMaxPartitionNum
val partitions = getFilePartitions(partitionedFiles, maxSplitBytes, openCostBytes)
if (maxPartNum.exists(partitions.size > _)) {
val totalSizeInBytes =
partitionedFiles.map(_.length + openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
val desiredSplitBytes =
(totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, RoundingMode.UP).longValue
val desiredPartitions = getFilePartitions(partitionedFiles, desiredSplitBytes, openCostBytes)
logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, partitions.size)}, " +
log"which exceeds the maximum number configured: " +
log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " +
log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring the " +
log"configuration of ${MDC(CONFIG, SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
desiredPartitions
} else {
partitions
val partitions = getFilePartitionsBySize(partitionedFiles, maxSplitBytes, openCostBytes)
conf.filesPartitionStrategy match {
case "file_based" =>
getFilePartitionsByFileNum(partitionedFiles, Math.min(partitions.size,
maxPartNum.getOrElse(Int.MaxValue)), conf.smallFileThreshold)
case "size_based" =>
if (maxPartNum.exists(partitions.size > _)) {
val totalSizeInBytes =
partitionedFiles.map(_.length + openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
val desiredSplitBytes =
(totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, RoundingMode.UP).longValue
val desiredPartitions = getFilePartitionsBySize(
partitionedFiles, desiredSplitBytes, openCostBytes)
logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, partitions.size)}, " +
log"which exceeds the maximum number configured: " +
log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " +
log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring the " +
log"configuration of ${MDC(CONFIG, SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
desiredPartitions
} else {
partitions
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,21 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession {
}
}

test(s"Test ${SQLConf.FILES_PARTITION_STRATEGY.key} works as expected") {
val files = {
Range(0, 20000 - 10).map(p => PartitionedFile(InternalRow.empty, sp(s"$p"), 0, 50000))
} ++ Range(0, 10).map(p => PartitionedFile(InternalRow.empty, sp(s"small_$p"), 0, 5000))

withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "50000",
SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0",
SQLConf.FILES_PARTITION_STRATEGY.key -> "file_based"
) {
val partitions = FilePartition.getFilePartitions(
spark, files, conf.filesMaxPartitionBytes)
assert(!partitions.exists(_.files.length >= 10))
}
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down