Skip to content

Commit

Permalink
remove enums
Browse files Browse the repository at this point in the history
  • Loading branch information
nachollorca committed Jan 20, 2025
1 parent 0a74fab commit 84e4e16
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/lamine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def get_answer(model: Model, conversation: list[Message], **kwargs) -> Answer:
"""
start = time()
module = import_module(f"lamine.providers.{model.provider}")
provider = getattr(module, f"{model.provider.capitalize()}")
answer = provider().get_answer(model=model, conversation=conversation, **kwargs)
provider = getattr(module, f"{model.provider.capitalize()}")()
answer = provider.get_answer(model=model, conversation=conversation, **kwargs)
answer.time = round(time() - start, 3)
return answer

Expand Down
11 changes: 6 additions & 5 deletions src/lamine/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
import os
from ..utils import BaseEnum

dir = Path(__file__).parent.resolve()
provider_ids = [provider.split(".")[0] for provider in os.listdir(f"{dir}") if provider.endswith(".py")]
PROVIDERS = BaseEnum.from_list(name="providers", lst=provider_ids)
dir = os.path.dirname(os.path.abspath(__file__))
PROVIDER_IDS = list()
for script in os.listdir(dir):
file_name, extension = os.path.splitext(script)
if not file_name.startswith("_") and extension == ".py":
PROVIDER_IDS.append(file_name)
50 changes: 31 additions & 19 deletions src/lamine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,32 @@

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from importlib import import_module
from typing import Optional

from .providers import PROVIDERS
from .utils import BaseDataclass, BaseEnum
from .providers import PROVIDER_IDS


@dataclass
class Message(BaseDataclass):
class Base:
"""Base class providing utility methods for other dataclasses."""

@property
def dict(self) -> dict:
return asdict(self)

@classmethod
def from_dict(cls, data: dict): # type: ignore
return cls(**data)

@property
def str(self) -> str:
return ""


@dataclass
class Message(Base):
"""
Represents a message in a conversation.
Expand All @@ -25,7 +41,7 @@ class Message(BaseDataclass):


@dataclass
class Answer(BaseDataclass):
class Answer(Base):
"""
Represents an answer generated by a language model.
Expand Down Expand Up @@ -53,7 +69,7 @@ def message(self) -> Message:


@dataclass
class Action(BaseDataclass):
class Action(Base):
"""
Represents an tool request by a Language Model.
Expand All @@ -67,7 +83,7 @@ class Action(BaseDataclass):


@dataclass
class Observation(BaseDataclass):
class Observation(Base):
"""
Represents the result of an Action.
Expand All @@ -81,7 +97,7 @@ class Observation(BaseDataclass):


@dataclass
class Model(BaseDataclass):
class Model(Base):
"""
Represents a language model.
Expand All @@ -97,18 +113,18 @@ class Model(BaseDataclass):

def __post_init__(self):
"""Checks whether the combination of provider, id and locations is valid."""
# check if provider is supported
# Crash if provider is not supported
self.provider = self.provider.lower()
if self.provider not in PROVIDERS:
raise ValueError(f"Provider {self.provider} is not supported: {PROVIDERS.values}.")
if self.provider not in PROVIDER_IDS:
raise ValueError(f"Provider {self.provider} is not supported: {PROVIDER_IDS}.")

# check if model_id is supported
# Warn if model_id is not supported
module = import_module(f"lamine.providers.{self.provider}")
provider = getattr(module, f"{self.provider.capitalize()}")()
if self.id not in provider.model_ids:
logging.warning(f"Provider {self.provider} does not support model {self.id}: {provider.model_ids.values}")
logging.warning(f"Provider {self.provider} does not support model {self.id}: {provider.model_ids}")

# check if locations are supported
# Warn if locations are not supported
if self.locations:
module = import_module(f"lamine.providers.{self.provider}")
if not provider.locations:
Expand All @@ -117,7 +133,7 @@ def __post_init__(self):
for location in self.locations:
if location not in provider.locations:
logging.warning(
f"Provider {self.provider} does not support location {location}: {provider.locations.values}"
f"Provider {self.provider} does not support location {location}: {provider.locations}"
)


Expand All @@ -134,10 +150,6 @@ class Provider(ABC):
model_ids: list[str]
locations: Optional[list[str]] = None

def __init__(self):
self.model_ids = BaseEnum.from_list("model_ids", self.model_ids) if self.model_ids else None
self.locations = BaseEnum.from_list("locations", self.locations) if self.locations else None

@abstractmethod
def get_answer(self, model: Model, conversation: list[Message], **kwargs) -> Answer:
"""
Expand Down
38 changes: 1 addition & 37 deletions src/lamine/utils.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,9 @@
"""Contains utility clases and functions."""
"""Contains utility clases and functions for the core modules."""

from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass
from enum import Enum, EnumMeta
from typing import Any, Callable


@dataclass
class BaseDataclass:
"""Base class providing utility methods for other dataclasses."""

@property
def dict(self) -> dict:
return asdict(self)

@classmethod
def from_dict(cls, data: dict): # type: ignore
return cls(**data)

@property
def str(self) -> str:
return ""


class BaseEnumMeta(EnumMeta):
@property
def values(cls):
"""Return a list of all the values in the enumeration."""
return [member.value for member in cls]


class BaseEnum(Enum, metaclass=BaseEnumMeta):
"""Enum that can be made from a list of strings and can easily print all values."""

@classmethod
def from_list(cls, name: str, lst: list[str]):
"""Turns a list of string ids into an Enum."""
enum_members = {model_id.upper().replace("-", "_").replace(".", "_"): model_id for model_id in lst}
return cls(name, enum_members) # type: ignore


def _run_batch(function: Callable, params_list: list[dict[str, Any]], max_workers: int = 10) -> list[Any]:
"""
Executes a batch of calls of a function with different parameters in parallel using threading.
Expand Down

0 comments on commit 84e4e16

Please sign in to comment.