diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/AutoBucketPartitioner.java b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/AutoBucketPartitioner.java index 98efd267..3836fbd6 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/AutoBucketPartitioner.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/AutoBucketPartitioner.java @@ -150,10 +150,6 @@ public List generatePartitions(final ReadConfig readConfig) return SINGLE_PARTITIONER.generatePartitions(readConfig); } - double avgObjSizeInBytes = - storageStats.get("avgObjSize", new BsonInt32(0)).asNumber().doubleValue(); - double numDocumentsPerPartition = Math.floor(partitionSizeInBytes / avgObjSizeInBytes); - BsonDocument usersCollectionFilter = PartitionerHelper.matchQuery(readConfig.getAggregationPipeline()); long count; @@ -164,6 +160,9 @@ public List generatePartitions(final ReadConfig readConfig) usersCollectionFilter, new CountOptions().comment(readConfig.getComment()))); } + double avgObjSizeInBytes = PartitionerHelper.averageDocumentSize(storageStats, count); + double numDocumentsPerPartition = Math.floor(partitionSizeInBytes / avgObjSizeInBytes); + if (numDocumentsPerPartition == 0 || numDocumentsPerPartition >= count) { LOGGER.info( "Fewer documents ({}) than the calculated number of documents per partition ({}). Returning a single partition", diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PaginateBySizePartitioner.java b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PaginateBySizePartitioner.java index 25702a64..0822c1a4 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PaginateBySizePartitioner.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PaginateBySizePartitioner.java @@ -18,6 +18,7 @@ package com.mongodb.spark.sql.connector.read.partitioner; import static com.mongodb.spark.sql.connector.read.partitioner.PartitionerHelper.SINGLE_PARTITIONER; +import static com.mongodb.spark.sql.connector.read.partitioner.PartitionerHelper.matchQuery; import static java.lang.String.format; import com.mongodb.client.model.CountOptions; @@ -55,7 +56,7 @@ public PaginateBySizePartitioner() {} @Override public List generatePartitions(final ReadConfig readConfig) { MongoConfig partitionerOptions = readConfig.getPartitionerOptions(); - int partitionSizeBytes = Assertions.validateConfig( + int partitionSizeInBytes = Assertions.validateConfig( partitionerOptions.getInt(PARTITION_SIZE_MB_CONFIG, PARTITION_SIZE_MB_DEFAULT), i -> i > 0, () -> @@ -69,18 +70,6 @@ public List generatePartitions(final ReadConfig readConfig) return SINGLE_PARTITIONER.generatePartitions(readConfig); } - double avgObjSizeInBytes = - storageStats.get("avgObjSize", new BsonInt32(0)).asNumber().doubleValue(); - if (avgObjSizeInBytes >= partitionSizeBytes) { - LOGGER.warn( - "Average document size `{}` is greater than the partition size `{}`. Please increase the partition size." - + "Returning a single partition.", - avgObjSizeInBytes, - partitionSizeBytes); - return SINGLE_PARTITIONER.generatePartitions(readConfig); - } - - int numDocumentsPerPartition = (int) Math.floor(partitionSizeBytes / avgObjSizeInBytes); BsonDocument matchQuery = PartitionerHelper.matchQuery(readConfig.getAggregationPipeline()); long count; if (matchQuery.isEmpty() && storageStats.containsKey("count")) { @@ -90,6 +79,18 @@ public List generatePartitions(final ReadConfig readConfig) coll.countDocuments(matchQuery, new CountOptions().comment(readConfig.getComment()))); } + double avgObjSizeInBytes = PartitionerHelper.averageDocumentSize(storageStats, count); + int numDocumentsPerPartition = (int) Math.floor(partitionSizeInBytes / avgObjSizeInBytes); + + if (avgObjSizeInBytes >= partitionSizeInBytes) { + LOGGER.warn( + "Average document size `{}` is greater than the partition size `{}`. Please increase the partition size." + + "Returning a single partition.", + avgObjSizeInBytes, + partitionSizeInBytes); + return SINGLE_PARTITIONER.generatePartitions(readConfig); + } + if (count <= numDocumentsPerPartition) { LOGGER.warn( "The calculated number of documents per partition {} is greater than or equal to the number of matching documents. " diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PartitionerHelper.java b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PartitionerHelper.java index 43ce9120..224832b6 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PartitionerHelper.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/PartitionerHelper.java @@ -18,7 +18,8 @@ package com.mongodb.spark.sql.connector.read.partitioner; import static com.mongodb.spark.sql.connector.read.partitioner.Partitioner.LOGGER; -import static java.util.Collections.singletonList; +import static java.lang.String.format; +import static java.util.Arrays.asList; import com.mongodb.MongoCommandException; import com.mongodb.client.MongoDatabase; @@ -37,9 +38,12 @@ /** Partitioner helper class, contains various utility methods used by the partitioner instances. */ public final class PartitionerHelper { - private static final List COLL_STATS_AGGREGATION_PIPELINE = - singletonList(BsonDocument.parse("{'$collStats': {'storageStats': { } } }")); + private static final List COLL_STATS_AGGREGATION_PIPELINE = asList( + BsonDocument.parse("{'$collStats': {'storageStats': { } } }"), + BsonDocument.parse( + "{'$project': {'size': '$storageStats.size', 'count': '$storageStats.count' } }")); private static final BsonDocument PING_COMMAND = BsonDocument.parse("{ping: 1}"); + private static final BsonDocument BUILD_INFO_COMMAND = BsonDocument.parse("{buildInfo: 1}"); public static final Partitioner SINGLE_PARTITIONER = new SinglePartitionPartitioner(); /** @@ -101,14 +105,34 @@ public static List createPartitionPipeline( public static BsonDocument storageStats(final ReadConfig readConfig) { LOGGER.info("Getting collection stats for: {}", readConfig.getNamespace().getFullName()); try { - return readConfig - .withCollection( - coll -> Optional.ofNullable(coll.aggregate(COLL_STATS_AGGREGATION_PIPELINE) - .allowDiskUse(readConfig.getAggregationAllowDiskUse()) - .comment(readConfig.getComment()) - .first()) - .orElseGet(BsonDocument::new)) - .getDocument("storageStats", new BsonDocument()); + BsonDocument buildInfo = readConfig.withClient(c -> { + MongoDatabase db = c.getDatabase(readConfig.getDatabaseName()); + return db.runCommand(BUILD_INFO_COMMAND).toBsonDocument(); + }); + + // Atlas Data Federation does not support the storageStats property and requires + // special handling to return the federated collection stats. + if (!buildInfo.containsKey("dataLake")) { + return readConfig.withClient(c -> { + MongoDatabase db = c.getDatabase(readConfig.getDatabaseName()); + BsonDocument command = + BsonDocument.parse(format("{ collStats: '%s' }", readConfig.getCollectionName())); + BsonDocument result = db.runCommand(command).toBsonDocument(); + + BsonDocument formattedResult = new BsonDocument(); + formattedResult.append("count", result.get("count")); + formattedResult.append("size", result.get("size")); + + return formattedResult; + }); + } + + return readConfig.withCollection( + coll -> Optional.ofNullable(coll.aggregate(COLL_STATS_AGGREGATION_PIPELINE) + .allowDiskUse(readConfig.getAggregationAllowDiskUse()) + .comment(readConfig.getComment()) + .first()) + .orElseGet(BsonDocument::new)); } catch (RuntimeException ex) { if (ex instanceof MongoCommandException && (ex.getMessage().contains("not found.") @@ -138,5 +162,24 @@ public static List getPreferredLocations(final ReadConfig readConfig) { .collect(Collectors.toList()); } + /** + * Returns the average document size in a collection, either using {@code avgObjSize} + * or calculated from document count and collection size. + * + * @param storageStats the storage stats of a collection + * @param documentCount the number of documents in a collection + * @return the average document size in a collection + */ + public static double averageDocumentSize(final BsonDocument storageStats, final long documentCount) { + if (storageStats.containsKey("avgObjSize")) { + return storageStats.get("avgObjSize", new BsonInt32(0)).asNumber().doubleValue(); + } + + long size = storageStats.getNumber("size").longValue(); + double avgObjSizeInBytes = Math.floor(size / documentCount); + + return avgObjSizeInBytes; + } + private PartitionerHelper() {} } diff --git a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/SamplePartitioner.java b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/SamplePartitioner.java index 88e13b42..91f13681 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/SamplePartitioner.java +++ b/src/main/java/com/mongodb/spark/sql/connector/read/partitioner/SamplePartitioner.java @@ -18,6 +18,7 @@ package com.mongodb.spark.sql.connector.read.partitioner; import static com.mongodb.spark.sql.connector.read.partitioner.PartitionerHelper.SINGLE_PARTITIONER; +import static com.mongodb.spark.sql.connector.read.partitioner.PartitionerHelper.matchQuery; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -105,8 +106,8 @@ public List generatePartitions(final ReadConfig readConfig) count = readConfig.withCollection(coll -> coll.countDocuments(matchQuery, new CountOptions().comment(readConfig.getComment()))); } - double avgObjSizeInBytes = - storageStats.get("avgObjSize", new BsonInt32(0)).asNumber().doubleValue(); + + double avgObjSizeInBytes = PartitionerHelper.averageDocumentSize(storageStats, count); double numDocumentsPerPartition = Math.floor(partitionSizeInBytes / avgObjSizeInBytes); if (numDocumentsPerPartition >= count) {