|
20 | 20 | from google.adk import version as adk_version |
21 | 21 | from google.adk.models import anthropic_llm |
22 | 22 | from google.adk.models.anthropic_llm import Claude |
| 23 | +from google.adk.models.anthropic_llm import content_to_message_param |
23 | 24 | from google.adk.models.anthropic_llm import function_declaration_to_tool_param |
24 | 25 | from google.adk.models.llm_request import LlmRequest |
25 | 26 | from google.adk.models.llm_response import LlmResponse |
|
32 | 33 |
|
33 | 34 | @pytest.fixture |
34 | 35 | def generate_content_response(): |
35 | | - return anthropic_types.Message( |
36 | | - id="msg_vrtx_testid", |
37 | | - content=[ |
38 | | - anthropic_types.TextBlock( |
39 | | - citations=None, text="Hi! How can I help you today?", type="text" |
40 | | - ) |
41 | | - ], |
42 | | - model="claude-3-5-sonnet-v2-20241022", |
43 | | - role="assistant", |
44 | | - stop_reason="end_turn", |
45 | | - stop_sequence=None, |
46 | | - type="message", |
47 | | - usage=anthropic_types.Usage( |
48 | | - cache_creation_input_tokens=0, |
49 | | - cache_read_input_tokens=0, |
50 | | - input_tokens=13, |
51 | | - output_tokens=12, |
52 | | - server_tool_use=None, |
53 | | - service_tier=None, |
54 | | - ), |
55 | | - ) |
| 36 | + return anthropic_types.Message( |
| 37 | + id="msg_vrtx_testid", |
| 38 | + content=[ |
| 39 | + anthropic_types.TextBlock( |
| 40 | + citations=None, text="Hi! How can I help you today?", type="text" |
| 41 | + ) |
| 42 | + ], |
| 43 | + model="claude-3-5-sonnet-v2-20241022", |
| 44 | + role="assistant", |
| 45 | + stop_reason="end_turn", |
| 46 | + stop_sequence=None, |
| 47 | + type="message", |
| 48 | + usage=anthropic_types.Usage( |
| 49 | + cache_creation_input_tokens=0, |
| 50 | + cache_read_input_tokens=0, |
| 51 | + input_tokens=13, |
| 52 | + output_tokens=12, |
| 53 | + server_tool_use=None, |
| 54 | + service_tier=None, |
| 55 | + ), |
| 56 | + ) |
56 | 57 |
|
57 | 58 |
|
58 | 59 | @pytest.fixture |
59 | 60 | def generate_llm_response(): |
60 | | - return LlmResponse.create( |
61 | | - types.GenerateContentResponse( |
62 | | - candidates=[ |
63 | | - types.Candidate( |
64 | | - content=Content( |
65 | | - role="model", |
66 | | - parts=[Part.from_text(text="Hello, how can I help you?")], |
67 | | - ), |
68 | | - finish_reason=types.FinishReason.STOP, |
69 | | - ) |
70 | | - ] |
71 | | - ) |
72 | | - ) |
| 61 | + return LlmResponse.create( |
| 62 | + types.GenerateContentResponse( |
| 63 | + candidates=[ |
| 64 | + types.Candidate( |
| 65 | + content=Content( |
| 66 | + role="model", |
| 67 | + parts=[Part.from_text(text="Hello, how can I help you?")], |
| 68 | + ), |
| 69 | + finish_reason=types.FinishReason.STOP, |
| 70 | + ) |
| 71 | + ] |
| 72 | + ) |
| 73 | + ) |
73 | 74 |
|
74 | 75 |
|
75 | 76 | @pytest.fixture |
76 | 77 | def claude_llm(): |
77 | | - return Claude(model="claude-3-5-sonnet-v2@20241022") |
| 78 | + return Claude(model="claude-3-5-sonnet-v2@20241022") |
78 | 79 |
|
79 | 80 |
|
80 | 81 | @pytest.fixture |
81 | 82 | def llm_request(): |
82 | | - return LlmRequest( |
83 | | - model="claude-3-5-sonnet-v2@20241022", |
84 | | - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], |
85 | | - config=types.GenerateContentConfig( |
86 | | - temperature=0.1, |
87 | | - response_modalities=[types.Modality.TEXT], |
88 | | - system_instruction="You are a helpful assistant", |
89 | | - ), |
90 | | - ) |
| 83 | + return LlmRequest( |
| 84 | + model="claude-3-5-sonnet-v2@20241022", |
| 85 | + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], |
| 86 | + config=types.GenerateContentConfig( |
| 87 | + temperature=0.1, |
| 88 | + response_modalities=[types.Modality.TEXT], |
| 89 | + system_instruction="You are a helpful assistant", |
| 90 | + ), |
| 91 | + ) |
91 | 92 |
|
92 | 93 |
|
93 | 94 | def test_supported_models(): |
94 | | - models = Claude.supported_models() |
95 | | - assert len(models) == 2 |
96 | | - assert models[0] == r"claude-3-.*" |
97 | | - assert models[1] == r"claude-.*-4.*" |
| 95 | + models = Claude.supported_models() |
| 96 | + assert len(models) == 2 |
| 97 | + assert models[0] == r"claude-3-.*" |
| 98 | + assert models[1] == r"claude-.*-4.*" |
98 | 99 |
|
99 | 100 |
|
100 | 101 | function_declaration_test_cases = [ |
@@ -133,9 +134,7 @@ def test_supported_models(): |
133 | 134 | "properties": { |
134 | 135 | "location": { |
135 | 136 | "type": "string", |
136 | | - "description": ( |
137 | | - "City and state, e.g., San Francisco, CA" |
138 | | - ), |
| 137 | + "description": ("City and state, e.g., San Francisco, CA"), |
139 | 138 | } |
140 | 139 | }, |
141 | 140 | }, |
@@ -284,65 +283,142 @@ def test_supported_models(): |
284 | 283 | async def test_function_declaration_to_tool_param( |
285 | 284 | _, function_declaration, expected_tool_param |
286 | 285 | ): |
287 | | - """Test function_declaration_to_tool_param.""" |
288 | | - assert ( |
289 | | - function_declaration_to_tool_param(function_declaration) |
290 | | - == expected_tool_param |
291 | | - ) |
| 286 | + """Test function_declaration_to_tool_param.""" |
| 287 | + assert ( |
| 288 | + function_declaration_to_tool_param(function_declaration) == expected_tool_param |
| 289 | + ) |
292 | 290 |
|
293 | 291 |
|
294 | 292 | @pytest.mark.asyncio |
295 | 293 | async def test_generate_content_async( |
296 | 294 | claude_llm, llm_request, generate_content_response, generate_llm_response |
297 | 295 | ): |
298 | | - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: |
299 | | - with mock.patch.object( |
300 | | - anthropic_llm, |
301 | | - "message_to_generate_content_response", |
302 | | - return_value=generate_llm_response, |
303 | | - ): |
304 | | - # Create a mock coroutine that returns the generate_content_response. |
305 | | - async def mock_coro(): |
306 | | - return generate_content_response |
| 296 | + with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: |
| 297 | + with mock.patch.object( |
| 298 | + anthropic_llm, |
| 299 | + "message_to_generate_content_response", |
| 300 | + return_value=generate_llm_response, |
| 301 | + ): |
| 302 | + # Create a mock coroutine that returns the generate_content_response. |
| 303 | + async def mock_coro(): |
| 304 | + return generate_content_response |
307 | 305 |
|
308 | | - # Assign the coroutine to the mocked method |
309 | | - mock_client.messages.create.return_value = mock_coro() |
| 306 | + # Assign the coroutine to the mocked method |
| 307 | + mock_client.messages.create.return_value = mock_coro() |
310 | 308 |
|
311 | | - responses = [ |
312 | | - resp |
313 | | - async for resp in claude_llm.generate_content_async( |
314 | | - llm_request, stream=False |
315 | | - ) |
316 | | - ] |
317 | | - assert len(responses) == 1 |
318 | | - assert isinstance(responses[0], LlmResponse) |
319 | | - assert responses[0].content.parts[0].text == "Hello, how can I help you?" |
| 309 | + responses = [ |
| 310 | + resp |
| 311 | + async for resp in claude_llm.generate_content_async( |
| 312 | + llm_request, stream=False |
| 313 | + ) |
| 314 | + ] |
| 315 | + assert len(responses) == 1 |
| 316 | + assert isinstance(responses[0], LlmResponse) |
| 317 | + assert responses[0].content.parts[0].text == "Hello, how can I help you?" |
320 | 318 |
|
321 | 319 |
|
322 | 320 | @pytest.mark.asyncio |
323 | 321 | async def test_generate_content_async_with_max_tokens( |
324 | 322 | llm_request, generate_content_response, generate_llm_response |
325 | 323 | ): |
326 | | - claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096) |
327 | | - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: |
328 | | - with mock.patch.object( |
329 | | - anthropic_llm, |
330 | | - "message_to_generate_content_response", |
331 | | - return_value=generate_llm_response, |
332 | | - ): |
333 | | - # Create a mock coroutine that returns the generate_content_response. |
334 | | - async def mock_coro(): |
335 | | - return generate_content_response |
| 324 | + claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096) |
| 325 | + with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: |
| 326 | + with mock.patch.object( |
| 327 | + anthropic_llm, |
| 328 | + "message_to_generate_content_response", |
| 329 | + return_value=generate_llm_response, |
| 330 | + ): |
| 331 | + # Create a mock coroutine that returns the generate_content_response. |
| 332 | + async def mock_coro(): |
| 333 | + return generate_content_response |
| 334 | + |
| 335 | + # Assign the coroutine to the mocked method |
| 336 | + mock_client.messages.create.return_value = mock_coro() |
| 337 | + |
| 338 | + _ = [ |
| 339 | + resp |
| 340 | + async for resp in claude_llm.generate_content_async( |
| 341 | + llm_request, stream=False |
| 342 | + ) |
| 343 | + ] |
| 344 | + mock_client.messages.create.assert_called_once() |
| 345 | + _, kwargs = mock_client.messages.create.call_args |
| 346 | + assert kwargs["max_tokens"] == 4096 |
| 347 | + |
| 348 | + |
| 349 | +content_to_message_param_test_cases = [ |
| 350 | + ( |
| 351 | + "user_role_with_text_and_image", |
| 352 | + Content( |
| 353 | + role="user", |
| 354 | + parts=[ |
| 355 | + Part.from_text(text="What's in this image?"), |
| 356 | + Part( |
| 357 | + inline_data=types.Blob( |
| 358 | + mime_type="image/jpeg", data=b"fake_image_data" |
| 359 | + ) |
| 360 | + ), |
| 361 | + ], |
| 362 | + ), |
| 363 | + "user", |
| 364 | + 2, # Expected content length |
| 365 | + False, # Should not log warning |
| 366 | + ), |
| 367 | + ( |
| 368 | + "model_role_with_text_and_image", |
| 369 | + Content( |
| 370 | + role="model", |
| 371 | + parts=[ |
| 372 | + Part.from_text(text="I see a cat."), |
| 373 | + Part( |
| 374 | + inline_data=types.Blob( |
| 375 | + mime_type="image/png", data=b"fake_image_data" |
| 376 | + ) |
| 377 | + ), |
| 378 | + ], |
| 379 | + ), |
| 380 | + "assistant", |
| 381 | + 1, # Image filtered out, only text remains |
| 382 | + True, # Should log warning |
| 383 | + ), |
| 384 | + ( |
| 385 | + "assistant_role_with_text_and_image", |
| 386 | + Content( |
| 387 | + role="assistant", |
| 388 | + parts=[ |
| 389 | + Part.from_text(text="Here's what I found."), |
| 390 | + Part( |
| 391 | + inline_data=types.Blob( |
| 392 | + mime_type="image/webp", data=b"fake_image_data" |
| 393 | + ) |
| 394 | + ), |
| 395 | + ], |
| 396 | + ), |
| 397 | + "assistant", |
| 398 | + 1, # Image filtered out, only text remains |
| 399 | + True, # Should log warning |
| 400 | + ), |
| 401 | +] |
| 402 | + |
| 403 | + |
| 404 | +@pytest.mark.parametrize( |
| 405 | + "_, content, expected_role, expected_content_length, should_log_warning", |
| 406 | + content_to_message_param_test_cases, |
| 407 | + ids=[case[0] for case in content_to_message_param_test_cases], |
| 408 | +) |
| 409 | +def test_content_to_message_param_with_images( |
| 410 | + _, content, expected_role, expected_content_length, should_log_warning |
| 411 | +): |
| 412 | + """Test content_to_message_param handles images correctly based on role.""" |
| 413 | + with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger: |
| 414 | + result = content_to_message_param(content) |
336 | 415 |
|
337 | | - # Assign the coroutine to the mocked method |
338 | | - mock_client.messages.create.return_value = mock_coro() |
| 416 | + assert result["role"] == expected_role |
| 417 | + assert len(result["content"]) == expected_content_length |
339 | 418 |
|
340 | | - _ = [ |
341 | | - resp |
342 | | - async for resp in claude_llm.generate_content_async( |
343 | | - llm_request, stream=False |
344 | | - ) |
345 | | - ] |
346 | | - mock_client.messages.create.assert_called_once() |
347 | | - _, kwargs = mock_client.messages.create.call_args |
348 | | - assert kwargs["max_tokens"] == 4096 |
| 419 | + if should_log_warning: |
| 420 | + mock_logger.warning.assert_called_once_with( |
| 421 | + "Image data is not supported in Claude for model turns." |
| 422 | + ) |
| 423 | + else: |
| 424 | + mock_logger.warning.assert_not_called() |
0 commit comments