Skip to content

Commit

Permalink
Fuck big commit bc fuck mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
nachollorca committed Jan 20, 2025
1 parent ce28671 commit 283522e
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 67 deletions.
9 changes: 4 additions & 5 deletions src/lamine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ def get_answer(model: Model, conversation: list[Message], **kwargs) -> Answer:
Answer: The response generated by the language model.
"""
start = time()
module = import_module(f"agilm.providers.{model.provider}")
module = import_module(f"lamine.providers.{model.provider}")
provider = getattr(module, f"{model.provider.capitalize()}")
answer = provider().get_answer(model=model, conversation=conversation, **kwargs)
answer.time = round(time() - start, 3)
return answer

def get_answers(
model: Model, conversations: list[list[Message]], **kwargs
) -> list[Answer]:

def get_answers(model: Model, conversations: list[list[Message]], **kwargs) -> list[Answer]: # type: ignore
"""
Request responses from a language model API for multiple conversations in parallel.
Expand All @@ -39,4 +38,4 @@ def get_answers(
Returns:
list[Answer]: A list of responses generated by the language model for each conversation.
"""
pass
...
7 changes: 7 additions & 0 deletions src/lamine/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pathlib import Path
import os
from ..utils import BaseEnum

dir = Path(__file__).parent.resolve()
provider_ids = [provider.split(".")[0] for provider in os.listdir(f"{dir}") if provider.endswith(".py")]
PROVIDERS = BaseEnum.from_list(name="providers", lst=provider_ids)
95 changes: 47 additions & 48 deletions src/lamine/types.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,31 @@
"""Contains all types defined to represent interactions with LLM APIs."""
"""Contains all types defined to represent interactions with LLM APIs. Utility parent types are defined in .utils"""

import os
import logging
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from typing import Optional

from .utils import list_to_enum
from .providers import PROVIDERS
from .utils import BaseDataclass, BaseEnum


@dataclass
class Base:
"""
Base class providing utility methods for other dataclasses.
"""
@property
def dict(self) -> dict:
return asdict(self)

@classmethod
def from_dict(cls, data: dict):
return cls(**data)

@property
def str(self) -> str:
return ""


@dataclass
class Message(Base):
class Message(BaseDataclass):
"""
Represents a message in a conversation.
Attributes:
role (str): The role of the message sender (e.g., "user", "assistant").
content (str): The content of the message.
"""

role: str
content: str


@dataclass
class Answer(Base):
class Answer(BaseDataclass):
"""
Represents an answer generated by a language model.
Expand All @@ -51,6 +35,7 @@ class Answer(Base):
tokens_out (int): The number of output tokens generated in the answer.
time (float): The time taken to generate the answer in seconds.
"""

content: str
tokens_in: int
tokens_out: int
Expand All @@ -65,33 +50,38 @@ def message(self) -> Message:
Message: A Message object representing the answer.
"""
return Message("assistant", self.content)



@dataclass
class Action(Base):
class Action(BaseDataclass):
"""
Represents an tool request by a Language Model.
Attributes:
name (str): The name of the tool.
params (dict): A dictionary with input parameters.
"""

name: str
params: dict


@dataclass
class Observation(Base):
class Observation(BaseDataclass):
"""
Represents the result of an Action.
Attributes:
status (str): The status of the observation (e.g., "success", "failure").
content (str): The content of the observation.
"""

status: str
content: str


@dataclass
class Model(Base):
class Model(BaseDataclass):
"""
Represents a language model.
Expand All @@ -100,6 +90,7 @@ class Model(Base):
id (str): The unique identifier of the language model.
locations (list[str]): A list of host locations to randomly pick from.
"""

provider: str
id: str
locations: Optional[list[str]] = None
Expand All @@ -108,36 +99,44 @@ def __post_init__(self):
"""Checks whether the combination of provider, id and locations is valid."""
# check if provider is supported
self.provider = self.provider.lower()
dir = Path(__file__).parent.resolve()
# ToDo: make providers an Enum so user can access with `providers.`
providers = [provider.split(".")[0] for provider in os.listdir(f"{dir}/providers") if provider.endswith(".py")]
if self.provider not in providers:
raise ValueError(f"Provider {self.provider} is not supported: {providers}.")

if self.provider not in PROVIDERS:
raise ValueError(f"Provider {self.provider} is not supported: {PROVIDERS.values}.")

# check if model_id is supported
module = import_module(f"agilm.providers.{self.provider}")
module = import_module(f"lamine.providers.{self.provider}")
provider = getattr(module, f"{self.provider.capitalize()}")()
if self.id not in provider.model_ids:
raise ValueError(f"Provider {self.provider} does not support model {self.id}: {provider.model_ids}") # ToDo: warn instead of raise
logging.warning(f"Provider {self.provider} does not support model {self.id}: {provider.model_ids.values}")

# check if locations are supported
if self.locations:
module = import_module(f"agilm.providers.{self.provider}")
module = import_module(f"lamine.providers.{self.provider}")
if not provider.locations:
raise ValueError(f"Provider {self.provider} does not support `locations`.")
for location in self.locations:
if location not in provider.locations:
raise ValueError(f"Provider {self.provider} does not support location {location}: {provider.locations}")
logging.warning(f"Provider {self.provider} does not support `locations`.")
else:
for location in self.locations:
if location not in provider.locations:
logging.warning(
f"Provider {self.provider} does not support location {location}: {provider.locations.values}"
)


class Provider(ABC):
"""Abstract class defining the shared interface for API providers."""
"""
Abstract class defining the shared interface for API providers.
Attributes:
model_ids (list[str]): a list of valid model ids supported by the provider, e.g. "gemini-2-flash".
locations (list[str]): a list of valid locations where the provider is hosts LMs, e.g. "eu-central1".
This generally applies only to provideders Vertex, Bedrock and Azure.
"""

model_ids: list[str]
locations: Optional[list[str]] = None

def __init__(self):
self.model_ids = list_to_enum(self.model_ids) if self.model_ids else None
self.locations = list_to_enum(self.locations) if self.locations else None
self.model_ids = BaseEnum.from_list("model_ids", self.model_ids) if self.model_ids else None
self.locations = BaseEnum.from_list("locations", self.locations) if self.locations else None

@abstractmethod
def get_answer(self, model: Model, conversation: list[Message], **kwargs) -> Answer:
Expand All @@ -152,4 +151,4 @@ def get_answer(self, model: Model, conversation: list[Message], **kwargs) -> Ans
Returns:
Answer: The response generated by the language model.
"""
...
...
53 changes: 39 additions & 14 deletions src/lamine/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,46 @@
"""Contains utility functions."""
"""Contains utility clases and functions."""

from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass
from enum import Enum, EnumMeta
from typing import Any, Callable
from enum import Enum


def _run_batch(
function: Callable, params_list: list[dict[str, Any]], max_workers: int = 10
) -> list[Any]:
@dataclass
class BaseDataclass:
"""Base class providing utility methods for other dataclasses."""

@property
def dict(self) -> dict:
return asdict(self)

@classmethod
def from_dict(cls, data: dict): # type: ignore
return cls(**data)

@property
def str(self) -> str:
return ""


class BaseEnumMeta(EnumMeta):
@property
def values(cls):
"""Return a list of all the values in the enumeration."""
return [member.value for member in cls]


class BaseEnum(Enum, metaclass=BaseEnumMeta):
"""Enum that can be made from a list of strings and can easily print all values."""

@classmethod
def from_list(cls, name: str, lst: list[str]):
"""Turns a list of string ids into an Enum."""
enum_members = {model_id.upper().replace("-", "_").replace(".", "_"): model_id for model_id in lst}
return cls(name, enum_members) # type: ignore


def _run_batch(function: Callable, params_list: list[dict[str, Any]], max_workers: int = 10) -> list[Any]:
"""
Executes a batch of calls of a function with different parameters in parallel using threading.
Expand All @@ -22,16 +55,8 @@ def _run_batch(
"""
results = [None] * len(params_list)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures_to_indices = {
executor.submit(function, **params): index
for index, params in enumerate(params_list)
}
futures_to_indices = {executor.submit(function, **params): index for index, params in enumerate(params_list)}
for future in as_completed(futures_to_indices):
index = futures_to_indices[future]
results[index] = future.result()
return results

def list_to_enum(lst: list[str]) -> Enum:
"""Turns a list of strings ids into an Enum."""
enum_members = {model_id.upper().replace('-', '_').replace(".", "_"): model_id for model_id in lst}
return Enum("ModelID", enum_members)

0 comments on commit 283522e

Please sign in to comment.