From af21430853086d8656e7c4707275821687772baa Mon Sep 17 00:00:00 2001 From: Sophia Date: Thu, 14 Aug 2025 16:54:46 -0700 Subject: [PATCH 1/2] feat: add tool choice and parallel tool calls control - Implement provider-specific tool choice handling - Add InvalidToolChoiceError for validation - Enhance tool execution flow to prevent infinite loops with non-auto choices --- lib/ruby_llm/chat.rb | 44 +++++++++++++++++++---- lib/ruby_llm/error.rb | 1 + lib/ruby_llm/provider.rb | 7 +++- lib/ruby_llm/providers/anthropic/chat.rb | 15 +++++--- lib/ruby_llm/providers/anthropic/tools.rb | 9 +++++ lib/ruby_llm/providers/bedrock/chat.rb | 8 +++-- lib/ruby_llm/providers/gemini/chat.rb | 12 +++++-- lib/ruby_llm/providers/gemini/tools.rb | 15 ++++++++ lib/ruby_llm/providers/mistral/chat.rb | 3 +- lib/ruby_llm/providers/openai/chat.rb | 11 ++++-- lib/ruby_llm/providers/openai/tools.rb | 16 +++++++++ 11 files changed, 122 insertions(+), 19 deletions(-) diff --git a/lib/ruby_llm/chat.rb b/lib/ruby_llm/chat.rb index 7131aeb7..f9ec5150 100644 --- a/lib/ruby_llm/chat.rb +++ b/lib/ruby_llm/chat.rb @@ -11,7 +11,7 @@ module RubyLLM class Chat include Enumerable - attr_reader :model, :messages, :tools, :params, :headers, :schema + attr_reader :model, :messages, :tools, :tool_choice, :parallel_tool_calls, :params, :headers, :schema def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil) if assume_model_exists && !provider @@ -23,6 +23,8 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n model_id = model || @config.default_model with_model(model_id, provider: provider, assume_exists: assume_model_exists) @temperature = 0.7 + @tool_choice = nil + @parallel_tool_calls = nil @messages = [] @tools = {} @params = {} @@ -50,15 +52,19 @@ def with_instructions(instructions, replace: false) self end - def with_tool(tool) - tool_instance = tool.is_a?(Class) ? tool.new : tool - @tools[tool_instance.name.to_sym] = tool_instance + def with_tool(tool, choice: nil, parallel: nil) + unless tool.nil? + tool_instance = tool.is_a?(Class) ? tool.new : tool + @tools[tool_instance.name.to_sym] = tool_instance + end + update_tool_options(choice:, parallel:) self end - def with_tools(*tools, replace: false) + def with_tools(*tools, replace: false, choice: nil, parallel: nil) @tools.clear if replace tools.compact.each { |tool| with_tool tool } + update_tool_options(choice:, parallel:) self end @@ -136,6 +142,8 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity params: @params, headers: @headers, schema: @schema, + tool_choice: @tool_choice, + parallel_tool_calls: @parallel_tool_calls, &wrap_streaming_block(&) ) @@ -189,7 +197,7 @@ def wrap_streaming_block(&block) end end - def handle_tool_calls(response, &) + def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity halt_result = nil response.tool_calls.each_value do |tool_call| @@ -203,7 +211,9 @@ def handle_tool_calls(response, &) halt_result = result if result.is_a?(Tool::Halt) end - halt_result || complete(&) + return halt_result if halt_result + + should_continue_after_tools? ? complete(&) : response end def execute_tool(tool_call) @@ -212,6 +222,26 @@ def execute_tool(tool_call) tool.call(args) end + def update_tool_options(choice:, parallel:) + unless choice.nil? + valid_tool_choices = %i[auto none any] + tools.keys + unless valid_tool_choices.include?(choice.to_sym) + raise InvalidToolChoiceError, + "Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}" + end + + @tool_choice = choice.to_sym + end + + @parallel_tool_calls = !!parallel unless parallel.nil? + end + + def should_continue_after_tools? + # Continue conversation only with :auto tool choice to avoid infinite loops. + # With :any or specific tool choices, the model would keep calling tools repeatedly. + tool_choice.nil? || tool_choice == :auto + end + def instance_variables super - %i[@connection @config] end diff --git a/lib/ruby_llm/error.rb b/lib/ruby_llm/error.rb index 1bcc018b..a20fa0f2 100644 --- a/lib/ruby_llm/error.rb +++ b/lib/ruby_llm/error.rb @@ -22,6 +22,7 @@ def initialize(response = nil, message = nil) # Error classes for non-HTTP errors class ConfigurationError < StandardError; end class InvalidRoleError < StandardError; end + class InvalidToolChoiceError < StandardError; end class ModelNotFoundError < StandardError; end class UnsupportedAttachmentError < StandardError; end diff --git a/lib/ruby_llm/provider.rb b/lib/ruby_llm/provider.rb index ed0c2ac8..be8c3127 100644 --- a/lib/ruby_llm/provider.rb +++ b/lib/ruby_llm/provider.rb @@ -40,7 +40,9 @@ def configuration_requirements self.class.configuration_requirements end - def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists + # rubocop:disable Metrics/ParameterLists + def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, + tool_choice: nil, parallel_tool_calls: nil, &) normalized_temperature = maybe_normalize_temperature(temperature, model) payload = Utils.deep_merge( @@ -48,6 +50,8 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc render_payload( messages, tools: tools, + tool_choice: tool_choice, + parallel_tool_calls: parallel_tool_calls, temperature: normalized_temperature, model: model, stream: block_given?, @@ -61,6 +65,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc sync_response @connection, payload, headers end end + # rubocop:enable Metrics/ParameterLists def list_models response = @connection.get models_url diff --git a/lib/ruby_llm/providers/anthropic/chat.rb b/lib/ruby_llm/providers/anthropic/chat.rb index 64cd764b..cbc0db92 100644 --- a/lib/ruby_llm/providers/anthropic/chat.rb +++ b/lib/ruby_llm/providers/anthropic/chat.rb @@ -11,14 +11,17 @@ def completion_url '/v1/messages' end - def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument + # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument + def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, + temperature:, model:, stream: false, schema: nil) system_messages, chat_messages = separate_messages(messages) system_content = build_system_content(system_messages) build_base_payload(chat_messages, model, stream).tap do |payload| - add_optional_fields(payload, system_content:, tools:, temperature:) + add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:) end end + # rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument def separate_messages(messages) messages.partition { |msg| msg.role == :system } @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream) } end - def add_optional_fields(payload, system_content:, tools:, temperature:) - payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any? + def add_optional_fields(payload, system_content:, tools:, tool_choice:, parallel_tool_calls:, temperature:) # rubocop:disable Metrics/ParameterLists + if tools.any? + payload[:tools] = tools.values.map { |t| Tools.function_for(t) } + payload[:tool_choice] = build_tool_choice(tool_choice, parallel_tool_calls) unless tool_choice.nil? + end + payload[:system] = system_content unless system_content.empty? payload[:temperature] = temperature unless temperature.nil? end diff --git a/lib/ruby_llm/providers/anthropic/tools.rb b/lib/ruby_llm/providers/anthropic/tools.rb index 0b4d1c7c..8f26cb79 100644 --- a/lib/ruby_llm/providers/anthropic/tools.rb +++ b/lib/ruby_llm/providers/anthropic/tools.rb @@ -102,6 +102,15 @@ def clean_parameters(parameters) def required_parameters(parameters) parameters.select { |_, param| param.required }.keys end + + def build_tool_choice(tool_choice, parallel_tool_calls) + { + type: %i[auto any none].include?(tool_choice) ? tool_choice : :tool + }.tap do |tc| + tc[:name] = tool_choice if tc[:type] == :tool + tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil? + end + end end end end diff --git a/lib/ruby_llm/providers/bedrock/chat.rb b/lib/ruby_llm/providers/bedrock/chat.rb index 94655d11..a4d5220f 100644 --- a/lib/ruby_llm/providers/bedrock/chat.rb +++ b/lib/ruby_llm/providers/bedrock/chat.rb @@ -40,7 +40,9 @@ def completion_url "model/#{@model_id}/invoke" end - def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists + # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument + def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, + temperature:, model:, stream: false, schema: nil) # Hold model_id in instance variable for use in completion_url and stream_url @model_id = model @@ -48,9 +50,11 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema system_content = Anthropic::Chat.build_system_content(system_messages) build_base_payload(chat_messages, model).tap do |payload| - Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:) + Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_choice:, + parallel_tool_calls:, temperature:) end end + # rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument def build_base_payload(chat_messages, model) { diff --git a/lib/ruby_llm/providers/gemini/chat.rb b/lib/ruby_llm/providers/gemini/chat.rb index 5254722a..d25dd8c2 100644 --- a/lib/ruby_llm/providers/gemini/chat.rb +++ b/lib/ruby_llm/providers/gemini/chat.rb @@ -11,7 +11,9 @@ def completion_url "models/#{@model}:generateContent" end - def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument + # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument + def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, + temperature:, model:, stream: false, schema: nil) @model = model # Store model for completion_url/stream_url payload = { contents: format_messages(messages), @@ -25,9 +27,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema) end - payload[:tools] = format_tools(tools) if tools.any? + if tools.any? + payload[:tools] = format_tools(tools) + # Gemini doesn't support controlling parallel tool calls + payload[:toolConfig] = build_tool_config(tool_choice) unless tool_choice.nil? + end + payload end + # rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument private diff --git a/lib/ruby_llm/providers/gemini/tools.rb b/lib/ruby_llm/providers/gemini/tools.rb index fd178d24..7a31a0a1 100644 --- a/lib/ruby_llm/providers/gemini/tools.rb +++ b/lib/ruby_llm/providers/gemini/tools.rb @@ -76,6 +76,21 @@ def param_type_for_gemini(type) else 'STRING' end end + + def build_tool_config(tool_choice) + { + functionCallingConfig: { + mode: specific_tool_choice?(tool_choice) ? 'any' : tool_choice + }.tap do |config| + # Use allowedFunctionNames to simulate specific tool choice + config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice) + end + } + end + + def specific_tool_choice?(tool_choice) + !%i[auto none any].include?(tool_choice) + end end end end diff --git a/lib/ruby_llm/providers/mistral/chat.rb b/lib/ruby_llm/providers/mistral/chat.rb index 74d508d1..c2a6321e 100644 --- a/lib/ruby_llm/providers/mistral/chat.rb +++ b/lib/ruby_llm/providers/mistral/chat.rb @@ -13,7 +13,8 @@ def format_role(role) end # rubocop:disable Metrics/ParameterLists - def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) + def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, temperature:, model:, stream: false, + schema: nil) payload = super # Mistral doesn't support stream_options payload.delete(:stream_options) diff --git a/lib/ruby_llm/providers/openai/chat.rb b/lib/ruby_llm/providers/openai/chat.rb index 9ed7e170..7d0a5fd6 100644 --- a/lib/ruby_llm/providers/openai/chat.rb +++ b/lib/ruby_llm/providers/openai/chat.rb @@ -11,7 +11,9 @@ def completion_url module_function - def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists + # rubocop:disable Metrics/ParameterLists + def render_payload(messages, tools:, tool_choice:, parallel_tool_calls:, + temperature:, model:, stream: false, schema: nil) payload = { model: model, messages: format_messages(messages), @@ -21,7 +23,11 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema # Only include temperature if it's not nil (some models don't accept it) payload[:temperature] = temperature unless temperature.nil? - payload[:tools] = tools.map { |_, tool| tool_for(tool) } if tools.any? + if tools.any? + payload[:tools] = tools.map { |_, tool| tool_for(tool) } + payload[:tool_choice] = build_tool_choice(tool_choice) unless tool_choice.nil? + payload[:parallel_tool_calls] = parallel_tool_calls unless parallel_tool_calls.nil? + end if schema # Use strict mode from schema if specified, default to true @@ -40,6 +46,7 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema payload[:stream_options] = { include_usage: true } if stream payload end + # rubocop:enable Metrics/ParameterLists def parse_completion_response(response) data = response.body diff --git a/lib/ruby_llm/providers/openai/tools.rb b/lib/ruby_llm/providers/openai/tools.rb index e4b76c0c..7cf87d89 100644 --- a/lib/ruby_llm/providers/openai/tools.rb +++ b/lib/ruby_llm/providers/openai/tools.rb @@ -67,6 +67,22 @@ def parse_tool_calls(tool_calls, parse_arguments: true) ] end end + + def build_tool_choice(tool_choice) + case tool_choice + when :auto, :none + tool_choice + when :any + :required + else + { + type: 'function', + function: { + name: tool_choice + } + } + end + end end end end From 47923bce1257bad450cad69a012e51a726e04a8e Mon Sep 17 00:00:00 2001 From: Sophia Date: Thu, 14 Aug 2025 18:00:05 -0700 Subject: [PATCH 2/2] docs: add Tool Choice Control section to tools.md --- docs/_core_features/tools.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/_core_features/tools.md b/docs/_core_features/tools.md index aadb73f6..3e352804 100644 --- a/docs/_core_features/tools.md +++ b/docs/_core_features/tools.md @@ -136,6 +136,27 @@ puts response.content # => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast." ``` +### Tool Choice Control + +Control when and how tools are called using `choice` and `parallel` options: + +```ruby +chat = RubyLLM.chat(model: 'gpt-4o') + +# Choice options +chat.with_tool(Weather, choice: :auto) # Model decides whether to call any provided tools or not (default) +chat.with_tool(Weather, choice: :any) # Model must use one of the provided tools +chat.with_tool(Weather, choice: :none) # No tools +chat.with_tool(Weather, choice: :weather) # Force specific tool + +# Parallel tool calls +chat.with_tools(Weather, Calculator, parallel: true) # Model can output multiple tool calls at once (default) +chat.with_tools(Weather, Calculator, parallel: false) # At most one tool call +``` + +> With `:any` or specific tool choices, tool results are not automatically sent back to the AI model (see The Tool Execution Flow section below) to prevent infinite loops. +{: .note } + ### Model Compatibility {: .d-inline-block }