Skip to content

Commit

Permalink
Propagate response_format attribute to the client
Browse files Browse the repository at this point in the history
Signed-off-by: Sternakt <[email protected]>
  • Loading branch information
sternakt committed Nov 29, 2024
1 parent 00f43aa commit 6db8b7d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 33 deletions.
6 changes: 1 addition & 5 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,7 @@ def _validate_llm_config(self, llm_config):
raise ValueError(
"When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
)
self.client = (
None
if self.llm_config is False
else OpenAIWrapper(**self.llm_config, response_format=self._response_format)
)
self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)

@staticmethod
def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool:
Expand Down
18 changes: 10 additions & 8 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import sys
import uuid
from pickle import PickleError, PicklingError
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union

from pydantic import BaseModel, schema_json_of
Expand Down Expand Up @@ -753,7 +754,6 @@ def yes_or_no_filter(context, response):
- RuntimeError: If all declared custom model clients are not registered
- APIError: If any model client create call raises an APIError
"""
print(f"{response_format=}, {config=}")
if ERROR:
raise ERROR
invocation_id = str(uuid.uuid4())
Expand All @@ -769,17 +769,13 @@ def yes_or_no_filter(context, response):
for i, client in enumerate(self._clients):
# merge the input config with the i-th config in the config list
full_config = {**config, **self._config_list[i]}
print(f"{full_config=}")
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
api_type = extra_kwargs.get("api_type")
if api_type and api_type.startswith("azure") and "model" in create_config:
create_config["model"] = create_config["model"].replace(".", "")
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
if "response_format" in params:
params["response_format"] = schema_json_of(params["response_format"])
print(f"{params=}")
# get the cache_seed, filter_func and context
cache_seed = extra_kwargs.get("cache_seed", LEGACY_DEFAULT_CACHE_SEED)
cache = extra_kwargs.get("cache")
Expand Down Expand Up @@ -809,8 +805,11 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
with cache_client as cache:
# Try to get the response from cache
print(f"{params=}")
key = get_key(params)
key = get_key(
{**params, **{"response_format": schema_json_of(response_format)}}
if response_format
else params
)
request_ts = get_current_ts()

response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
Expand Down Expand Up @@ -916,7 +915,10 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
# Cache the response
with cache_client as cache:
cache.set(key, response)
try:
cache.set(key, response)
except (PicklingError, AttributeError) as e:
logger.info(f"Failed to cache response: {e}")

if logging_enabled():
# TODO: log the config_id and pass_filter etc.
Expand Down
Loading

0 comments on commit 6db8b7d

Please sign in to comment.