Skip to content

Commit

Permalink
Use google-genai package instead of old google-generativeai package (#…
Browse files Browse the repository at this point in the history
…694)

* WIP: Use new google-genai package instead of old package

* Use google-genai for chat

* Remove commented out old code

* Use google.genai in tests

* Update error messages

* Remove pydantic which is included already in base dependencies

* Fix retrieving parts from response

* Add GEMINI_API_KEY to codespace secrets

* Use type instead of type_ parameter which is obsolete

* Catch genai client error
  • Loading branch information
kumaranvpl authored Jan 29, 2025
1 parent ee8ff17 commit 3dddf74
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.10/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.11/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.12/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
3 changes: 3 additions & 0 deletions .devcontainer/python-3.13/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
4 changes: 2 additions & 2 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
gemini_import_exception: Optional[ImportError] = None
else:
gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816
gemini_import_exception = ImportError("google-generativeai not found")
gemini_import_exception = ImportError("google-genai not found")

with optional_import_block() as anthropic_result:
from anthropic import ( # noqa
Expand Down Expand Up @@ -756,7 +756,7 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s
self._clients.append(client)
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.")
client = GeminiClient(response_format=response_format, **openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("anthropic"):
Expand Down
59 changes: 35 additions & 24 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,22 @@
from .client_utils import FormatterProtocol

with optional_import_block():
import google.generativeai as genai
import google.genai as genai
import vertexai
from PIL import Image
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
from google.ai.generativelanguage_v1beta.types import Schema
from google.auth.credentials import Credentials
from google.generativeai.types import GenerateContentResponse
from google.genai.types import (
Content,
FunctionCall,
FunctionDeclaration,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
Part,
Schema,
Tool,
Type,
)
from jsonschema import ValidationError
from vertexai.generative_models import (
Content as VertexAIContent,
Expand Down Expand Up @@ -215,9 +224,9 @@ def create(self, params: dict) -> ChatCompletion:
if autogen_term in params
}
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", []))
else:
safety_settings = params.get("safety_settings", {})
safety_settings = params.get("safety_settings", [])

if stream:
warnings.warn(
Expand Down Expand Up @@ -254,27 +263,29 @@ def create(self, params: dict) -> ChatCompletion:
)

chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
else:
model = genai.GenerativeModel(
model_name,
generation_config=generation_config,
client = genai.Client(api_key=self.api_key)
generate_content_config = GenerateContentConfig(
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=tools,
**generation_config,
)

genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])

response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
chat = client.chats.create(model=model_name, config=generate_content_config, history=gemini_messages[:-1])
response = chat.send_message(message=gemini_messages[-1].parts)

# Extract text and tools from response
ans = ""
random_id = random.randint(0, 10000)
prev_function_calls = []

if isinstance(response, GenerateContentResponse):
parts = response.parts
if len(response.candidates) != 1:
raise ValueError(
f"Unexpected number of candidates in the response. Expected 1, got {len(response.candidates)}"
)
parts = response.candidates[0].content.parts
elif isinstance(response, VertexAIGenerationResponse): # or hasattr(response, "candidates"):
# google.generativeai also raises an error len(candidates) != 1:
if len(response.candidates) != 1:
Expand Down Expand Up @@ -610,21 +621,21 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
"""

if param_type == "integer":
param_schema.type_ = 2
param_schema.type = Type.INTEGER
elif param_type == "number":
param_schema.type_ = 3
param_schema.type = Type.NUMBER
elif param_type == "string":
param_schema.type_ = 1
param_schema.type = Type.STRING
elif param_type == "boolean":
param_schema.type_ = 6
param_schema.type = Type.BOOLEAN
elif param_type == "array":
param_schema.type_ = 5
param_schema.type = Type.ARRAY
if "items" in json_data:
param_schema.items = GeminiClient._create_gemini_function_declaration_schema(json_data["items"])
else:
print("Warning: Array schema missing 'items' definition.")
elif param_type == "object":
param_schema.type_ = 4
param_schema.type = Type.OBJECT
param_schema.properties = {}
if "properties" in json_data:
for prop_name, prop_data in json_data["properties"].items():
Expand All @@ -635,7 +646,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
print("Warning: Object schema missing 'properties' definition.")

elif param_type in ("null", "any"):
param_schema.type_ = 1 # Treating these as strings for simplicity
param_schema.type = Type.STRING # Treating these as strings for simplicity
else:
print(f"Warning: Unsupported parameter type '{param_type}'.")

Expand All @@ -647,7 +658,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema:
@staticmethod
def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> dict[str, any]:
"""Convert function parameters to Gemini format, recursive"""
function_parameter["type_"] = function_parameter["type"].upper()
function_parameter["type"] = function_parameter["type"].upper()

# Parameter properties and items
if "properties" in function_parameter:
Expand All @@ -660,7 +671,7 @@ def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> di
function_parameter["items"] = GeminiClient._create_gemini_function_parameters(function_parameter["items"])

# Remove any attributes not needed
for attr in ["type", "default"]:
for attr in ["default"]:
if attr in function_parameter:
del function_parameter[attr]

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,10 @@ teachable = ["chromadb"]
lmm = ["replicate", "pillow"]
graph = ["networkx", "matplotlib"]
gemini = [
"google-generativeai>=0.5,<1",
"google-genai>=0.6,<0.7",
"google-cloud-aiplatform",
"google-auth",
"pillow",
"pydantic",
"jsonschema",
]
together = ["together>=1.2"]
Expand Down
3 changes: 3 additions & 0 deletions scripts/devcontainer/templates/devcontainer.json.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"ANTHROPIC_API_KEY": {
"description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"GEMINI_API_KEY": {
"description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal."
},
"AZURE_OPENAI_API_KEY": {
"description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal."
},
Expand Down
30 changes: 24 additions & 6 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,13 @@ def credentials_from_test_param(request: pytest.FixtureRequest) -> Credentials:
T = TypeVar("T", bound=Callable[..., Any])


def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int = 60) -> Callable[[T], T]:
def suppress(
exception: type[BaseException],
*,
retries: int = 0,
timeout: int = 60,
error_filter: Optional[Callable[[BaseException], bool]] = None,
) -> Callable[[T], T]:
"""Suppresses the specified exception and retries the function a specified number of times.
Args:
Expand All @@ -429,7 +435,11 @@ def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int =
"""

def decorator(
func: T, exception: type[BaseException] = exception, retries: int = retries, timeout: int = timeout
func: T,
exception: type[BaseException] = exception,
retries: int = retries,
timeout: int = timeout,
error_filter: Optional[Callable[[BaseException], bool]] = error_filter,
) -> T:
if inspect.iscoroutinefunction(func):

Expand All @@ -444,7 +454,9 @@ async def wrapper(
for i in range(retries + 1):
try:
return await func(*args, **kwargs)
except exception:
except exception as e:
if error_filter and not error_filter(e): # type: ignore [arg-type]
raise
if i >= retries - 1:
pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times")
raise
Expand All @@ -462,7 +474,9 @@ def wrapper(
for i in range(retries + 1):
try:
return func(*args, **kwargs)
except exception:
except exception as e:
if error_filter and not error_filter(e): # type: ignore [arg-type]
raise
if i >= retries - 1:
pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times")
raise
Expand All @@ -475,9 +489,13 @@ def wrapper(

def suppress_gemini_resource_exhausted(func: T) -> T:
with optional_import_block():
from google.api_core.exceptions import ResourceExhausted
from google.genai.errors import ClientError

# Catch only code 429 which is RESOURCE_EXHAUSTED error instead of catching all the client errors
def is_resource_exhausted_error(e: BaseException) -> bool:
return isinstance(e, ClientError) and getattr(e, "code", None) == 429

return suppress(ResourceExhausted, retries=2)(func)
return suppress(ClientError, retries=2, error_filter=is_resource_exhausted_error)(func)

return func

Expand Down
Loading

0 comments on commit 3dddf74

Please sign in to comment.