Skip to content

Commit

Permalink
feat: Optionally ensure upload (#36)
Browse files Browse the repository at this point in the history
* feat: wait to ensure upload

Signed-off-by: Anush008 <[email protected]>

* docs: Updated README.md

Signed-off-by: Anush008 <[email protected]>

---------

Signed-off-by: Anush008 <[email protected]>
  • Loading branch information
Anush008 authored Jan 16, 2025
1 parent b66b5ff commit f38288c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 9 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

## Installation

To integrate the connector into your Spark environment, get the JAR file from one of the sources listed below.

> [!IMPORTANT]
> Requires Java 8 or above.
> Ensure your system is running Java 8.
### GitHub Releases

Expand All @@ -20,11 +22,11 @@ Once the requirements have been satisfied, run the following command in the proj
mvn package
```

This will build and store the fat JAR in the `target` directory by default.
The JAR file will be written into the `target` directory by default.

### Maven Central

For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).
Find the project on Maven Central [here](https://central.sonatype.com/artifact/io.qdrant/spark).

## Usage

Expand Down Expand Up @@ -257,6 +259,7 @@ The appropriate Spark data types are mapped to the Qdrant payload based on the p
| `multi_vector_fields` | Comma-separated names of columns holding the multi-vector values. | `ArrayType(ArrayType(FloatType))` ||
| `multi_vector_names` | Comma-separated names of the multi-vectors in the collection. | - ||
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||
| `wait` | Wait for each batch upsert to complete. `true` or `false`. Defaults to `true`. | - ||

## LICENSE

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>io.qdrant</groupId>
<artifactId>spark</artifactId>
<version>2.3.2</version>
<version>2.3.3</version>
<name>qdrant-spark</name>
<url>https://github.com/qdrant/qdrant-spark</url>
<description>An Apache Spark connector for the Qdrant vector database</description>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private void doWriteBatch() throws Exception {

// Instantiate QdrantGrpc client for each batch to maintain serializability
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector);
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector, options.wait);
qdrant.close();
}

Expand Down
10 changes: 8 additions & 2 deletions src/main/java/io/qdrant/spark/QdrantGrpc.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ public QdrantGrpc(URL url, String apiKey) throws MalformedURLException {
}

public void upsert(
String collectionName, List<PointStruct> points, ShardKeySelector shardKeySelector)
String collectionName,
List<PointStruct> points,
ShardKeySelector shardKeySelector,
boolean wait)
throws InterruptedException, ExecutionException {
UpsertPoints.Builder upsertPoints =
UpsertPoints.newBuilder().setCollectionName(collectionName).addAllPoints(points);
UpsertPoints.newBuilder()
.setCollectionName(collectionName)
.setWait(wait)
.addAllPoints(points);
if (shardKeySelector != null) {
upsertPoints.setShardKeySelector(shardKeySelector);
}
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/io/qdrant/spark/QdrantOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
public class QdrantOptions implements Serializable {
private static final int DEFAULT_BATCH_SIZE = 64;
private static final int DEFAULT_RETRIES = 3;
private static final boolean DEFAULT_WAIT = true;

public final String qdrantUrl;
public final String apiKey;
Expand All @@ -33,6 +34,7 @@ public class QdrantOptions implements Serializable {
public final String[] multiVectorNames;
public final List<String> payloadFieldsToSkip;
public final ShardKeySelector shardKeySelector;
public final boolean wait;

public QdrantOptions(Map<String, String> options) {
Objects.requireNonNull(options);
Expand All @@ -45,6 +47,7 @@ public QdrantOptions(Map<String, String> options) {
apiKey = options.getOrDefault("api_key", "");
embeddingField = options.getOrDefault("embedding_field", "");
vectorName = options.getOrDefault("vector_name", "");
wait = getBooleanOption(options, "wait", DEFAULT_WAIT);

sparseVectorValueFields = parseArray(options.get("sparse_vector_value_fields"));
sparseVectorIndexFields = parseArray(options.get("sparse_vector_index_fields"));
Expand All @@ -66,6 +69,10 @@ private int getIntOption(Map<String, String> options, String key, int defaultVal
return Integer.parseInt(options.getOrDefault(key, String.valueOf(defaultValue)));
}

private boolean getBooleanOption(Map<String, String> options, String key, boolean defaultValue) {
return Boolean.parseBoolean(options.getOrDefault(key, String.valueOf(defaultValue)));
}

private String[] parseArray(String input) {
return input == null
? new String[0]
Expand Down
3 changes: 1 addition & 2 deletions src/test/java/io/qdrant/spark/TestQdrantGrpc.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ public void testUploadBatch() throws Exception {
point2Builder.putPayload("rand_number", value(89));
points.add(point2Builder.build());

// call the uploadBatch method
qdrantGrpc.upsert(collectionName, points, null);
qdrantGrpc.upsert(collectionName, points, null, true);

qdrantGrpc.close();
}
Expand Down

0 comments on commit f38288c

Please sign in to comment.