Skip to content

Commit 2ae3d85

Browse files
authored
Merge pull request #1779 from marklogic/feature/20906-vector-functions-fix
MLE-19374 Updated cosine and annTopK functions
2 parents 4a7f812 + e030819 commit 2ae3d85

File tree

5 files changed

+39
-38
lines changed

5 files changed

+39
-38
lines changed

marklogic-client-api/src/main/java/com/marklogic/client/expression/PlanBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,11 +1507,11 @@ public interface ModifyPlan extends PreparePlan, PlanBuilderBase.ModifyPlanBase
15071507
* @param vectorColumn The column representing the vector ann-indexed column to perform the index lookup against.
15081508
* @param queryVector Specifies the query vector to perform the index lookup with.
15091509
* @param distanceColumn Optional output column that captures the values of the distance metric of the vectors retrieved from the index associated with vectorColumn and the queryVector.
1510-
* @param queryTolerance Specifies the query tolerance to help balance recall and search time. The value is between 0.0 and 1.0. At 0.0, the recall will be highest. At 1.0 the recall will likely see a large degradation, but queries will be quick. The default value is 0.0.
1510+
* @param options Optional sequence of strings or a map containing keys and values for the options to this operator.
15111511
* @return
1512-
* @since 7.1.0
1512+
* @since 7.2.0
15131513
*/
1514-
ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance);
1514+
ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, Map<String, Object> options);
15151515

