|
50 | 50 | import org.springframework.ai.converter.StructuredOutputConverter; |
51 | 51 | import org.springframework.ai.template.TemplateRenderer; |
52 | 52 | import org.springframework.ai.tool.ToolCallback; |
| 53 | +import org.springframework.ai.tool.ToolCallbackProvider; |
53 | 54 | import org.springframework.ai.tool.function.FunctionToolCallback; |
54 | 55 | import org.springframework.core.ParameterizedTypeReference; |
55 | 56 | import org.springframework.core.convert.support.DefaultConversionService; |
|
61 | 62 | import static org.assertj.core.api.Assertions.assertThatThrownBy; |
62 | 63 | import static org.mockito.BDDMockito.given; |
63 | 64 | import static org.mockito.Mockito.mock; |
| 65 | +import static org.mockito.Mockito.never; |
| 66 | +import static org.mockito.Mockito.times; |
| 67 | +import static org.mockito.Mockito.verify; |
64 | 68 | import static org.mockito.Mockito.when; |
65 | 69 |
|
66 | 70 | /** |
@@ -1474,24 +1478,24 @@ void buildChatClientRequestSpec() { |
1474 | 1478 | ChatModel chatModel = mock(ChatModel.class); |
1475 | 1479 | DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( |
1476 | 1480 | chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), |
1477 | | - List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); |
| 1481 | + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); |
1478 | 1482 | assertThat(spec).isNotNull(); |
1479 | 1483 | } |
1480 | 1484 |
|
1481 | 1485 | @Test |
1482 | 1486 | void whenChatModelIsNullThenThrow() { |
1483 | 1487 | assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(), |
1484 | | - null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), |
1485 | | - ObservationRegistry.NOOP, null, Map.of(), null)) |
| 1488 | + null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), |
| 1489 | + Map.of(), ObservationRegistry.NOOP, null, Map.of(), null)) |
1486 | 1490 | .isInstanceOf(IllegalArgumentException.class) |
1487 | 1491 | .hasMessage("chatModel cannot be null"); |
1488 | 1492 | } |
1489 | 1493 |
|
1490 | 1494 | @Test |
1491 | 1495 | void whenObservationRegistryIsNullThenThrow() { |
1492 | 1496 | assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, |
1493 | | - Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, |
1494 | | - List.of(), Map.of(), null, null, Map.of(), null)) |
| 1497 | + Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), |
| 1498 | + null, List.of(), Map.of(), null, null, Map.of(), null)) |
1495 | 1499 | .isInstanceOf(IllegalArgumentException.class) |
1496 | 1500 | .hasMessage("observationRegistry cannot be null"); |
1497 | 1501 | } |
@@ -2197,6 +2201,115 @@ void whenUserConsumerWithNullParamValueThenThrow() { |
2197 | 2201 | .hasMessage("value cannot be null"); |
2198 | 2202 | } |
2199 | 2203 |
|
| 2204 | + @Test |
| 2205 | + void whenToolCallbackProviderThenNotEagerlyEvaluated() { |
| 2206 | + ChatModel chatModel = mock(ChatModel.class); |
| 2207 | + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); |
| 2208 | + |
| 2209 | + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); |
| 2210 | + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); |
| 2211 | + |
| 2212 | + // Verify that getToolCallbacks() was NOT called during configuration |
| 2213 | + verify(provider, never()).getToolCallbacks(); |
| 2214 | + } |
| 2215 | + |
| 2216 | + @Test |
| 2217 | + void whenToolCallbackProviderThenLazilyEvaluatedOnCall() { |
| 2218 | + ChatModel chatModel = mock(ChatModel.class); |
| 2219 | + ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 2220 | + given(chatModel.call(promptCaptor.capture())) |
| 2221 | + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); |
| 2222 | + |
| 2223 | + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); |
| 2224 | + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); |
| 2225 | + |
| 2226 | + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); |
| 2227 | + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); |
| 2228 | + |
| 2229 | + // Verify not called yet |
| 2230 | + verify(provider, never()).getToolCallbacks(); |
| 2231 | + |
| 2232 | + // Execute the call |
| 2233 | + spec.call().content(); |
| 2234 | + |
| 2235 | + // Verify getToolCallbacks() WAS called during execution |
| 2236 | + verify(provider, times(1)).getToolCallbacks(); |
| 2237 | + } |
| 2238 | + |
| 2239 | + @Test |
| 2240 | + void whenToolCallbackProviderThenLazilyEvaluatedOnStream() { |
| 2241 | + ChatModel chatModel = mock(ChatModel.class); |
| 2242 | + ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 2243 | + given(chatModel.stream(promptCaptor.capture())) |
| 2244 | + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); |
| 2245 | + |
| 2246 | + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); |
| 2247 | + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); |
| 2248 | + |
| 2249 | + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); |
| 2250 | + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); |
| 2251 | + |
| 2252 | + // Verify not called yet |
| 2253 | + verify(provider, never()).getToolCallbacks(); |
| 2254 | + |
| 2255 | + // Execute the stream |
| 2256 | + spec.stream().content().blockLast(); |
| 2257 | + |
| 2258 | + // Verify getToolCallbacks() WAS called during execution |
| 2259 | + verify(provider, times(1)).getToolCallbacks(); |
| 2260 | + } |
| 2261 | + |
| 2262 | + @Test |
| 2263 | + void whenMultipleToolCallbackProvidersThenAllLazilyEvaluated() { |
| 2264 | + ChatModel chatModel = mock(ChatModel.class); |
| 2265 | + ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 2266 | + given(chatModel.call(promptCaptor.capture())) |
| 2267 | + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); |
| 2268 | + |
| 2269 | + ToolCallbackProvider provider1 = mock(ToolCallbackProvider.class); |
| 2270 | + when(provider1.getToolCallbacks()).thenReturn(new ToolCallback[] {}); |
| 2271 | + |
| 2272 | + ToolCallbackProvider provider2 = mock(ToolCallbackProvider.class); |
| 2273 | + when(provider2.getToolCallbacks()).thenReturn(new ToolCallback[] {}); |
| 2274 | + |
| 2275 | + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); |
| 2276 | + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider1, provider2); |
| 2277 | + |
| 2278 | + // Verify not called yet |
| 2279 | + verify(provider1, never()).getToolCallbacks(); |
| 2280 | + verify(provider2, never()).getToolCallbacks(); |
| 2281 | + |
| 2282 | + // Execute the call |
| 2283 | + spec.call().content(); |
| 2284 | + |
| 2285 | + // Verify both getToolCallbacks() were called during execution |
| 2286 | + verify(provider1, times(1)).getToolCallbacks(); |
| 2287 | + verify(provider2, times(1)).getToolCallbacks(); |
| 2288 | + } |
| 2289 | + |
| 2290 | + @Test |
| 2291 | + void whenToolCallbacksAndProvidersThenBothUsed() { |
| 2292 | + ChatModel chatModel = mock(ChatModel.class); |
| 2293 | + ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 2294 | + given(chatModel.call(promptCaptor.capture())) |
| 2295 | + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); |
| 2296 | + |
| 2297 | + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); |
| 2298 | + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); |
| 2299 | + |
| 2300 | + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); |
| 2301 | + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); |
| 2302 | + |
| 2303 | + // Verify provider not called yet |
| 2304 | + verify(provider, never()).getToolCallbacks(); |
| 2305 | + |
| 2306 | + // Execute the call |
| 2307 | + spec.call().content(); |
| 2308 | + |
| 2309 | + // Verify provider was called during execution |
| 2310 | + verify(provider, times(1)).getToolCallbacks(); |
| 2311 | + } |
| 2312 | + |
2200 | 2313 | record Person(String name) { |
2201 | 2314 | } |
2202 | 2315 |
|
|
0 commit comments