diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 8e6cbbcbdb..2488be7fb9 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -50,7 +50,7 @@ async def embed_text( vector_store_config = strategy.get("vector_store") if vector_store_config: - index_name = _get_index_name(vector_store_config, embedding_name) + index_name = get_index_name(vector_store_config, embedding_name) vector_store: BaseVectorStore = _create_vector_store( vector_store_config, index_name, embedding_name ) @@ -217,7 +217,12 @@ def _create_vector_store( return vector_store -def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: +def get_index_name(vector_store_config: dict, embedding_name: str) -> str: + collection_name = vector_store_config.get("collection_name") + if collection_name: + msg = f"using vector store {vector_store_config.get('type')} with user provided collection_name {collection_name} for embedding {embedding_name}" + logger.info(msg) + return collection_name container_name = vector_store_config.get("container_name", "default") index_name = create_index_name(container_name, embedding_name) diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index db3d94790d..a8b4a390ee 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -8,7 +8,7 @@ from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.config.embeddings import create_index_name +from graphrag.index.operations.embed_text.embed_text import get_index_name from graphrag.config.models.cache_config import CacheConfig from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig @@ -104,10 +104,7 @@ def get_embedding_store( index_names = [] for index, store in config_args.items(): vector_store_type = store["type"] - index_name = create_index_name( - store.get("container_name", "default"), embedding_name - ) - + index_name = get_index_name(store, embedding_name) embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( "embeddings_schema", {} )