Skip to content

Commit

Permalink
Merge pull request #90 from harishmohanraj/fix-upstream-conflicts-1
Browse files Browse the repository at this point in the history
Fix upstream conflicts 1
  • Loading branch information
harishmohanraj authored Jan 21, 2025
2 parents e0d7e90 + 94b0551 commit 6fa3103
Show file tree
Hide file tree
Showing 86 changed files with 574 additions and 508 deletions.
3 changes: 1 addition & 2 deletions autogen/agentchat/assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
# SPDX-License-Identifier: MIT
from typing import Callable, Literal, Optional, Union

from autogen.runtime_logging import log_new_agent, logging_enabled

from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent


Expand Down
27 changes: 16 additions & 11 deletions autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import re
from typing import Any, Literal, Optional, Protocol, Union

from PIL.Image import Image
from openai import OpenAI

from autogen import Agent, ConversableAgent, code_utils
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
from autogen.cache import AbstractCache
from .... import Agent, ConversableAgent, code_utils
from ....cache import AbstractCache
from ....import_utils import optional_import_block, require_optional_import
from .. import img_utils
from ..capabilities.agent_capability import AgentCapability
from ..text_analyzer_agent import TextAnalyzerAgent

with optional_import_block():
from PIL.Image import Image

SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
Expand All @@ -34,7 +37,7 @@ class ImageGenerator(Protocol):
NOTE: Current implementation does not allow you to edit a previously existing image.
"""

def generate_image(self, prompt: str) -> Image:
def generate_image(self, prompt: str) -> "Image":
"""Generates an image based on the provided prompt.
Args:
Expand Down Expand Up @@ -62,6 +65,7 @@ def cache_key(self, prompt: str) -> str:
...


@require_optional_import("PIL", "unknown")
class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.
Expand Down Expand Up @@ -94,7 +98,7 @@ def __init__(
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])

def generate_image(self, prompt: str) -> Image:
def generate_image(self, prompt: str) -> "Image":
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
Expand All @@ -114,6 +118,7 @@ def cache_key(self, prompt: str) -> str:
return ",".join([str(k) for k in keys])


@require_optional_import("PIL", "unknown")
class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
Expand Down Expand Up @@ -253,15 +258,15 @@ def _extract_prompt(self, last_message) -> str:
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)

def _cache_get(self, prompt: str) -> Optional[Image]:
def _cache_get(self, prompt: str) -> Optional["Image"]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)

if cached_value:
return img_utils.get_pil_image(cached_value)

def _cache_set(self, prompt: str, image: Image):
def _cache_set(self, prompt: str, image: "Image"):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))
Expand All @@ -272,7 +277,7 @@ def _extract_analysis(self, analysis: Union[str, dict, None]) -> str:
else:
return code_utils.content_str(analysis)

def _generate_content_message(self, prompt: str, image: Image) -> dict[str, Any]:
def _generate_content_message(self, prompt: str, image: "Image") -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
Expand Down
24 changes: 18 additions & 6 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from typing import Union

import requests
from PIL import Image

from ...import_utils import optional_import_block, require_optional_import

with optional_import_block():
from PIL import Image

from autogen.agentchat import utils

Expand All @@ -37,7 +41,8 @@
}


def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
@require_optional_import("PIL", "unknown")
def get_pil_image(image_file: Union[str, "Image.Image"]) -> "Image.Image":
"""Loads an image from a file and returns a PIL Image object.
Parameters:
Expand Down Expand Up @@ -75,7 +80,8 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
return image.convert("RGB")


def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
@require_optional_import("PIL", "unknown")
def get_image_data(image_file: Union[str, "Image.Image"], use_b64=True) -> bytes:
"""Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
Expand Down Expand Up @@ -105,6 +111,7 @@ def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
return content


@require_optional_import("PIL", "unknown")
def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str, list[str]]:
"""Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
Expand Down Expand Up @@ -149,7 +156,8 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> tuple[str,
return new_prompt, images


def pil_to_data_uri(image: Image.Image) -> str:
@require_optional_import("PIL", "unknown")
def pil_to_data_uri(image: "Image.Image") -> str:
"""Converts a PIL Image object to a data URI.
Parameters:
Expand Down Expand Up @@ -184,6 +192,7 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri


@require_optional_import("PIL", "unknown")
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> list[Union[str, dict]]:
"""Formats the input prompt by replacing image tags and returns a list of text and images.
Expand Down Expand Up @@ -251,7 +260,8 @@ def extract_img_paths(paragraph: str) -> list:
return img_paths


def _to_pil(data: str) -> Image.Image:
@require_optional_import("PIL", "unknown")
def _to_pil(data: str) -> "Image.Image":
"""Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
Expand All @@ -266,6 +276,7 @@ def _to_pil(data: str) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(data)))


@require_optional_import("PIL", "unknown")
def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]:
"""Converts the PIL image URLs in the messages to base64 encoded data URIs.
Expand Down Expand Up @@ -321,8 +332,9 @@ def message_formatter_pil_to_b64(messages: list[dict]) -> list[dict]:
return new_messages


@require_optional_import("PIL", "unknown")
def num_tokens_from_gpt_image(
image_data: Union[str, Image.Image], model: str = "gpt-4-vision", low_quality: bool = False
image_data: Union[str, "Image.Image"], model: str = "gpt-4-vision", low_quality: bool = False
) -> int:
"""Calculate the number of tokens required to process an image based on its dimensions
after scaling for different GPT models. Supports "gpt-4-vision", "gpt-4o", and "gpt-4o-mini".
Expand Down
4 changes: 3 additions & 1 deletion autogen/agentchat/realtime_agent/oai_realtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection

