diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java index 19b70507db8..de5f5f2912b 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java @@ -35,10 +35,8 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; -import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; -import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.util.JacksonUtils; @@ -152,6 +150,7 @@ * @author Thomas Vitale * @author Soby Chacko * @author Sebastien Deleuze + * @author Jihoon Kim * @since 1.0.0 */ public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -162,6 +161,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public static final String DEFAULT_TABLE_NAME = "vector_store"; + public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID; + public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index"; public static final String DEFAULT_SCHEMA_NAME = "public"; @@ -187,6 +188,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final String schemaName; + private final PgIdType idType; + private final boolean schemaValidation; private final boolean initializeSchema; @@ -224,6 +227,7 @@ protected PgVectorStore(PgVectorStoreBuilder builder) { : this.vectorTableName + "_index"; this.schemaName = builder.schemaName; + this.idType = builder.idType; this.schemaValidation = builder.vectorTableValidationsEnabled; this.jdbcTemplate = builder.jdbcTemplate; @@ -272,13 +276,13 @@ private void insertOrUpdateBatch(List batch, List documents, public void setValues(PreparedStatement ps, int i) throws SQLException { var document = batch.get(i); + var id = convertIdToPgType(document.getId()); var content = document.getText(); var json = toJson(document.getMetadata()); var embedding = embeddings.get(documents.indexOf(document)); var pGvector = new PGvector(embedding); - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, id); StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); @@ -303,6 +307,19 @@ private String toJson(Map map) { } } + private Object convertIdToPgType(String id) { + if (this.initializeSchema) { + return UUID.fromString(id); + } + + return switch (getIdType()) { + case UUID -> UUID.fromString(id); + case TEXT -> id; + case INTEGER, SERIAL -> Integer.valueOf(id); + case BIGSERIAL -> Long.valueOf(id); + }; + } + @Override public Optional doDelete(List idList) { int updateCount = 0; @@ -412,6 +429,10 @@ private String getFullyQualifiedTableName() { return this.schemaName + "." + this.vectorTableName; } + private PgIdType getIdType() { + return this.idType; + } + private String getVectorTableName() { return this.vectorTableName; } @@ -489,6 +510,12 @@ public enum PgIndexType { } + public enum PgIdType { + + UUID, TEXT, INTEGER, SERIAL, BIGSERIAL + + } + /** * Defaults to CosineDistance. But if vectors are normalized to length 1 (like OpenAI * embeddings), use inner product (NegativeInnerProduct) for best performance. @@ -584,6 +611,8 @@ public static final class PgVectorStoreBuilder extends AbstractVectorStoreBuilde private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME; + private PgIdType idType = PgVectorStore.DEFAULT_ID_TYPE; + private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION; @@ -614,6 +643,11 @@ public PgVectorStoreBuilder vectorTableName(String vectorTableName) { return this; } + public PgVectorStoreBuilder idType(PgIdType idType) { + this.idType = idType; + return this; + } + public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) { this.vectorTableValidationsEnabled = vectorTableValidationsEnabled; return this; diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java index 82a70237a2f..ac0b7eb6aa6 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -29,6 +30,7 @@ import com.zaxxer.hikari.HikariDataSource; import org.junit.Assert; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -40,13 +42,15 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; +import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIdType; +import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser.FilterExpressionParseException; -import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -67,6 +71,7 @@ * @author Muthukumaran Navaneethakrishnan * @author Christian Tzolov * @author Thomas Vitale + * @author Jihoon Kim */ @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -103,6 +108,27 @@ public static String getText(String uri) { } } + private static void initSchema(ApplicationContext context) { + PgVectorStore vectorStore = context.getBean(PgVectorStore.class); + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + // Enable the PGVector, JSONB and UUID support. + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector"); + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore"); + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\""); + + jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME)); + + jdbcTemplate.execute(String.format(""" + CREATE TABLE IF NOT EXISTS %s.%s ( + id text PRIMARY KEY, + content text, + metadata json, + embedding vector(%d) + ) + """, PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME, + vectorStore.embeddingDimensions())); + } + private static void dropTable(ApplicationContext context) { JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store"); @@ -166,6 +192,35 @@ public void addAndSearch(String distanceType) { }); } + @Test + public void testToPgTypeWithUuidIdType() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(List.of(new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()))); + + dropTable(context); + }); + } + + @Test + public void testToPgTypeWithNonUuidIdType() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE") + .withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + initSchema(context); + + vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>()))); + + dropTable(context); + }); + } + @ParameterizedTest(name = "Filter expression {0} should return {1} records ") @MethodSource("provideFilters") public void searchWithInFilter(String expression, Integer expectedRecords) { @@ -371,12 +426,19 @@ public static class TestApplication { @Value("${test.spring.ai.vectorstore.pgvector.distanceType}") PgVectorStore.PgDistanceType distanceType; + @Value("${test.spring.ai.vectorstore.pgvector.initializeSchema:true}") + boolean initializeSchema; + + @Value("${test.spring.ai.vectorstore.pgvector.idType:UUID}") + PgIdType idType; + @Bean public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { return PgVectorStore.builder(jdbcTemplate, embeddingModel) .dimensions(PgVectorStore.INVALID_EMBEDDING_DIMENSION) + .idType(idType) .distanceType(this.distanceType) - .initializeSchema(true) + .initializeSchema(initializeSchema) .indexType(PgIndexType.HNSW) .removeExistingVectorStoreTable(true) .build();