15161516
/**
15171517
* This method restricts the left row set to rows where a row with the same columns and values doesn't exist in the right row set.

marklogic-client-api/src/main/java/com/marklogic/client/expression/VecExpr.java

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44

55
package com.marklogic.client.expression;
66

7-
import com.marklogic.client.type.XsAnyAtomicTypeSeqVal;
8-
import com.marklogic.client.type.XsDoubleVal;
9-
import com.marklogic.client.type.XsFloatVal;
10-
import com.marklogic.client.type.XsStringVal;
11-
import com.marklogic.client.type.XsUnsignedIntVal;
12-
import com.marklogic.client.type.XsUnsignedLongVal;
13-
147
import com.marklogic.client.type.ServerExpression;
158

169
// IMPORTANT: Do not edit. This file is generated.
@@ -59,15 +52,15 @@ public interface VecExpr {
5952
/**
6053
* Returns the cosine similarity between two vectors. The vectors must be of the same dimension.
6154
*
62-
* <a name="ml-server-type-cosine-similarity"></a>
55+
* <a name="ml-server-type-cosine"></a>
6356
6457
* <p>
65-
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:cosine-similarity" target="mlserverdoc">vec:cosine-similarity</a> server function.
58+
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:cosine" target="mlserverdoc">vec:cosine</a> server function.
6659
* @param vector1 The vector from which to calculate the cosine similarity with vector2. (of <a href="{@docRoot}/doc-files/types/vec_vector.html">vec:vector</a>)
6760
* @param vector2 The vector from which to calculate the cosine similarity with vector1. (of <a href="{@docRoot}/doc-files/types/vec_vector.html">vec:vector</a>)
6861
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a> server data type
6962
*/
70-
public ServerExpression cosineSimilarity(ServerExpression vector1, ServerExpression vector2);
63+
public ServerExpression cosine(ServerExpression vector1, ServerExpression vector2);
7164
/**
7265
* Returns the dimension of the vector passed in.
7366
*
@@ -187,7 +180,7 @@ public interface VecExpr {
187180
* <p>
188181
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
189182
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
190-
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
183+
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
191184
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
192185
*/
193186
public ServerExpression vectorScore(ServerExpression score, double similarity);
@@ -199,7 +192,7 @@ public interface VecExpr {
199192
* <p>
200193
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
201194
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
202-
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
195+
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
203196
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
204197
*/
205198
public ServerExpression vectorScore(ServerExpression score, ServerExpression similarity);
@@ -208,7 +201,7 @@ public interface VecExpr {
208201
* <p>
209202
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
210203
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
211-
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
204+
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
212205
* @param similarityWeight The weight of the vector similarity on the score. The default value is 0.1. If 0.0 is passed in, vector similarity has no effect. If passed a value less than 0.0 or greater than 1.0, throw VEC-VECTORSCORE. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
213206
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
214207
*/
@@ -218,7 +211,7 @@ public interface VecExpr {
218211
* <p>
219212
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
220213
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
221-
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
214+
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
222215
* @param similarityWeight The weight of the vector similarity on the score. The default value is 0.1. If 0.0 is passed in, vector similarity has no effect. If passed a value less than 0.0 or greater than 1.0, throw VEC-VECTORSCORE. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
223216
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
224217
*/

marklogic-client-api/src/main/java/com/marklogic/client/impl/PlanBuilderSubImpl.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -987,9 +987,9 @@ static class ModifyPlanSubImpl
987987
}
988988

989989
@Override
990-
public ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance) {
990+
public ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, Map<String, Object> options) {
991991
return new PlanBuilderSubImpl.ModifyPlanSubImpl(this, "op", "annTopK", new Object[]{
992-
k, vectorColumn, queryVector, distanceColumn, queryTolerance
992+
k, vectorColumn, queryVector, distanceColumn, new BaseTypeImpl.BaseMapImpl(options)
993993
});
994994
}
995995

@@ -1029,7 +1029,7 @@ public ModifyPlan facetBy(PlanNamedGroupSeq keys) {
10291029
}
10301030
@Override
10311031
public ModifyPlan facetBy(PlanNamedGroupSeq keys, String countCol) {
1032-
return facetBy(keys, (countCol == null) ? (PlanExprCol) null : exprCol(countCol));
1032+
return facetBy(keys, (countCol == null) ? null : exprCol(countCol));
10331033
}
10341034
@Override
10351035
public ModifyPlan facetBy(PlanNamedGroupSeq keys, PlanExprCol countCol) {
@@ -1100,7 +1100,7 @@ public ModifyPlan remove(PlanColumn uriColumn) {
11001100
}
11011101

11021102
static class TemporalRemoval implements BaseArgImpl {
1103-
private String template;
1103+
private final String template;
11041104

11051105
public TemporalRemoval(PlanColumn temporalCollection, PlanColumn uriColumn) {
11061106
this.template = String.format("{\"temporalCollection\":%s, \"uri\": %s}",

marklogic-client-api/src/main/java/com/marklogic/client/impl/VecExprImpl.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ public ServerExpression base64Encode(ServerExpression vector1) {
5858

5959

6060
@Override
61-
public ServerExpression cosineSimilarity(ServerExpression vector1, ServerExpression vector2) {
61+
public ServerExpression cosine(ServerExpression vector1, ServerExpression vector2) {
6262
if (vector1 == null) {
63-
throw new IllegalArgumentException("vector1 parameter for cosineSimilarity() cannot be null");
63+
throw new IllegalArgumentException("vector1 parameter for cosine() cannot be null");
6464
}
6565
if (vector2 == null) {
66-
throw new IllegalArgumentException("vector2 parameter for cosineSimilarity() cannot be null");
66+
throw new IllegalArgumentException("vector2 parameter for cosine() cannot be null");
6767
}
68-
return new XsExprImpl.DoubleCallImpl("vec", "cosine-similarity", new Object[]{ vector1, vector2 });
68+
return new XsExprImpl.DoubleCallImpl("vec", "cosine", new Object[]{ vector1, vector2 });
6969
}
7070

7171

marklogic-client-api/src/test/java/com/marklogic/client/test/rows/VectorTest.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import org.junit.jupiter.api.Test;
1515
import org.junit.jupiter.api.extension.ExtendWith;
1616

17+
import java.util.HashMap;
1718
import java.util.List;
19+
import java.util.Map;
1820

1921
import static org.junit.jupiter.api.Assertions.*;
2022

@@ -38,7 +40,7 @@ void vectorFunctionsHappyPath() {
3840
PlanBuilder.ModifyPlan plan =
3941
op.fromView("vectors", "persons")
4042
.bind(op.as("sampleVector", op.vec.vector(sampleVector)))
41-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
43+
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
4244
.bind(op.as("dotProduct", op.vec.dotProduct(op.col("embedding"), op.col("sampleVector"))))
4345
.bind(op.as("euclideanDistance", op.vec.euclideanDistance(op.col("embedding"), op.col("sampleVector"))))
4446
.bind(op.as("dimension", op.vec.dimension(op.col("sampleVector"))))
@@ -52,7 +54,7 @@ void vectorFunctionsHappyPath() {
5254
.bind(op.as("subVector", op.vec.subvector(op.col("sampleVector"), op.xs.integer(1), op.xs.integer(1))))
5355
.bind(op.as("vectorScore", op.vec.vectorScore(op.xs.unsignedInt(1), op.xs.doubleVal(0.5))))
5456
.select(
55-
op.col("cosineSimilarity"), op.col("dotProduct"), op.col("euclideanDistance"),
57+
op.col("cosine"), op.col("dotProduct"), op.col("euclideanDistance"),
5658
op.col("name"), op.col("dimension"), op.col("normalize"),
5759
op.col("magnitude"), op.col("get"), op.col("add"), op.col("subtract"),
5860
op.col("base64Encode"), op.col("base64Decode"), op.col("subVector"), op.col("vectorScore")
@@ -63,8 +65,8 @@ void vectorFunctionsHappyPath() {
6365

6466
rows.forEach(row -> {
6567
// Simple a sanity checks to verify that the functions ran. Very little concern about the actual return values.
66-
double cosineSimilarity = row.getDouble("cosineSimilarity");
67-
assertTrue((cosineSimilarity > 0) && (cosineSimilarity < 1), "Unexpected value: " + cosineSimilarity);
68+
double cosine = row.getDouble("cosine");
69+
assertTrue((cosine > 0) && (cosine < 1), "Unexpected value: " + cosine);
6870
double dotProduct = row.getDouble("dotProduct");
6971
Assertions.assertTrue(dotProduct > 0, "Unexpected value: " + dotProduct);
7072
double euclideanDistance = row.getDouble("euclideanDistance");
@@ -85,25 +87,25 @@ void vectorFunctionsHappyPath() {
8587
}
8688

8789
@Test
88-
void cosineSimilarity_DimensionMismatch() {
90+
void cosine_DimensionMismatch() {
8991
PlanBuilder.ModifyPlan plan =
9092
op.fromView("vectors", "persons")
9193
.bind(op.as("sampleVector", op.vec.vector(twoDimensionalVector)))
92-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
93-
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
94+
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
95+
.select(op.col("name"), op.col("summary"), op.col("cosine"));
9496
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
9597
String actualMessage = exception.getMessage();
9698
assertTrue(actualMessage.contains("Server Message: VEC-DIMMISMATCH"), "Unexpected message: " + actualMessage);
9799
assertTrue(actualMessage.contains("Mismatched dimension"), "Unexpected message: " + actualMessage);
98100
}
99101

100102
@Test
101-
void cosineSimilarity_InvalidVector() {
103+
void cosine_InvalidVector() {
102104
PlanBuilder.ModifyPlan plan =
103105
op.fromView("vectors", "persons")
104106
.bind(op.as("sampleVector", invalidVector))
105-
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
106-
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
107+
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
108+
.select(op.col("name"), op.col("summary"), op.col("cosine"));
107109
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
108110
String actualMessage = exception.getMessage();
109111
assertTrue(actualMessage.contains("Server Message: XDMP-ARGTYPE"), "Unexpected message: " + actualMessage);
@@ -139,10 +141,16 @@ void vecVectorWithCol() {
139141
assertEquals(2, rows.size());
140142
}
141143

144+
/**
145+
* Updated after 2025-06-06, when the vector functions were updated. That includes annTopK being modified to accept
146+
* an options map as its 5th argument instead of a single query tolerance value.
147+
*/
142148
@Test
143-
void annTopK() {
149+
void annTopKWithOptionsMap() {
150+
Map<String, Object> options = new HashMap<>();
151+
options.put("distance", "cosine");
144152
PlanBuilder.ModifyPlan plan = op.fromView("vectors", "persons")
145-
.annTopK(10, op.col("embedding"), op.vec.vector(sampleVector), op.col("distance"), 0.5f);
153+
.annTopK(10, op.col("embedding"), op.vec.vector(sampleVector), op.col("distance"), options);
146154

147155
List<RowRecord> rows = resultRows(plan);
148156
assertEquals(2, rows.size(), "Verifying that annTopK worked and returned both rows from the view.");
@@ -158,7 +166,7 @@ void dslAnnTopK() {
158166
String query = "const qualityVector = vec.vector([ 1.1, 2.2, 3.3 ]);\n" +
159167
"op.fromView('vectors', 'persons')\n" +
160168
" .bind(op.as('myVector', op.vec.vector(op.col('embedding'))))\n" +
161-
" .annTopK(2, op.col('myVector'), qualityVector, op.col('distance'), 0.5)";
169+
" .annTopK(2, op.col('myVector'), qualityVector, op.col('distance'), {'distance':'cosine'})";
162170

163171
RawQueryDSLPlan plan = rowManager.newRawQueryDSLPlan(new StringHandle(query));
164172
List<RowRecord> rows = resultRows(plan);

0 commit comments

Comments
 (0)