Skip to content

Commit 99ccf9b

Browse files
committed
Defer ToolCallbackProvider resolution to execution time
Changed DefaultChatClient to store ToolCallbackProvider instances instead of eagerly resolving them during configuration. Providers are now resolved lazily when call() or stream() is invoked. - Added toolCallbackProviders field to DefaultChatClientRequestSpec - Updated DefaultChatClientUtils to resolve providers at execution time - Added tests verifying lazy evaluation behavior Resolves #4748 Signed-off-by: Christian Tzolov <[email protected]>
1 parent 0523dd7 commit 99ccf9b

File tree

4 files changed

+156
-21
lines changed

4 files changed

+156
-21
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
618618

619619
private final List<ToolCallback> toolCallbacks = new ArrayList<>();
620620

621+
private final List<ToolCallbackProvider> toolCallbackProviders = new ArrayList<>();
622+
621623
private final List<Message> messages = new ArrayList<>();
622624

623625
private final Map<String, Object> userParams = new HashMap<>();
@@ -648,16 +650,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
648650
/* copy constructor */
649651
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
650652
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams,
651-
ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
652-
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
653-
ccr.toolContext, ccr.templateRenderer);
653+
ccr.systemMetadata, ccr.toolCallbacks, ccr.toolCallbackProviders, ccr.messages, ccr.toolNames,
654+
ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry,
655+
ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
654656
}
655657

656658
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
657659
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
658660
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
659-
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
660-
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
661+
List<ToolCallbackProvider> toolCallbackProviders, List<Message> messages, List<String> toolNames,
662+
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
663+
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
661664
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
662665
@Nullable TemplateRenderer templateRenderer) {
663666

@@ -667,6 +670,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
667670
Assert.notNull(systemParams, "systemParams cannot be null");
668671
Assert.notNull(systemMetadata, "systemMetadata cannot be null");
669672
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
673+
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
670674
Assert.notNull(messages, "messages cannot be null");
671675
Assert.notNull(toolNames, "toolNames cannot be null");
672676
Assert.notNull(media, "media cannot be null");
@@ -689,6 +693,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
689693

690694
this.toolNames.addAll(toolNames);
691695
this.toolCallbacks.addAll(toolCallbacks);
696+
this.toolCallbackProviders.addAll(toolCallbackProviders);
692697
this.messages.addAll(messages);
693698
this.media.addAll(media);
694699
this.advisors.addAll(advisors);
@@ -755,6 +760,10 @@ public List<ToolCallback> getToolCallbacks() {
755760
return this.toolCallbacks;
756761
}
757762

763+
public List<ToolCallbackProvider> getToolCallbackProviders() {
764+
return this.toolCallbackProviders;
765+
}
766+
758767
public Map<String, Object> getToolContext() {
759768
return this.toolContext;
760769
}
@@ -773,6 +782,7 @@ public Builder mutate() {
773782
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
774783
.defaultTemplateRenderer(this.templateRenderer)
775784
.defaultToolCallbacks(this.toolCallbacks)
785+
.defaultToolCallbacks(this.toolCallbackProviders.toArray(new ToolCallback[0]))
776786
.defaultToolContext(this.toolContext)
777787
.defaultToolNames(StringUtils.toStringArray(this.toolNames));
778788

@@ -885,9 +895,7 @@ public ChatClientRequestSpec tools(Object... toolObjects) {
885895
public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) {
886896
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
887897
Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
888-
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
889-
this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
890-
}
898+
this.toolCallbackProviders.addAll(List.of(toolCallbackProviders));
891899
return this;
892900
}
893901

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6565
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
6666
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
6767
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(),
68-
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
69-
customObservationConvention, Map.of(), null);
68+
Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
69+
observationRegistry, customObservationConvention, Map.of(), null);
7070
}
7171

