-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change tests so they can run in CI/CD
- Loading branch information
1 parent
93185a0
commit 083b569
Showing
2 changed files
with
40 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from lamine.types import Provider | ||
|
||
|
||
class Mock(Provider): | ||
model_ids = ["model1", "model2"] | ||
locations = ["location1", "location2"] | ||
env_vars = ["var1"] | ||
|
||
def get_answer(self, model, conversation, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,46 @@ | ||
import pytest | ||
from lamine.types import Model | ||
import os | ||
|
||
import pytest | ||
|
||
def test_valid_provider_no_locations(): | ||
model = Model(provider="anthropic", id="claude-3.5-sonnet-latest") | ||
assert model.provider == "anthropic" | ||
assert model.locations is None | ||
from lamine.types import Model | ||
|
||
|
||
def test_valid_provider_with_supported_locations(): | ||
model = Model(provider="vertex", id="gemini-1.5-flash-002", locations=["us-central1"]) | ||
assert model.provider == "vertex" | ||
assert model.locations == ["us-central1"] | ||
def test_unsupported_provider(): | ||
with pytest.raises(ValueError) as excinfo: | ||
Model(provider="unsupported-provider", id="model-id") | ||
assert str(excinfo.value).startswith("Provider 'unsupported-provider' is not supported:") | ||
|
||
|
||
def test_valid_provider_with_multiple_supported_locations(): | ||
model = Model(provider="vertex", id="gemini-1.5-flash-002", locations=["us-central1", "eu-central1"]) | ||
assert model.provider == "vertex" | ||
assert model.locations == ["us-central1", "eu-central1"] | ||
def test_false_env_vars(): | ||
if os.getenv("var1"): | ||
del os.environ["var1"] | ||
with pytest.raises(EnvironmentError) as excinfo: | ||
Model("mock", "model1") | ||
assert str(excinfo.value) == "Provider 'mock' requires environmental variable 'var1'" | ||
|
||
|
||
def test_invalid_location_for_provider(): | ||
model = Model(provider="anthropic", id="claude-3.5-sonnet-latest", locations=["us-central1"]) | ||
assert model | ||
# assert "Provider anthropic does not support `locations`." in caplog.text | ||
def test_correct_env_vars(): | ||
os.environ["var1"] = "mock" | ||
model = Model("mock", "model1") | ||
assert model.provider == "mock" | ||
|
||
|
||
def test_unsupported_provider(): | ||
with pytest.raises(ValueError) as excinfo: | ||
Model(provider="unsupported-provider", id="model-id") | ||
assert str(excinfo.value).startswith("Provider 'unsupported-provider' is not supported:") | ||
def test_valid_model(): | ||
model = Model(provider="mock", id="model1") | ||
assert model.provider == "mock" | ||
assert model.id == "model1" | ||
|
||
|
||
def test_invalid_location_for_vertex_provider(): | ||
model = Model(provider="vertex", id="gemini-2-flash", locations=["invalid-location"]) | ||
assert model | ||
# assert "Provider vertex does not support location invalid-location:" in caplog.text | ||
def test_invalid_model(): | ||
Model(provider="mock", id="invalid-model") | ||
# assert "Provider 'mock' does not support model 'invalid-model'" | ||
|
||
|
||
def test_valid_model_id_for_vertex_provider(): | ||
model = Model(provider="vertex", id="gemini-1.5-pro-002") | ||
assert model.provider == "vertex" | ||
assert model.id == "gemini-1.5-pro-002" | ||
def test_valid_locations(): | ||
model = Model(provider="mock", id="model1", locations=["location1"]) | ||
assert model.locations == ["location1"] | ||
|
||
|
||
def test_invalid_model_id_for_vertex_provider(): | ||
model = Model(provider="vertex", id="invalid-model-id") | ||
assert model | ||
# assert "Provider vertex does not support model invalid-model-id:" in caplog.text | ||
def test_invalid_locations(): | ||
Model(provider="mock", id="invalid-model", locations=["location3"]) | ||
# assert "Provider 'mock' does not support model 'location3'" |