Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openviking/models/embedder/cohere_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
super().__init__(model_name, config)
self.provider = "cohere"

self.api_key = api_key
self.api_base = (api_base or "https://api.cohere.com").rstrip("/")
Expand Down
1 change: 1 addition & 0 deletions openviking/models/embedder/gemini_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
super().__init__(model_name, config)
self.provider = "gemini"
if not api_key:
raise ValueError("Gemini provider requires api_key")
if task_type and task_type not in _VALID_TASK_TYPES:
Expand Down
1 change: 1 addition & 0 deletions openviking/models/embedder/jina_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
ValueError: If api_key is not provided
"""
super().__init__(model_name, config)
self.provider = "jina"

self.api_key = api_key
self.api_base = api_base or "https://api.jina.ai/v1"
Expand Down
1 change: 1 addition & 0 deletions openviking/models/embedder/litellm_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
config: Additional configuration dict.
"""
super().__init__(model_name, config)
self.provider = "litellm"

os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True")

Expand Down
1 change: 1 addition & 0 deletions openviking/models/embedder/minimax_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
extra_headers: Extra headers, useful for passing GroupId for MiniMax API
"""
super().__init__(model_name, config)
self.provider = "minimax"

self.api_key = api_key
self.api_base = api_base or self.DEFAULT_API_BASE
Expand Down
2 changes: 2 additions & 0 deletions openviking/models/embedder/openai_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
extra_headers: Optional[Dict[str, str]] = None,
input_type: Optional[str] = None,
provider: str = "openai",
configured_provider: Optional[str] = None,
):
"""Initialize OpenAI-Compatible Dense Embedder

Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
self.query_param = query_param
self.document_param = document_param
self._provider = provider.lower()
self.provider = (configured_provider or provider).lower()
self._client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"}

# Allow missing api_key when api_base is set (e.g. local OpenAI-compatible servers)
Expand Down
3 changes: 3 additions & 0 deletions openviking/models/embedder/vikingdb_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
DenseEmbedderBase.__init__(self, model_name, config)
self.provider = "vikingdb"
self._init_vikingdb_client(ak, sk, region, host)
self.model_version = model_version
self.dimension = dimension
Expand Down Expand Up @@ -299,6 +300,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
SparseEmbedderBase.__init__(self, model_name, config)
self.provider = "vikingdb"
self._init_vikingdb_client(ak, sk, region, host)
self.model_version = model_version
self.sparse_model = {
Expand Down Expand Up @@ -438,6 +440,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
HybridEmbedderBase.__init__(self, model_name, config)
self.provider = "vikingdb"
self._init_vikingdb_client(ak, sk, region, host)
self.model_version = model_version
self.dimension = dimension
Expand Down
3 changes: 3 additions & 0 deletions openviking/models/embedder/volcengine_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
ValueError: If api_key is not provided
"""
super().__init__(model_name, config)
self.provider = "volcengine"

self.api_key = api_key
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
Expand Down Expand Up @@ -326,6 +327,7 @@ def __init__(
ValueError: If api_key is not provided
"""
super().__init__(model_name, config)
self.provider = "volcengine"

self.api_key = api_key
self.api_base = api_base
Expand Down Expand Up @@ -512,6 +514,7 @@ def __init__(
ValueError: If api_key is not provided
"""
super().__init__(model_name, config)
self.provider = "volcengine"
self.api_key = api_key
self.api_base = api_base
self.dimension = dimension
Expand Down
1 change: 1 addition & 0 deletions openviking/models/embedder/voyage_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
):
super().__init__(model_name, config)
self.provider = "voyage"

self.api_key = api_key
self.api_base = api_base or "https://api.voyageai.com/v1"
Expand Down
3 changes: 3 additions & 0 deletions openviking_cli/utils/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _create_embedder(
"api_version": cfg.api_version,
"dimension": cfg.dimension,
"provider": "openai",
"configured_provider": "openai",
"config": dict(runtime_config),
**({"query_param": cfg.query_param} if cfg.query_param else {}),
**({"document_param": cfg.document_param} if cfg.document_param else {}),
Expand All @@ -391,6 +392,7 @@ def _create_embedder(
"api_version": cfg.api_version,
"dimension": cfg.dimension,
"provider": "azure",
"configured_provider": "azure",
"config": dict(runtime_config),
**({"query_param": cfg.query_param} if cfg.query_param else {}),
**({"document_param": cfg.document_param} if cfg.document_param else {}),
Expand Down Expand Up @@ -500,6 +502,7 @@ def _create_embedder(
or "no-key", # Ollama ignores the key, but client requires non-empty
"api_base": cfg.api_base or "http://localhost:11434/v1",
"dimension": cfg.dimension,
"configured_provider": "ollama",
"config": dict(runtime_config),
},
),
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/test_extra_headers_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
4. api_key dead-code bug fix: no raise when api_base is set without api_key
"""

from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand Down Expand Up @@ -115,6 +115,49 @@ def test_factory_injects_embedding_max_retries(self, mock_openai_class):

assert embedder.max_retries == 0

@pytest.mark.asyncio
@patch("openviking.models.embedder.openai_embedders.openai.AsyncOpenAI")
@patch("openviking.models.embedder.openai_embedders.openai.OpenAI")
async def test_factory_uses_configured_provider_for_slow_call_logging(
self,
mock_openai_class,
mock_async_openai_class,
):
"""Slow-call warnings should log the configured provider, not the transport client mode."""
mock_openai_class.return_value = _make_mock_client()

async_response = MagicMock(
data=[MagicMock(embedding=[0.1] * 8)],
usage=None,
)
mock_async_client = MagicMock()
mock_async_client.embeddings.create = AsyncMock(return_value=async_response)
mock_async_openai_class.return_value = mock_async_client

cfg = EmbeddingModelConfig(
provider="ollama",
model="nomic-embed-text",
api_base="http://localhost:11434/v1",
dimension=8,
)
embedder = EmbeddingConfig(dense=cfg)._create_embedder("ollama", "dense", cfg)

with (
patch(
"openviking.models.embedder.openai_embedders.logger.warning"
) as mock_warning,
patch(
"openviking.models.embedder.base.time.monotonic",
side_effect=[0.0, 0.0, 0.0, 1.2],
),
):
await embedder.embed_async("hello")

mock_warning.assert_called_once()
call_args = mock_warning.call_args.args
assert call_args[1] == "OpenAI async embedding"
assert call_args[2] == "ollama"


class TestEmbeddingModelConfigExtraHeaders:
"""Test that EmbeddingModelConfig accepts and stores the extra_headers field."""
Expand Down
Loading