7272
public ChatClient build() {

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
106106

107107
ChatOptions processedChatOptions = inputRequest.getChatOptions();
108108

109-
if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) {
110-
if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty()
111-
|| !CollectionUtils.isEmpty(inputRequest.getToolContext())) {
109+
// If we have tool-related configuration but no tool or non-tool options,
110+
// create ToolCallingChatOptions
111+
if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty()
112+
|| !inputRequest.getToolCallbackProviders().isEmpty()
113+
|| !CollectionUtils.isEmpty(inputRequest.getToolContext())) {
114+
115+
if (processedChatOptions == null) {
116+
processedChatOptions = new DefaultToolCallingChatOptions();
117+
}
118+
else if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) {
112119
processedChatOptions = ModelOptionsUtils.copyToTarget(defaultChatOptions, ChatOptions.class,
113120
DefaultToolCallingChatOptions.class);
114121
}
@@ -120,9 +127,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
120127
.mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames());
121128
toolCallingChatOptions.setToolNames(toolNames);
122129
}
123-
if (!inputRequest.getToolCallbacks().isEmpty()) {
124-
List<ToolCallback> toolCallbacks = ToolCallingChatOptions
125-
.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks());
130+
131+
// Lazily resolve ToolCallbackProvider instances to ToolCallback instances
132+
List<ToolCallback> allToolCallbacks = new ArrayList<>(inputRequest.getToolCallbacks());
133+
for (var provider : inputRequest.getToolCallbackProviders()) {
134+
allToolCallbacks.addAll(java.util.List.of(provider.getToolCallbacks()));
135+
}
136+
137+
if (!allToolCallbacks.isEmpty()) {
138+
List<ToolCallback> toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(allToolCallbacks,
139+
toolCallingChatOptions.getToolCallbacks());
126140
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
127141
toolCallingChatOptions.setToolCallbacks(toolCallbacks);
128142
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.springframework.ai.converter.StructuredOutputConverter;
5151
import org.springframework.ai.template.TemplateRenderer;
5252
import org.springframework.ai.tool.ToolCallback;
53+
import org.springframework.ai.tool.ToolCallbackProvider;
5354
import org.springframework.ai.tool.function.FunctionToolCallback;
5455
import org.springframework.core.ParameterizedTypeReference;
5556
import org.springframework.core.convert.support.DefaultConversionService;
@@ -61,6 +62,9 @@
6162
import static org.assertj.core.api.Assertions.assertThatThrownBy;
6263
import static org.mockito.BDDMockito.given;
6364
import static org.mockito.Mockito.mock;
65+
import static org.mockito.Mockito.never;
66+
import static org.mockito.Mockito.times;
67+
import static org.mockito.Mockito.verify;
6468
import static org.mockito.Mockito.when;
6569

6670
/**
@@ -1474,24 +1478,24 @@ void buildChatClientRequestSpec() {
14741478
ChatModel chatModel = mock(ChatModel.class);
14751479
DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec(
14761480
chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(),
1477-
List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null);
1481+
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null);
14781482
assertThat(spec).isNotNull();
14791483
}
14801484

14811485
@Test
14821486
void whenChatModelIsNullThenThrow() {
14831487
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(),
1484-
null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
1485-
ObservationRegistry.NOOP, null, Map.of(), null))
1488+
null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(),
1489+
Map.of(), ObservationRegistry.NOOP, null, Map.of(), null))
14861490
.isInstanceOf(IllegalArgumentException.class)
14871491
.hasMessage("chatModel cannot be null");
14881492
}
14891493

14901494
@Test
14911495
void whenObservationRegistryIsNullThenThrow() {
14921496
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null,
1493-
Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null,
1494-
List.of(), Map.of(), null, null, Map.of(), null))
1497+
Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(),
1498+
null, List.of(), Map.of(), null, null, Map.of(), null))
14951499
.isInstanceOf(IllegalArgumentException.class)
14961500
.hasMessage("observationRegistry cannot be null");
14971501
}
@@ -2197,6 +2201,115 @@ void whenUserConsumerWithNullParamValueThenThrow() {
21972201
.hasMessage("value cannot be null");
21982202
}
21992203

2204+
@Test
2205+
void whenToolCallbackProviderThenNotEagerlyEvaluated() {
2206+
ChatModel chatModel = mock(ChatModel.class);
2207+
ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
2208+
2209+
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
2210+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);
2211+
2212+
// Verify that getToolCallbacks() was NOT called during configuration
2213+
verify(provider, never()).getToolCallbacks();
2214+
}
2215+
2216+
@Test
2217+
void whenToolCallbackProviderThenLazilyEvaluatedOnCall() {
2218+
ChatModel chatModel = mock(ChatModel.class);
2219+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
2220+
given(chatModel.call(promptCaptor.capture()))
2221+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
2222+
2223+
ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
2224+
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});
2225+
2226+
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
2227+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);
2228+
2229+
// Verify not called yet
2230+
verify(provider, never()).getToolCallbacks();
2231+
2232+
// Execute the call
2233+
spec.call().content();
2234+
2235+
// Verify getToolCallbacks() WAS called during execution
2236+
verify(provider, times(1)).getToolCallbacks();
2237+
}
2238+
2239+
@Test
2240+
void whenToolCallbackProviderThenLazilyEvaluatedOnStream() {
2241+
ChatModel chatModel = mock(ChatModel.class);
2242+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
2243+
given(chatModel.stream(promptCaptor.capture()))
2244+
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))));
2245+
2246+
ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
2247+
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});
2248+
2249+
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
2250+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);
2251+
2252+
// Verify not called yet
2253+
verify(provider, never()).getToolCallbacks();
2254+
2255+
// Execute the stream
2256+
spec.stream().content().blockLast();
2257+
2258+
// Verify getToolCallbacks() WAS called during execution
2259+
verify(provider, times(1)).getToolCallbacks();
2260+
}
2261+
2262+
@Test
2263+
void whenMultipleToolCallbackProvidersThenAllLazilyEvaluated() {
2264+
ChatModel chatModel = mock(ChatModel.class);
2265+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
2266+
given(chatModel.call(promptCaptor.capture()))
2267+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
2268+
2269+
ToolCallbackProvider provider1 = mock(ToolCallbackProvider.class);
2270+
when(provider1.getToolCallbacks()).thenReturn(new ToolCallback[] {});
2271+
2272+
ToolCallbackProvider provider2 = mock(ToolCallbackProvider.class);
2273+
when(provider2.getToolCallbacks()).thenReturn(new ToolCallback[] {});
2274+
2275+
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
2276+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider1, provider2);
2277+
2278+
// Verify not called yet
2279+
verify(provider1, never()).getToolCallbacks();
2280+
verify(provider2, never()).getToolCallbacks();
2281+
2282+
// Execute the call
2283+
spec.call().content();
2284+
2285+
// Verify both getToolCallbacks() were called during execution
2286+
verify(provider1, times(1)).getToolCallbacks();
2287+
verify(provider2, times(1)).getToolCallbacks();
2288+
}
2289+
2290+
@Test
2291+
void whenToolCallbacksAndProvidersThenBothUsed() {
2292+
ChatModel chatModel = mock(ChatModel.class);
2293+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
2294+
given(chatModel.call(promptCaptor.capture()))
2295+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
2296+
2297+
ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
2298+
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});
2299+
2300+
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
2301+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);
2302+
2303+
// Verify provider not called yet
2304+
verify(provider, never()).getToolCallbacks();
2305+
2306+
// Execute the call
2307+
spec.call().content();
2308+
2309+
// Verify provider was called during execution
2310+
verify(provider, times(1)).getToolCallbacks();
2311+
}
2312+
22002313
record Person(String name) {
22012314
}
22022315

0 commit comments

Comments
 (0)