diff --git a/extensions/Postgres/Postgres.TestApplication/Program.cs b/extensions/Postgres/Postgres.TestApplication/Program.cs index 915dd409b..37fb7b5ec 100644 --- a/extensions/Postgres/Postgres.TestApplication/Program.cs +++ b/extensions/Postgres/Postgres.TestApplication/Program.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.KernelMemory; +using Microsoft.KernelMemory.AI.Ollama; using Microsoft.KernelMemory.DocumentStorage.DevTools; using Microsoft.KernelMemory.FileSystem.DevTools; @@ -26,16 +27,13 @@ private static async Task Test1() var postgresConfig = cfg.GetSection("KernelMemory:Services:Postgres").Get(); ArgumentNullExceptionEx.ThrowIfNull(postgresConfig, nameof(postgresConfig), "Postgres config not found"); - var azureOpenAIEmbeddingConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIEmbedding").Get(); - ArgumentNullExceptionEx.ThrowIfNull(azureOpenAIEmbeddingConfig, nameof(azureOpenAIEmbeddingConfig), "AzureOpenAIEmbedding config not found"); - - var azureOpenAITextConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIText").Get(); - ArgumentNullExceptionEx.ThrowIfNull(azureOpenAITextConfig, nameof(azureOpenAITextConfig), "AzureOpenAIText config not found"); + var ollamaConfig = cfg.GetSection("KernelMemory:Services:Ollama").Get(); + ArgumentNullExceptionEx.ThrowIfNull(ollamaConfig, nameof(ollamaConfig), "Ollama config not found"); // Concatenate our 'WithPostgresMemoryDb()' after 'WithOpenAIDefaults()' from the core nuget var mem1 = new KernelMemoryBuilder() - .WithAzureOpenAITextGeneration(azureOpenAITextConfig) - .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig) + .WithOllamaTextEmbeddingGeneration(ollamaConfig) + .WithOllamaTextGeneration(ollamaConfig) .WithPostgresMemoryDb(postgresConfig) .WithSimpleFileStorage(SimpleFileStorageConfig.Persistent) .Build(); @@ -44,16 +42,16 @@ private static async Task Test1() var mem2 = new KernelMemoryBuilder() .WithPostgresMemoryDb(postgresConfig) .WithSimpleFileStorage(SimpleFileStorageConfig.Persistent) - .WithAzureOpenAITextGeneration(azureOpenAITextConfig) - .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig) + .WithOllamaTextEmbeddingGeneration(ollamaConfig) + .WithOllamaTextGeneration(ollamaConfig) .Build(); // Concatenate our 'WithPostgresMemoryDb()' before and after KM builder extension methods from the core nuget var mem3 = new KernelMemoryBuilder() .WithSimpleFileStorage(SimpleFileStorageConfig.Persistent) - .WithAzureOpenAITextGeneration(azureOpenAITextConfig) + .WithOllamaTextEmbeddingGeneration(ollamaConfig) + .WithOllamaTextGeneration(ollamaConfig) .WithPostgresMemoryDb(postgresConfig) - .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig) .Build(); await mem1.DeleteIndexAsync("index1"); @@ -92,8 +90,7 @@ private static async Task Test1() private static async Task Test2() { var postgresConfig = new PostgresConfig(); - var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig(); - var azureOpenAITextConfig = new AzureOpenAIConfig(); + var ollamaConfig = new OllamaConfig(); new ConfigurationBuilder() .AddJsonFile("appsettings.json") @@ -101,13 +98,12 @@ private static async Task Test2() .AddJsonFile("appsettings.Development.json", optional: true) .Build() .BindSection("KernelMemory:Services:Postgres", postgresConfig) - .BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig) - .BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig); + .BindSection("KernelMemory:Services:Ollama", ollamaConfig); var memory = new KernelMemoryBuilder() .WithPostgresMemoryDb(postgresConfig) - .WithAzureOpenAITextGeneration(azureOpenAITextConfig) - .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig) + .WithOllamaTextGeneration(ollamaConfig) + .WithOllamaTextEmbeddingGeneration(ollamaConfig) .WithSimpleFileStorage(new SimpleFileStorageConfig { StorageType = FileSystemTypes.Disk, @@ -140,8 +136,7 @@ private static async Task Test2() private static async Task Test3() { var postgresConfig = new PostgresConfig(); - var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig(); - var azureOpenAITextConfig = new AzureOpenAIConfig(); + var ollamaConfig = new OllamaConfig(); // Note: using appsettings.custom-sql.json new ConfigurationBuilder() @@ -151,13 +146,12 @@ private static async Task Test3() .AddJsonFile("appsettings.custom-sql.json") .Build() .BindSection("KernelMemory:Services:Postgres", postgresConfig) - .BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig) - .BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig); + .BindSection("KernelMemory:Services:Ollama", ollamaConfig); var memory = new KernelMemoryBuilder() .WithPostgresMemoryDb(postgresConfig) - .WithAzureOpenAITextGeneration(azureOpenAITextConfig) - .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig) + .WithOllamaTextGeneration(ollamaConfig) + .WithOllamaTextEmbeddingGeneration(ollamaConfig) .WithSimpleFileStorage(new SimpleFileStorageConfig { StorageType = FileSystemTypes.Disk, diff --git a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs index 9a646842c..8660c1a5c 100644 --- a/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs +++ b/extensions/Postgres/Postgres/Internals/PostgresDbClient.cs @@ -42,6 +42,8 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase); this._schema = config.Schema; this._tableNamePrefix = config.TableNamePrefix; + this._textSearchLanguage = config.TextSearchLanguage; + this._rrfK = config.RRFK; this._colId = config.Columns[PostgresConfig.ColumnId]; this._colEmbedding = config.Columns[PostgresConfig.ColumnEmbedding]; @@ -59,6 +61,14 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n this._columnsListNoEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload}"; this._columnsListWithEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}"; + this._columnsListHybrid = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}"; + this._columnsListHybridCoalesce = $@" + COALESCE(semantic_search.{this._colId}, keyword_search.{this._colId}) AS {this._colId}, + COALESCE(semantic_search.{this._colTags}, keyword_search.{this._colTags}) AS {this._colTags}, + COALESCE(semantic_search.{this._colContent}, keyword_search.{this._colContent}) AS {this._colContent}, + COALESCE(semantic_search.{this._colPayload}, keyword_search.{this._colPayload}) AS {this._colPayload}, + COALESCE(semantic_search.{this._colEmbedding}, keyword_search.{this._colEmbedding}) AS {this._colEmbedding} + "; this._createTableSql = string.Empty; if (config.CreateTableSql?.Count > 0) @@ -138,6 +148,8 @@ public async Task CreateTableAsync( CancellationToken cancellationToken = default) { var origInputTableName = tableName; + var indexTags = this.WithTableNamePrefix(tableName) + "_idx_tags"; + var indexContent = this.WithTableNamePrefix(tableName) + "_idx_content"; tableName = this.WithSchemaAndTableNamePrefix(tableName); this._log.LogTrace("Creating table: {0}", tableName); @@ -175,7 +187,8 @@ public async Task CreateTableAsync( {this._colContent} TEXT DEFAULT '' NOT NULL, {this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL ); - CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); + CREATE INDEX IF NOT EXISTS ""{indexTags}"" ON {tableName} USING GIN({this._colTags}); + CREATE INDEX IF NOT EXISTS ""{indexContent}"" ON {tableName} USING GIN(to_tsvector('{this._textSearchLanguage}',{this._colContent})); COMMIT; "; #pragma warning restore CA2100 @@ -388,6 +401,7 @@ DO UPDATE SET /// Get a list of records /// /// Table containing the records to fetch + /// Prompt query. Only used in the case of hybrid search /// Source vector to compare for similarity /// Minimum similarity threshold /// SQL filter to apply @@ -395,9 +409,11 @@ DO UPDATE SET /// Max number of records to retrieve /// Records to skip from the top /// Whether to include embedding vectors + /// Whether to use hybrid search or vector search /// Async task cancellation token public async IAsyncEnumerable<(PostgresMemoryRecord record, double similarity)> GetSimilarAsync( string tableName, + string query, Vector target, double minSimilarity, string? filterSql = null, @@ -405,6 +421,7 @@ DO UPDATE SET int limit = 1, int offset = 0, bool withEmbeddings = false, + bool useHybridSearch = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { tableName = this.WithSchemaAndTableNamePrefix(tableName); @@ -415,12 +432,15 @@ DO UPDATE SET string columns = withEmbeddings ? this._columnsListWithEmbeddings : this._columnsListNoEmbeddings; // Filtering logic, including filter by similarity + // filterSql = filterSql?.Trim().Replace(PostgresSchema.PlaceholdersTags, this._colTags, StringComparison.Ordinal); if (string.IsNullOrWhiteSpace(filterSql)) { filterSql = "TRUE"; } + string filterSqlHybridText = filterSql; + var maxDistance = 1 - minSimilarity; filterSql += $" AND {this._colEmbedding} <=> @embedding < @maxDistance"; @@ -440,16 +460,51 @@ DO UPDATE SET #pragma warning disable CA2100 // SQL reviewed string colDistance = "__distance"; - // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate - // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. - cmd.CommandText = @$" - SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} - FROM {tableName} - WHERE {filterSql} - ORDER BY {colDistance} ASC + if (useHybridSearch) + { + // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate + // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. + cmd.CommandText = @$" + WITH semantic_search AS ( + SELECT {this._columnsListHybrid}, RANK () OVER (ORDER BY {this._colEmbedding} <=> @embedding) AS rank + FROM {tableName} + WHERE {filterSql} + ORDER BY {this._colEmbedding} <=> @embedding + LIMIT @limit + ), + keyword_search AS ( + SELECT {this._columnsListHybrid}, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('{this._textSearchLanguage}', {this._colContent}), query) DESC) + FROM {tableName}, plainto_tsquery('{this._textSearchLanguage}', @query) query + WHERE {filterSqlHybridText} AND to_tsvector('{this._textSearchLanguage}', {this._colContent}) @@ query + ORDER BY ts_rank_cd(to_tsvector('{this._textSearchLanguage}', {this._colContent}), query) DESC + LIMIT @limit + ) + SELECT + {this._columnsListHybridCoalesce}, + COALESCE(1.0 / ({this._rrfK} + semantic_search.rank), 0.0) + + COALESCE(1.0 / ({this._rrfK} + keyword_search.rank), 0.0) AS {colDistance} + FROM semantic_search + FULL OUTER JOIN keyword_search ON semantic_search.{this._colId} = keyword_search.{this._colId} + ORDER BY {colDistance} DESC LIMIT @limit OFFSET @offset "; + cmd.Parameters.AddWithValue("@query", query); + cmd.Parameters.AddWithValue("@minSimilarity", minSimilarity); + } + else + { + // When using 1 - (embedding <=> target) the index is not being used, therefore we calculate + // the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. + cmd.CommandText = @$" + SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} + FROM {tableName} + WHERE {filterSql} + ORDER BY {colDistance} ASC + LIMIT @limit + OFFSET @offset + "; + } cmd.Parameters.AddWithValue("@embedding", target); cmd.Parameters.AddWithValue("@maxDistance", maxDistance); @@ -692,7 +747,11 @@ public async ValueTask DisposeAsync() private readonly string _colPayload; private readonly string _columnsListNoEmbeddings; private readonly string _columnsListWithEmbeddings; + private readonly string _columnsListHybrid; + private readonly string _columnsListHybridCoalesce; private readonly bool _dbNamePresent; + private readonly string _textSearchLanguage; + private readonly int _rrfK; /// /// Try to connect to PG, handling exceptions in case the DB doesn't exist diff --git a/extensions/Postgres/Postgres/PostgresConfig.cs b/extensions/Postgres/Postgres/PostgresConfig.cs index f34be9637..2d4822161 100644 --- a/extensions/Postgres/Postgres/PostgresConfig.cs +++ b/extensions/Postgres/Postgres/PostgresConfig.cs @@ -107,6 +107,24 @@ public class PostgresConfig /// public List CreateTableSql { get; set; } = []; + /// + /// Important: when using hybrid search, relevance scores + /// are very different from when using just vector search. + /// + public bool UseHybridSearch { get; set; } = false; + + /// + /// Defines the dictionary language used for the textual part of hybrid search. + /// see: https://www.postgresql.org/docs/current/textsearch-dictionaries.html + /// This query can help you to get the list of dictionaries: SELECT * FROM pg_catalog.pg_ts_dict; + /// + public string TextSearchLanguage { get; set; } = "english"; + + /// + /// Reciprocal Ranked Fusion to score results of Hybrid Search + /// + public int RRFK { get; set; } = 50; + /// /// Create a new instance of the configuration /// diff --git a/extensions/Postgres/Postgres/PostgresMemory.cs b/extensions/Postgres/Postgres/PostgresMemory.cs index 135034e18..97b5265d1 100644 --- a/extensions/Postgres/Postgres/PostgresMemory.cs +++ b/extensions/Postgres/Postgres/PostgresMemory.cs @@ -29,6 +29,8 @@ public sealed class PostgresMemory : IMemoryDb, IDisposable, IAsyncDisposable private readonly ITextEmbeddingGenerator _embeddingGenerator; private readonly ILogger _log; + private readonly bool _useHybridSearch; + /// /// Create a new instance of Postgres KM connector /// @@ -41,6 +43,7 @@ public PostgresMemory( ILoggerFactory? loggerFactory = null) { this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + this._useHybridSearch = config.UseHybridSearch; this._embeddingGenerator = embeddingGenerator; if (this._embeddingGenerator == null) @@ -160,12 +163,14 @@ await this._db.UpsertAsync( var records = this._db.GetSimilarAsync( index, + query: text, target: new Vector(textEmbedding.Data), minSimilarity: minRelevance, filterSql: sql, sqlUserValues: unsafeSqlUserValues, limit: limit, withEmbeddings: withEmbeddings, + useHybridSearch: this._useHybridSearch, cancellationToken: cancellationToken).ConfigureAwait(false); await foreach ((PostgresMemoryRecord record, double similarity) result in records) diff --git a/extensions/Postgres/README.md b/extensions/Postgres/README.md index 44786f6f6..0a93a394d 100644 --- a/extensions/Postgres/README.md +++ b/extensions/Postgres/README.md @@ -31,7 +31,10 @@ To use Postgres with Kernel Memory: "KernelMemory": { "Services": { "Postgres": { - "ConnectionString": "Host=localhost;Port=5432;Username=myuser;Password=mypassword;Database=mydatabase" + "ConnectionString": "Host=localhost;Port=5432;Username=myuser;Password=mypassword;Database=mydatabase", + "UseHybridSearch": true, + "TextSearchLanguage": "english", + "RRF_K" : 60, } } } @@ -42,16 +45,16 @@ To use Postgres with Kernel Memory: // using Microsoft.KernelMemory; // using Microsoft.KernelMemory.Postgres; // using Microsoft.Extensions.Configuration; - + var postgresConfig = new PostgresConfig(); - + new ConfigurationBuilder() .AddJsonFile("appsettings.json") .AddJsonFile("appsettings.development.json", optional: true) .AddJsonFile("appsettings.Development.json", optional: true) .Build() .BindSection("KernelMemory:Services:Postgres", postgresConfig); - + var memory = new KernelMemoryBuilder() .WithPostgresMemoryDb(postgresConfig) .WithSimpleFileStorage(SimpleFileStorageConfig.Persistent) @@ -103,6 +106,35 @@ types supported by Kernel Memory. Overall we recommend not mixing external tables in the same DB used for Kernel Memory. +## Hybrid Search + +The Postgres memory connector support Hybrid Search. + +Hybrid Search configuration parameters: + +- **UseHybridSearch**: This parameter enables (true) or disables (false) hybrid search. +- **TextSearchLanguage**: This parameter sets the language used during text search. +- **RRFK**: This parameter allows to configured [RRF](https://en.wikipedia.org/wiki/Mean_reciprocal_rank) for the hybrid search. + A smaller value of `RRF_k` gives more weight to higher ranked items, whereas a + larger value of `RRF_k` gives more weight to lower ranked items. For hybrid search, + this impacts the final score when combining the two scores from the search. + It defaults to 50. The range of value should be 1-100. + +For more details, check out the +[pgvector GitHub project](https://github.com/pgvector/pgvector) and +[this article](https://jkatz05.com/post/postgres/hybrid-search-postgres-pgvector) +on hybrid search in Postgres. + +The connector creates text search index automatically on table creation. + +In the case you activate the text search once PostgreSQL tables are created with an older +version of the connector or you want to change the TextSearchLanguage you will need +to create manually the text search index using the column names and table name that +you had configured. + +**SQL to add Text Search Index:** +'CREATE INDEX IF NOT EXISTS {indexName} ON {tableName} USING GIN(to_tsvector('TextSearchLanguage',{this._colContent})); + ## Column names and table schema The connector uses a default schema with predefined columns and indexes. diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index 1e7a0b415..b92d7b6eb 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -621,7 +621,14 @@ "ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public", // Mandatory prefix to add to the name of table managed by KM, // e.g. to exclude other tables in the same schema. - "TableNamePrefix": "km-" + "TableNamePrefix": "km-", + // Hybrid search is not enabled by default. Note that when using hybrid search + // relevance scores are different, usually lower, than when using just vector search + "UseHybridSearch": false, + // Defines the dictionary language of hybrid search + "TextSearchLanguage": "english", + // Reciprocal Ranked Fusion to score results of hybrid search + "RRFK": 50, }, "Qdrant": { // Qdrant endpoint