Skip to content

Commit 751778e

Browse files
committed
Refactor ID handling for different IdType formats
- Add handling for UUID, TEXT, INTEGER, SERIAL, BIGSERIAL formats in `convertIdToPgType` function. - Implemented type conversion logic based on the IdType value (UUID, TEXT, INTEGER, SERIAL, BIGSERIAL). - Add unit tests to validate correct conversion for UUID and non-UUID IdType formats. - `testToPgTypeWithUuidIdType`: Validates UUID handling. - `testToPgTypeWithNonUuidIdType`: Validates handling for non-UUID IdTypes. Signed-off-by: jitokim <[email protected]>
1 parent 224191a commit 751778e

File tree

2 files changed

+102
-6
lines changed

2 files changed

+102
-6
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535

3636
import org.springframework.ai.document.Document;
3737
import org.springframework.ai.document.DocumentMetadata;
38-
import org.springframework.ai.embedding.BatchingStrategy;
3938
import org.springframework.ai.embedding.EmbeddingModel;
4039
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
41-
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
4240
import org.springframework.ai.observation.conventions.VectorStoreProvider;
4341
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
4442
import org.springframework.ai.util.JacksonUtils;
@@ -152,6 +150,7 @@
152150
* @author Thomas Vitale
153151
* @author Soby Chacko
154152
* @author Sebastien Deleuze
153+
* @author Jihoon Kim
155154
* @since 1.0.0
156155
*/
157156
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -162,6 +161,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
162161

163162
public static final String DEFAULT_TABLE_NAME = "vector_store";
164163

164+
public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID;
165+
165166
public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index";
166167

167168
public static final String DEFAULT_SCHEMA_NAME = "public";
@@ -187,6 +188,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
187188

188189
private final String schemaName;
189190

191+
private final PgIdType idType;
192+
190193
private final boolean schemaValidation;
191194

192195
private final boolean initializeSchema;
@@ -224,6 +227,7 @@ protected PgVectorStore(PgVectorStoreBuilder builder) {
224227
: this.vectorTableName + "_index";
225228

226229
this.schemaName = builder.schemaName;
230+
this.idType = builder.idType;
227231
this.schemaValidation = builder.vectorTableValidationsEnabled;
228232

229233
this.jdbcTemplate = builder.jdbcTemplate;
@@ -272,13 +276,13 @@ private void insertOrUpdateBatch(List<Document> batch, List<Document> documents,
272276
public void setValues(PreparedStatement ps, int i) throws SQLException {
273277

274278
var document = batch.get(i);
279+
var id = convertIdToPgType(document.getId());
275280
var content = document.getText();
276281
var json = toJson(document.getMetadata());
277282
var embedding = embeddings.get(documents.indexOf(document));
278283
var pGvector = new PGvector(embedding);
279284

280-
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
281-
UUID.fromString(document.getId()));
285+
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, id);
282286
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
283287
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
284288
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
@@ -303,6 +307,19 @@ private String toJson(Map<String, Object> map) {
303307
}
304308
}
305309

