From f38288c42a3748d8474aee3dd82b3c3c53d052e5 Mon Sep 17 00:00:00 2001 From: Anush Date: Thu, 16 Jan 2025 18:34:47 +0530 Subject: [PATCH] feat: Optionally ensure upload (#36) * feat: wait to ensure upload Signed-off-by: Anush008 * docs: Updated README.md Signed-off-by: Anush008 --------- Signed-off-by: Anush008 --- README.md | 9 ++++++--- pom.xml | 2 +- src/main/java/io/qdrant/spark/QdrantDataWriter.java | 2 +- src/main/java/io/qdrant/spark/QdrantGrpc.java | 10 ++++++++-- src/main/java/io/qdrant/spark/QdrantOptions.java | 7 +++++++ src/test/java/io/qdrant/spark/TestQdrantGrpc.java | 3 +-- 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 9d9526d..48c13c3 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,10 @@ ## Installation +To integrate the connector into your Spark environment, get the JAR file from one of the sources listed below. + > [!IMPORTANT] -> Requires Java 8 or above. +> Ensure your system is running Java 8. ### GitHub Releases @@ -20,11 +22,11 @@ Once the requirements have been satisfied, run the following command in the proj mvn package ``` -This will build and store the fat JAR in the `target` directory by default. +The JAR file will be written into the `target` directory by default. ### Maven Central -For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark). +Find the project on Maven Central [here](https://central.sonatype.com/artifact/io.qdrant/spark). ## Usage @@ -257,6 +259,7 @@ The appropriate Spark data types are mapped to the Qdrant payload based on the p | `multi_vector_fields` | Comma-separated names of columns holding the multi-vector values. | `ArrayType(ArrayType(FloatType))` | ❌ | | `multi_vector_names` | Comma-separated names of the multi-vectors in the collection. | - | ❌ | | `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - | ❌ | +| `wait` | Wait for each batch upsert to complete. `true` or `false`. Defaults to `true`. | - | ❌ | ## LICENSE diff --git a/pom.xml b/pom.xml index 4b2e1af..bd5e130 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ 4.0.0 io.qdrant spark - 2.3.2 + 2.3.3 qdrant-spark https://github.com/qdrant/qdrant-spark An Apache Spark connector for the Qdrant vector database diff --git a/src/main/java/io/qdrant/spark/QdrantDataWriter.java b/src/main/java/io/qdrant/spark/QdrantDataWriter.java index bb34367..7a4fd15 100644 --- a/src/main/java/io/qdrant/spark/QdrantDataWriter.java +++ b/src/main/java/io/qdrant/spark/QdrantDataWriter.java @@ -68,7 +68,7 @@ private void doWriteBatch() throws Exception { // Instantiate QdrantGrpc client for each batch to maintain serializability QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey); - qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector); + qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector, options.wait); qdrant.close(); } diff --git a/src/main/java/io/qdrant/spark/QdrantGrpc.java b/src/main/java/io/qdrant/spark/QdrantGrpc.java index 2852ade..05770f3 100644 --- a/src/main/java/io/qdrant/spark/QdrantGrpc.java +++ b/src/main/java/io/qdrant/spark/QdrantGrpc.java @@ -26,10 +26,16 @@ public QdrantGrpc(URL url, String apiKey) throws MalformedURLException { } public void upsert( - String collectionName, List points, ShardKeySelector shardKeySelector) + String collectionName, + List points, + ShardKeySelector shardKeySelector, + boolean wait) throws InterruptedException, ExecutionException { UpsertPoints.Builder upsertPoints = - UpsertPoints.newBuilder().setCollectionName(collectionName).addAllPoints(points); + UpsertPoints.newBuilder() + .setCollectionName(collectionName) + .setWait(wait) + .addAllPoints(points); if (shardKeySelector != null) { upsertPoints.setShardKeySelector(shardKeySelector); } diff --git a/src/main/java/io/qdrant/spark/QdrantOptions.java b/src/main/java/io/qdrant/spark/QdrantOptions.java index dbf784d..a700f54 100644 --- a/src/main/java/io/qdrant/spark/QdrantOptions.java +++ b/src/main/java/io/qdrant/spark/QdrantOptions.java @@ -15,6 +15,7 @@ public class QdrantOptions implements Serializable { private static final int DEFAULT_BATCH_SIZE = 64; private static final int DEFAULT_RETRIES = 3; + private static final boolean DEFAULT_WAIT = true; public final String qdrantUrl; public final String apiKey; @@ -33,6 +34,7 @@ public class QdrantOptions implements Serializable { public final String[] multiVectorNames; public final List payloadFieldsToSkip; public final ShardKeySelector shardKeySelector; + public final boolean wait; public QdrantOptions(Map options) { Objects.requireNonNull(options); @@ -45,6 +47,7 @@ public QdrantOptions(Map options) { apiKey = options.getOrDefault("api_key", ""); embeddingField = options.getOrDefault("embedding_field", ""); vectorName = options.getOrDefault("vector_name", ""); + wait = getBooleanOption(options, "wait", DEFAULT_WAIT); sparseVectorValueFields = parseArray(options.get("sparse_vector_value_fields")); sparseVectorIndexFields = parseArray(options.get("sparse_vector_index_fields")); @@ -66,6 +69,10 @@ private int getIntOption(Map options, String key, int defaultVal return Integer.parseInt(options.getOrDefault(key, String.valueOf(defaultValue))); } + private boolean getBooleanOption(Map options, String key, boolean defaultValue) { + return Boolean.parseBoolean(options.getOrDefault(key, String.valueOf(defaultValue))); + } + private String[] parseArray(String input) { return input == null ? new String[0] diff --git a/src/test/java/io/qdrant/spark/TestQdrantGrpc.java b/src/test/java/io/qdrant/spark/TestQdrantGrpc.java index b1c1e3c..11f8e43 100644 --- a/src/test/java/io/qdrant/spark/TestQdrantGrpc.java +++ b/src/test/java/io/qdrant/spark/TestQdrantGrpc.java @@ -68,8 +68,7 @@ public void testUploadBatch() throws Exception { point2Builder.putPayload("rand_number", value(89)); points.add(point2Builder.build()); - // call the uploadBatch method - qdrantGrpc.upsert(collectionName, points, null); + qdrantGrpc.upsert(collectionName, points, null, true); qdrantGrpc.close(); }