diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ed4d53236..ecfd4510f 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/.devcontainer/python-3.10/devcontainer.json b/.devcontainer/python-3.10/devcontainer.json index 8242b4f4b..84ccb97ed 100644 --- a/.devcontainer/python-3.10/devcontainer.json +++ b/.devcontainer/python-3.10/devcontainer.json @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/.devcontainer/python-3.11/devcontainer.json b/.devcontainer/python-3.11/devcontainer.json index c099e5c78..b04d42580 100644 --- a/.devcontainer/python-3.11/devcontainer.json +++ b/.devcontainer/python-3.11/devcontainer.json @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/.devcontainer/python-3.12/devcontainer.json b/.devcontainer/python-3.12/devcontainer.json index 6dbaedba8..d4d573320 100644 --- a/.devcontainer/python-3.12/devcontainer.json +++ b/.devcontainer/python-3.12/devcontainer.json @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/.devcontainer/python-3.13/devcontainer.json b/.devcontainer/python-3.13/devcontainer.json index 3e32aa670..c407d82da 100644 --- a/.devcontainer/python-3.13/devcontainer.json +++ b/.devcontainer/python-3.13/devcontainer.json @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 0f7e78a4e..bc3f8c1b7 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -82,7 +82,7 @@ gemini_import_exception: Optional[ImportError] = None else: gemini_InternalServerError = gemini_ResourceExhausted = Exception # noqa: N816 - gemini_import_exception = ImportError("google-generativeai not found") + gemini_import_exception = ImportError("google-genai not found") with optional_import_block() as anthropic_result: from anthropic import ( # noqa @@ -756,7 +756,7 @@ def _register_default_client(self, config: dict[str, Any], openai_config: dict[s self._clients.append(client) elif api_type is not None and api_type.startswith("google"): if gemini_import_exception: - raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.") + raise ImportError("Please install `google-genai` and 'vertexai' to use Google's API.") client = GeminiClient(response_format=response_format, **openai_config) self._clients.append(client) elif api_type is not None and api_type.startswith("anthropic"): diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 050f1a13b..acc83db79 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -62,13 +62,22 @@ from .client_utils import FormatterProtocol with optional_import_block(): - import google.generativeai as genai + import google.genai as genai import vertexai from PIL import Image - from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool - from google.ai.generativelanguage_v1beta.types import Schema from google.auth.credentials import Credentials - from google.generativeai.types import GenerateContentResponse + from google.genai.types import ( + Content, + FunctionCall, + FunctionDeclaration, + FunctionResponse, + GenerateContentConfig, + GenerateContentResponse, + Part, + Schema, + Tool, + Type, + ) from jsonschema import ValidationError from vertexai.generative_models import ( Content as VertexAIContent, @@ -215,9 +224,9 @@ def create(self, params: dict) -> ChatCompletion: if autogen_term in params } if self.use_vertexai: - safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {})) + safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", [])) else: - safety_settings = params.get("safety_settings", {}) + safety_settings = params.get("safety_settings", []) if stream: warnings.warn( @@ -254,19 +263,17 @@ def create(self, params: dict) -> ChatCompletion: ) chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) + response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings) else: - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, + client = genai.Client(api_key=self.api_key) + generate_content_config = GenerateContentConfig( safety_settings=safety_settings, system_instruction=system_instruction, tools=tools, + **generation_config, ) - - genai.configure(api_key=self.api_key) - chat = model.start_chat(history=gemini_messages[:-1]) - - response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings) + chat = client.chats.create(model=model_name, config=generate_content_config, history=gemini_messages[:-1]) + response = chat.send_message(message=gemini_messages[-1].parts) # Extract text and tools from response ans = "" @@ -274,7 +281,11 @@ def create(self, params: dict) -> ChatCompletion: prev_function_calls = [] if isinstance(response, GenerateContentResponse): - parts = response.parts + if len(response.candidates) != 1: + raise ValueError( + f"Unexpected number of candidates in the response. Expected 1, got {len(response.candidates)}" + ) + parts = response.candidates[0].content.parts elif isinstance(response, VertexAIGenerationResponse): # or hasattr(response, "candidates"): # google.generativeai also raises an error len(candidates) != 1: if len(response.candidates) != 1: @@ -610,21 +621,21 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema: """ if param_type == "integer": - param_schema.type_ = 2 + param_schema.type = Type.INTEGER elif param_type == "number": - param_schema.type_ = 3 + param_schema.type = Type.NUMBER elif param_type == "string": - param_schema.type_ = 1 + param_schema.type = Type.STRING elif param_type == "boolean": - param_schema.type_ = 6 + param_schema.type = Type.BOOLEAN elif param_type == "array": - param_schema.type_ = 5 + param_schema.type = Type.ARRAY if "items" in json_data: param_schema.items = GeminiClient._create_gemini_function_declaration_schema(json_data["items"]) else: print("Warning: Array schema missing 'items' definition.") elif param_type == "object": - param_schema.type_ = 4 + param_schema.type = Type.OBJECT param_schema.properties = {} if "properties" in json_data: for prop_name, prop_data in json_data["properties"].items(): @@ -635,7 +646,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema: print("Warning: Object schema missing 'properties' definition.") elif param_type in ("null", "any"): - param_schema.type_ = 1 # Treating these as strings for simplicity + param_schema.type = Type.STRING # Treating these as strings for simplicity else: print(f"Warning: Unsupported parameter type '{param_type}'.") @@ -647,7 +658,7 @@ def _create_gemini_function_declaration_schema(json_data) -> Schema: @staticmethod def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> dict[str, any]: """Convert function parameters to Gemini format, recursive""" - function_parameter["type_"] = function_parameter["type"].upper() + function_parameter["type"] = function_parameter["type"].upper() # Parameter properties and items if "properties" in function_parameter: @@ -660,7 +671,7 @@ def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> di function_parameter["items"] = GeminiClient._create_gemini_function_parameters(function_parameter["items"]) # Remove any attributes not needed - for attr in ["type", "default"]: + for attr in ["default"]: if attr in function_parameter: del function_parameter[attr] diff --git a/pyproject.toml b/pyproject.toml index 7d6c63139..2adaef9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,11 +165,10 @@ teachable = ["chromadb"] lmm = ["replicate", "pillow"] graph = ["networkx", "matplotlib"] gemini = [ - "google-generativeai>=0.5,<1", + "google-genai>=0.6,<0.7", "google-cloud-aiplatform", "google-auth", "pillow", - "pydantic", "jsonschema", ] together = ["together>=1.2"] diff --git a/scripts/devcontainer/templates/devcontainer.json.jinja b/scripts/devcontainer/templates/devcontainer.json.jinja index 503ed063f..0a63bfbbd 100644 --- a/scripts/devcontainer/templates/devcontainer.json.jinja +++ b/scripts/devcontainer/templates/devcontainer.json.jinja @@ -19,6 +19,9 @@ "ANTHROPIC_API_KEY": { "description": "This key is optional and only needed if you are working with Anthropic API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." }, + "GEMINI_API_KEY": { + "description": "This key is optional and only needed if you are working with Gemini API-related code. Leave it blank if not required. You can always set it later as an environment variable in the codespace terminal." + }, "AZURE_OPENAI_API_KEY": { "description": "This key is optional and only needed if you are using Azure's OpenAI services. For it to work, you must also set the related environment variables: AZURE_API_ENDPOINT, AZURE_API_VERSION. Leave it blank if not required. You can always set these variables later in the codespace terminal." }, diff --git a/test/conftest.py b/test/conftest.py index ee912d524..a9e7b5b5e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -418,7 +418,13 @@ def credentials_from_test_param(request: pytest.FixtureRequest) -> Credentials: T = TypeVar("T", bound=Callable[..., Any]) -def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int = 60) -> Callable[[T], T]: +def suppress( + exception: type[BaseException], + *, + retries: int = 0, + timeout: int = 60, + error_filter: Optional[Callable[[BaseException], bool]] = None, +) -> Callable[[T], T]: """Suppresses the specified exception and retries the function a specified number of times. Args: @@ -429,7 +435,11 @@ def suppress(exception: type[BaseException], *, retries: int = 0, timeout: int = """ def decorator( - func: T, exception: type[BaseException] = exception, retries: int = retries, timeout: int = timeout + func: T, + exception: type[BaseException] = exception, + retries: int = retries, + timeout: int = timeout, + error_filter: Optional[Callable[[BaseException], bool]] = error_filter, ) -> T: if inspect.iscoroutinefunction(func): @@ -444,7 +454,9 @@ async def wrapper( for i in range(retries + 1): try: return await func(*args, **kwargs) - except exception: + except exception as e: + if error_filter and not error_filter(e): # type: ignore [arg-type] + raise if i >= retries - 1: pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times") raise @@ -462,7 +474,9 @@ def wrapper( for i in range(retries + 1): try: return func(*args, **kwargs) - except exception: + except exception as e: + if error_filter and not error_filter(e): # type: ignore [arg-type] + raise if i >= retries - 1: pytest.xfail(f"Suppressed '{exception}' raised {i + 1} times") raise @@ -475,9 +489,13 @@ def wrapper( def suppress_gemini_resource_exhausted(func: T) -> T: with optional_import_block(): - from google.api_core.exceptions import ResourceExhausted + from google.genai.errors import ClientError + + # Catch only code 429 which is RESOURCE_EXHAUSTED error instead of catching all the client errors + def is_resource_exhausted_error(e: BaseException) -> bool: + return isinstance(e, ClientError) and getattr(e, "code", None) == 429 - return suppress(ResourceExhausted, retries=2)(func) + return suppress(ClientError, retries=2, error_filter=is_resource_exhausted_error)(func) return func diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 097e48d1c..10978ff39 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -19,7 +19,7 @@ from google.api_core.exceptions import InternalServerError from google.auth.credentials import Credentials from google.cloud.aiplatform.initializer import global_config as vertexai_global_config - from google.generativeai.types import GenerateContentResponse + from google.genai.types import GenerateContentResponse from vertexai.generative_models import GenerationResponse as VertexAIGenerationResponse from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold from vertexai.generative_models import HarmCategory as VertexAIHarmCategory @@ -64,7 +64,7 @@ def gemini_client_with_credentials(): # Test compute location initialization and configuration @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_compute_location_initialization(): with pytest.raises(AssertionError): @@ -75,7 +75,7 @@ def test_compute_location_initialization(): # Test project initialization and configuration @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_project_initialization(): with pytest.raises(AssertionError): @@ -85,14 +85,14 @@ def test_project_initialization(): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_valid_initialization(gemini_client): assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set" @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_google_application_credentials_initialization(): GeminiClient(google_application_credentials="credentials.json", project_id="fake-project-id") @@ -102,7 +102,7 @@ def test_google_application_credentials_initialization(): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_vertexai_initialization(): mock_credentials = MagicMock(Credentials) @@ -113,7 +113,7 @@ def test_vertexai_initialization(): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_gemini_message_handling(gemini_client): messages = [ @@ -152,7 +152,7 @@ def test_gemini_message_handling(gemini_client): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_gemini_empty_message_handling(gemini_client): messages = [ @@ -172,7 +172,7 @@ def test_gemini_empty_message_handling(gemini_client): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_vertexai_safety_setting_conversion(gemini_client): safety_settings = [ @@ -206,7 +206,7 @@ def compare_safety_settings(converted_safety_settings, expected_safety_settings) @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_vertexai_default_safety_settings_dict(gemini_client): safety_settings = { @@ -233,7 +233,7 @@ def compare_safety_settings(converted_safety_settings, expected_safety_settings) @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_vertexai_safety_setting_list(gemini_client): harm_categories = [ @@ -267,7 +267,7 @@ def compare_safety_settings(converted_safety_settings, expected_safety_settings) # Test error handling @patch("autogen.oai.gemini.genai") @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_internal_server_error_retry(mock_genai, gemini_client): mock_genai.GenerativeModel.side_effect = [InternalServerError("Test Error"), None] # First call fails @@ -283,7 +283,7 @@ def test_internal_server_error_retry(mock_genai, gemini_client): # Test cost calculation @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_cost_calculation(gemini_client, mock_response): response = mock_response( @@ -297,7 +297,7 @@ def test_cost_calculation(gemini_client, mock_response): @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) @patch("autogen.oai.gemini.genai.GenerativeModel") # @patch("autogen.oai.gemini.genai.configure") @@ -350,7 +350,7 @@ def test_create_response_with_text(mock_calculate_cost, mock_generative_model, g @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) @patch("autogen.oai.gemini.GenerativeModel") @patch("autogen.oai.gemini.vertexai.init") @@ -404,7 +404,7 @@ def test_vertexai_create_response( @skip_on_missing_imports( - ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.generativeai"], "gemini" + ["vertexai", "PIL", "google.ai", "google.auth", "google.api", "google.cloud", "google.genai"], "gemini" ) def test_extract_json_response(gemini_client): # Define test Pydantic model