Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use google-genai package instead of old google-generativeai package #694

Merged
merged 13 commits into from
Jan 29, 2025
Merged
3 changes: 3 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.10/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.11/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.12/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.13/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
4 changes: 2 additions & 2 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
gemini_import_exception: Optional[ImportError] = None
else:
gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
gemini_import_exception = ImportError("google-generativeai not found")
gemini_import_exception = ImportError("google-genai not found")

with optional_import_block() as anthropic_result:
from anthropic import ( # noqa
Expand Down Expand Up @@ -756,7 +756,7 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
client = GeminiClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
Expand Down
59 changes: 35 additions & 24 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,22 @@
from .client_utils import FormatterProtocol

with optional_import_block():
import google.generativeai as genai
import google.genai as genai
import vertexai
from PIL import Image
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
from google.ai.generativelanguage_v1beta.types import Schema
from google.auth.credentials import Credentials
from google.generativeai.types import GenerateContentResponse
from google.genai.types import (
Content,
FunctionCall,
FunctionDeclaration,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
Part,
Schema,
Tool,
Type,
)
from jsonschema import ValidationError
from vertexai.generative_models import (
Content as VertexAIContent,
Expand Down Expand Up @@ -215,9 +224,9 @@ def create(self, params: dict) -> ChatCompletion:
if autogen_term in params
}
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", []))
else:
safety_settings = params.get("safety_settings", {})
safety_settings = params.get("safety_settings", [])

if stream:
warnings.warn(
Expand Down Expand Up @@ -254,27 +263,29 @@ def create(self, params: dict) -> ChatCompletion:
)

chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
else:
model = genai.GenerativeModel(
model_name,
generation_config=generation_config,
client = genai.Client(api_key=self.api_key)
generate_content_config = GenerateContentConfig(
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=tools,
**generation_config,
)

genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])

response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
chat = client.chats.create(model=model_name, config=generate_content_config, history=gemini_messages[:-1])
response = chat.send_message(message=gemini_messages[-1].parts)

# Extract text and tools from response
ans = ""
random_id = random.randint(0, 10000)
prev_function_calls = []

if isinstance(response, GenerateContentResponse):
parts = response.parts
if len(response.candidates) != 1:
raise ValueError(
f"Unexpected number of candidates in the response. Expected 1, got {len(response.candidates)}"
)
parts = response.candidates[0].content.parts
elif isinstance(response, VertexAIGenerationResponse): # or hasattr(response, "candidates"):
# google.generativeai also raises an error len(candidates) != 1:
if len(response.candidates) != 1:
Expand Down Expand Up @@ -610,21 +621,21 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
"""

if param_type == "integer":
param_schema.type_ = 2
param_schema.type = Type.INTEGER
elif param_type == "number":
param_schema.type_ = 3
param_schema.type = Type.NUMBER
elif param_type == "string":
param_schema.type_ = 1
param_schema.type = Type.STRING
elif param_type == "boolean":
param_schema.type_ = 6
param_schema.type = Type.BOOLEAN
elif param_type == "array":
param_schema.type_ = 5
param_schema.type = Type.ARRAY
if "items" in json_data:
param_schema.items = GeminiClient._create_gemini_function_declaration_schema(json_data["items"])
else:
print("Warning: Array schema missing 'items' definition.")
elif param_type == "object":
param_schema.type_ = 4
param_schema.type = Type.OBJECT
param_schema.properties = {}
if "properties" in json_data:
for prop_name, prop_data in json_data["properties"].items():
Expand All @@ -635,7 +646,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
print("Warning: Object schema missing 'properties' definition.")

elif param_type in ("null", "any"):
param_schema.type_ = 1 # Treating these as strings for simplicity
param_schema.type = Type.STRING # Treating these as strings for simplicity
else:
print(f"Warning: Unsupported parameter type '{param_type}'.")

Expand All @@ -647,7 +658,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
@staticmethod
def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> dict[str, any]:
"""Convert function parameters to Gemini format, recursive"""
function_parameter["type_"] = function_parameter["type"].upper()
function_parameter["type"] = function_parameter["type"].upper()

# Parameter properties and items
if "properties" in function_parameter:
Expand All @@ -660,7 +671,7 @@ def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> di
function_parameter["items"] = GeminiClient._create_gemini_function_parameters(function_parameter["items"])

# Remove any attributes not needed
for attr in ["type", "default"]:
for attr in ["default"]:
if attr in function_parameter:
del function_parameter[attr]

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,10 @@ teachable = ["chromadb"]
lmm = ["replicate", "pillow"]
graph = ["networkx", "matplotlib"]
gemini = [
"google-generativeai>=0.5,<1",
"google-genai>=0.6,<0.7",
kumaranvpl marked this conversation as resolved.
Show resolved Hide resolved
"google-cloud-aiplatform",
"google-auth",
"pillow",
"pydantic",
"jsonschema",
]
together = ["together>=1.2"]
Expand Down
3 changes: 3 additions & 0 deletions scripts/devcontainer/templates/devcontainer.json.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
30 changes: 24 additions & 6 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,13 @@ def credentials_from_test_param(request: pytest.FixtureRequest) -> Credentials:
T = TypeVar("T", bound=Callable[..., Any])


def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int = 60) -> Callable[[T], T]:
def suppress(
exception: type[BaseException],
*,
retries: int = 0,
timeout: int = 60,
error_filter: Optional[Callable[[BaseException], bool]] = None,
) -> Callable[[T], T]:
"""Suppresses the specified exception and retries the function a specified number of times.

Args:
Expand All @@ -429,7 +435,11 @@ def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int =
"""

def decorator(
func: T, exception: type[BaseException] = exception, retries: int = retries, timeout: int = timeout
func: T,
exception: type[BaseException] = exception,
retries: int = retries,
timeout: int = timeout,
error_filter: Optional[Callable[[BaseException], bool]] = error_filter,
) -> T:
if inspect.iscoroutinefunction(func):

Expand All @@ -444,7 +454,9 @@ async def wrapper(
for i in range(retries + 1):
try:
return await func(*args, **kwargs)
except exception:
except exception as e:
if error_filter and not error_filter(e): # type: ignore [arg-type]
raise
if i >= retries - 1:
pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times")
raise
Expand All @@ -462,7 +474,9 @@ def wrapper(
for i in range(retries + 1):
try:
return func(*args, **kwargs)
except exception:
except exception as e:
if error_filter and not error_filter(e): # type: ignore [arg-type]
raise
if i >= retries - 1:
pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times")
raise
Expand All @@ -475,9 +489,13 @@ def wrapper(

def suppress_gemini_resource_exhausted(func: T) -> T:
with optional_import_block():
from google.api_core.exceptions import ResourceExhausted
from google.genai.errors import ClientError

# Catch only code 429 which is RESOURCE_EXHAUSTED error instead of catching all the client errors
def is_resource_exhausted_error(e: BaseException) -> bool:
return isinstance(e, ClientError) and getattr(e, "code", None) == 429

return suppress(ResourceExhausted, retries=2)(func)
return suppress(ClientError, retries=2, error_filter=is_resource_exhausted_error)(func)

return func

Expand Down
Loading
Loading