from ...import_utils import optional_import_block
from .realtime_client import Role

if TYPE_CHECKING:
from fastapi.websockets import WebSocket
with optional_import_block():
from fastapi.websockets import WebSocket

from .realtime_client import RealtimeClientProtocol

Expand Down
10 changes: 7 additions & 3 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# SPDX-License-Identifier: Apache-2.0

from logging import Logger, getLogger
from typing import Any, Callable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import anyio
from asyncer import asyncify, create_task_group, syncify
from fastapi import WebSocket

from ...import_utils import optional_import_block
from ...tools import Tool
from .. import SwarmAgent
from ..agent import Agent
Expand All @@ -19,6 +19,10 @@
from .realtime_client import RealtimeClientProtocol
from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
with optional_import_block():
from fastapi import WebSocket

F = TypeVar("F", bound=Callable[..., Any])

global_logger = getLogger(__name__)
Expand Down Expand Up @@ -49,7 +53,7 @@ def __init__(
llm_config: dict[str, Any],
voice: str = "alloy",
logger: Optional[Logger] = None,
websocket: Optional[WebSocket] = None,
websocket: Optional["WebSocket"] = None,
):
"""(Experimental) Agent for interacting with the Realtime Clients.
Expand Down
6 changes: 4 additions & 2 deletions autogen/agentchat/realtime_agent/twilio_audio_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from logging import Logger
from typing import TYPE_CHECKING, Any, Optional

from ...import_utils import optional_import_block
from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from fastapi.websockets import WebSocket
with optional_import_block():
from fastapi.websockets import WebSocket


LOG_EVENT_TYPES = [
Expand Down Expand Up @@ -144,5 +146,5 @@ async def initialize_session(self) -> None:

if TYPE_CHECKING:

def twilio_audio_adapter(websocket: WebSocket) -> RealtimeObserver:
def twilio_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return TwilioAudioAdapter(websocket)
11 changes: 6 additions & 5 deletions autogen/agentchat/realtime_agent/websocket_audio_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from logging import Logger
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
from fastapi.websockets import WebSocket


from ...import_utils import optional_import_block
from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
with optional_import_block():
from fastapi.websockets import WebSocket

LOG_EVENT_TYPES = [
"error",
"response.content.done",
Expand Down Expand Up @@ -138,5 +139,5 @@ async def run_loop(self) -> None:

if TYPE_CHECKING:

def websocket_audio_adapter(websocket: WebSocket) -> RealtimeObserver:
def websocket_audio_adapter(websocket: "WebSocket") -> RealtimeObserver:
return WebSocketAudioAdapter(websocket)
6 changes: 0 additions & 6 deletions autogen/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations

import sys
from types import TracebackType
from typing import Any

from .abstract_cache_base import AbstractCache
from .cache_factory import CacheFactory

if sys.version_info >= (3, 11):
pass
else:
pass


class Cache(AbstractCache):
"""A wrapper class for managing cache configuration and instances.
Expand Down
14 changes: 9 additions & 5 deletions autogen/cache/cosmos_db_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,24 @@
import pickle
from typing import Any, Optional, TypedDict, Union

from azure.cosmos import CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosResourceNotFoundError
from ..import_utils import optional_import_block, require_optional_import
from .abstract_cache_base import AbstractCache

from autogen.cache.abstract_cache_base import AbstractCache
with optional_import_block():
from azure.cosmos import CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosResourceNotFoundError


@require_optional_import("azure", "cosmosdb")
class CosmosDBConfig(TypedDict, total=False):
connection_string: str
database_id: str
container_id: str
cache_seed: Optional[Union[str, int]]
client: Optional[CosmosClient]
client: Optional["CosmosClient"]


@require_optional_import("azure", "cosmosdb")
class CosmosDBCache(AbstractCache):
"""Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
Expand Down Expand Up @@ -75,7 +79,7 @@ def from_connection_string(cls, seed: Union[str, int], connection_string: str, d
return cls(str(seed), config)

@classmethod
def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str):
def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str):
config = {"client": client, "database_id": database_id, "container_id": container_id}
return cls(str(seed), config)

Expand Down
11 changes: 7 additions & 4 deletions autogen/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
from types import TracebackType
from typing import Any, Optional, Union

import redis

from .abstract_cache_base import AbstractCache

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from ..import_utils import optional_import_block, require_optional_import
from .abstract_cache_base import AbstractCache

with optional_import_block():
import redis


@require_optional_import("redis", "redis")
class RedisCache(AbstractCache):
"""Implementation of AbstractCache using the Redis database.
Expand Down
3 changes: 1 addition & 2 deletions autogen/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import docker

from autogen import oai

from . import oai
from .types import UserMessageImageContentPart, UserMessageTextContentPart

SENTINEL = object()
Expand Down
4 changes: 1 addition & 3 deletions autogen/coding/jupyter/docker_jupyter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

import docker

from ..docker_commandline_code_executor import _wait_for_ready

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


from ..docker_commandline_code_executor import _wait_for_ready
from .base import JupyterConnectable, JupyterConnectionInfo
from .jupyter_client import JupyterClient

Expand Down
Loading

0 comments on commit 6fa3103

Please sign in to comment.