310+
private Object convertIdToPgType(String id) {
311+
if (this.initializeSchema) {
312+
return UUID.fromString(id);
313+
}
314+
315+
return switch (getIdType()) {
316+
case UUID -> UUID.fromString(id);
317+
case TEXT -> id;
318+
case INTEGER, SERIAL -> Integer.valueOf(id);
319+
case BIGSERIAL -> Long.valueOf(id);
320+
};
321+
}
322+
306323
@Override
307324
public Optional<Boolean> doDelete(List<String> idList) {
308325
int updateCount = 0;
@@ -412,6 +429,10 @@ private String getFullyQualifiedTableName() {
412429
return this.schemaName + "." + this.vectorTableName;
413430
}
414431

432+
private PgIdType getIdType() {
433+
return this.idType;
434+
}
435+
415436
private String getVectorTableName() {
416437
return this.vectorTableName;
417438
}
@@ -489,6 +510,12 @@ public enum PgIndexType {
489510

490511
}
491512

513+
public enum PgIdType {
514+
515+
UUID, TEXT, INTEGER, SERIAL, BIGSERIAL
516+
517+
}
518+
492519
/**
493520
* Defaults to CosineDistance. But if vectors are normalized to length 1 (like OpenAI
494521
* embeddings), use inner product (NegativeInnerProduct) for best performance.
@@ -584,6 +611,8 @@ public static final class PgVectorStoreBuilder extends AbstractVectorStoreBuilde
584611

585612
private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;
586613

614+
private PgIdType idType = PgVectorStore.DEFAULT_ID_TYPE;
615+
587616
private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;
588617

589618
private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION;
@@ -614,6 +643,11 @@ public PgVectorStoreBuilder vectorTableName(String vectorTableName) {
614643
return this;
615644
}
616645

646+
public PgVectorStoreBuilder idType(PgIdType idType) {
647+
this.idType = idType;
648+
return this;
649+
}
650+
617651
public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) {
618652
this.vectorTableValidationsEnabled = vectorTableValidationsEnabled;
619653
return this;

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.io.IOException;
2020
import java.nio.charset.StandardCharsets;
2121
import java.util.Collections;
22+
import java.util.HashMap;
2223
import java.util.Iterator;
2324
import java.util.List;
2425
import java.util.Map;
@@ -29,6 +30,7 @@
2930

3031
import com.zaxxer.hikari.HikariDataSource;
3132
import org.junit.Assert;
33+
import org.junit.jupiter.api.Test;
3234
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3335
import org.junit.jupiter.params.ParameterizedTest;
3436
import org.junit.jupiter.params.provider.Arguments;
@@ -40,13 +42,15 @@
4042

4143
import org.springframework.ai.document.Document;
4244
import org.springframework.ai.document.DocumentMetadata;
45+
import org.springframework.ai.document.id.RandomIdGenerator;
4346
import org.springframework.ai.embedding.EmbeddingModel;
4447
import org.springframework.ai.openai.OpenAiEmbeddingModel;
4548
import org.springframework.ai.openai.api.OpenAiApi;
49+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIdType;
50+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
4651
import org.springframework.ai.vectorstore.SearchRequest;
4752
import org.springframework.ai.vectorstore.VectorStore;
4853
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser.FilterExpressionParseException;
49-
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
5054
import org.springframework.beans.factory.annotation.Value;
5155
import org.springframework.boot.SpringBootConfiguration;
5256
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -67,6 +71,7 @@
6771
* @author Muthukumaran Navaneethakrishnan
6872
* @author Christian Tzolov
6973
* @author Thomas Vitale
74+
* @author Jihoon Kim
7075
*/
7176
@Testcontainers
7277
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -103,6 +108,27 @@ public static String getText(String uri) {
103108
}
104109
}
105110

111+
private static void initSchema(ApplicationContext context) {
112+
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
113+
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
114+
// Enable the PGVector, JSONB and UUID support.
115+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
116+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
117+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
118+
119+
jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));
120+
121+
jdbcTemplate.execute(String.format("""
122+
CREATE TABLE IF NOT EXISTS %s.%s (
123+
id text PRIMARY KEY,
124+
content text,
125+
metadata json,
126+
embedding vector(%d)
127+
)
128+
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME,
129+
vectorStore.embeddingDimensions()));
130+
}
131+
106132
private static void dropTable(ApplicationContext context) {
107133
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
108134
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
@@ -166,6 +192,35 @@ public void addAndSearch(String distanceType) {
166192
});
167193
}
168194

195+
@Test
196+
public void testToPgTypeWithUuidIdType() {
197+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
198+
.run(context -> {
199+
200+
VectorStore vectorStore = context.getBean(VectorStore.class);
201+
202+
vectorStore.add(List.of(new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>())));
203+
204+
dropTable(context);
205+
});
206+
}
207+
208+
@Test
209+
public void testToPgTypeWithNonUuidIdType() {
210+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
211+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
212+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
213+
.run(context -> {
214+
215+
VectorStore vectorStore = context.getBean(VectorStore.class);
216+
initSchema(context);
217+
218+
vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>())));
219+
220+
dropTable(context);
221+
});
222+
}
223+
169224
@ParameterizedTest(name = "Filter expression {0} should return {1} records ")
170225
@MethodSource("provideFilters")
171226
public void searchWithInFilter(String expression, Integer expectedRecords) {
@@ -371,12 +426,19 @@ public static class TestApplication {
371426
@Value("${test.spring.ai.vectorstore.pgvector.distanceType}")
372427
PgVectorStore.PgDistanceType distanceType;
373428

429+
@Value("${test.spring.ai.vectorstore.pgvector.initializeSchema:true}")
430+
boolean initializeSchema;
431+
432+
@Value("${test.spring.ai.vectorstore.pgvector.idType:UUID}")
433+
PgIdType idType;
434+
374435
@Bean
375436
public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
376437
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
377438
.dimensions(PgVectorStore.INVALID_EMBEDDING_DIMENSION)
439+
.idType(idType)
378440
.distanceType(this.distanceType)
379-
.initializeSchema(true)
441+
.initializeSchema(initializeSchema)
380442
.indexType(PgIndexType.HNSW)
381443
.removeExistingVectorStoreTable(true)
382444
.build();

0 commit comments

Comments
 (0)