From e28924c2030d8f637a954a5a7e150bea6d380afc Mon Sep 17 00:00:00 2001 From: Hause Lin Date: Wed, 25 Dec 2024 22:16:18 -0500 Subject: [PATCH] Support tool calling #30 --- R/ollama.R | 2 +- R/utils.R | 48 ++++++++++++++++- _pkgdown.yml | 1 + man/get_tool_calls.Rd | 12 +++++ man/resp_process.Rd | 8 ++- tests/testthat/test-chat.R | 105 +++++++++++++++++++++++++++++++++++++ vignettes/ollamar.Rmd | 73 ++++++++++++++++++++++++++ 7 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 man/get_tool_calls.Rd diff --git a/R/ollama.R b/R/ollama.R index e3dd592..93b0a05 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -210,7 +210,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ #' chat("benzie/llava-phi-3", messages, output = 'text') chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ...) { output <- output[1] - if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) { + if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools")) { stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'") } diff --git a/R/utils.R b/R/utils.R index 397fb72..5ec3cc7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -73,10 +73,49 @@ stream_handler <- function(x, env, endpoint) { + + +#' Get tool calls helper function +#' +#' Get tool calls from response object. +#' +#' @keywords internal +get_tool_calls <- function(resp) { + body <- httr2::resp_body_json(resp) + tools <- list() + tools_list <- list() + if (!is.null(body$message)) { + if (!is.null(body$message$tool_calls)) { + tools <- body$message$tool_calls + + tools_list <- list() + if (length(tools) > 0) { + for (i in seq_along(tools)) { + func <- tools[[i]]$`function` + func_name <- func$name + tools_list[[i]] <- list() + names(tools_list)[[i]] <- func_name + tools_list[[func_name]] <- func + } + } + + # remove empty lists + tools_list <- tools_list[which(sapply(tools_list, length) != 0)] + message(paste0("Tools: ", paste0(names(tools_list), collapse = ", "))) + } + } + + return(tools_list) +} + + + + + #' Process httr2 response object #' #' @param resp A httr2 response object. -#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text" +#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls) #' #' @return A data frame, json list, raw or httr2 response object. #' @export @@ -88,7 +127,8 @@ stream_handler <- function(x, env, endpoint) { #' resp_process(resp, "raw") # parse response to raw string #' resp_process(resp, "resp") # return input response object #' resp_process(resp, "text") # return text/character vector -resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text")) { +#' resp_process(resp, "tools") # return tool_calls +resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text", "tools")) { if (!inherits(resp, "httr2_response")) { stop("Input must be a httr2 response object") @@ -111,6 +151,10 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text return(resp) } + if (output == "tools") { + return(get_tool_calls(resp)) + } + # process stream resp separately stream <- FALSE headers <- httr2::resp_headers(resp) diff --git a/_pkgdown.yml b/_pkgdown.yml index d8a8f82..719ff65 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -84,3 +84,4 @@ reference: - model_options - stream_handler - resp_process_stream + - get_tool_calls diff --git a/man/get_tool_calls.Rd b/man/get_tool_calls.Rd new file mode 100644 index 0000000..15a878c --- /dev/null +++ b/man/get_tool_calls.Rd @@ -0,0 +1,12 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{get_tool_calls} +\alias{get_tool_calls} +\title{Get tool calls helper function} +\usage{ +get_tool_calls(resp) +} +\description{ +Get tool calls from response object. +} +\keyword{internal} diff --git a/man/resp_process.Rd b/man/resp_process.Rd index caf8421..11f9453 100644 --- a/man/resp_process.Rd +++ b/man/resp_process.Rd @@ -4,12 +4,15 @@ \alias{resp_process} \title{Process httr2 response object} \usage{ -resp_process(resp, output = c("df", "jsonlist", "raw", "resp", "text")) +resp_process( + resp, + output = c("df", "jsonlist", "raw", "resp", "text", "tools") +) } \arguments{ \item{resp}{A httr2 response object.} -\item{output}{The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text"} +\item{output}{The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls)} } \value{ A data frame, json list, raw or httr2 response object. @@ -25,5 +28,6 @@ resp_process(resp, "jsonlist") # parse response to list resp_process(resp, "raw") # parse response to raw string resp_process(resp, "resp") # return input response object resp_process(resp, "text") # return text/character vector +resp_process(resp, "tools") # return tool_calls \dontshow{\}) # examplesIf} } diff --git a/tests/testthat/test-chat.R b/tests/testthat/test-chat.R index 680a717..4ad798e 100644 --- a/tests/testthat/test-chat.R +++ b/tests/testthat/test-chat.R @@ -128,3 +128,108 @@ test_that("chat function handles images in messages", { expect_true(grepl("melon", tolower(result)) | grepl("cam", tolower(result))) }) + + +test_that("chat function handles tools", { + skip_if_not(test_connection(), "Ollama server not available") + + add_two_numbers <- function(x, y) { + return(x + y) + } + + multiply_two_numbers <- function(x, y) { + return(x * y) + } + + tools <- list(list(type = "function", + "function" = list( + name = "add_two_numbers", + description = "add two numbers", + parameters = list( + type = "object", + required = list("x", "y"), + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number") + ) + ) + ) + ) + ) + + msg <- create_message("what is three plus one?") + resp <- chat("llama3.1", msg, tools = tools, output = "tools") + resp2 <- resp[[1]] + expect_equal(resp2$name, "add_two_numbers") + expect_equal(do.call(resp2$name, resp2$arguments), 4) + + tools <- list(list(type = "function", + "function" = list( + name = "multiply_two_numbers", + description = "multiply two numbers", + parameters = list( + type = "object", + required = list("x", "y"), + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number") + ) + ) + ) + ) + ) + + msg <- create_message("what is three times eleven?") + resp <- chat("llama3.1", msg, tools = tools, output = "tools") + resp2 <- resp[[1]] + expect_equal(resp2$name, "multiply_two_numbers") + expect_equal(do.call(resp2$name, resp2$arguments), 33) + + + + # test multiple tools + msg <- create_message("what is three plus one? then multiply the output of that by ten") + tools <- list(list(type = "function", + "function" = list( + name = "add_two_numbers", + description = "add two numbers", + parameters = list( + type = "object", + required = list("x", "y"), + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number") + ) + ) + ) + ), + list(type = "function", + "function" = list( + name = "multiply_two_numbers", + description = "multiply two numbers", + parameters = list( + type = "object", + required = list("x", "y"), + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number") + ) + ) + ) + ) + ) + + msg <- create_message("what is four plus five?") + resp <- chat("llama3.1", msg, tools = tools, output = "tools") + expect_equal(resp[[1]]$name, "add_two_numbers") + + msg <- create_message("what is four times five?") + resp <- chat("llama3.1", msg, tools = tools, output = "tools") + expect_equal(resp[[1]]$name, "multiply_two_numbers") + + msg <- create_message("three and four. sum the numbers then multiply the output by ten") + resp <- chat("llama3.1", msg, tools = tools, output = "tools") + expect_equal(resp[[1]]$name, "add_two_numbers") + expect_equal(resp[[2]]$name, "multiply_two_numbers") + +}) diff --git a/vignettes/ollamar.Rmd b/vignettes/ollamar.Rmd index e6ea351..4d2f725 100644 --- a/vignettes/ollamar.Rmd +++ b/vignettes/ollamar.Rmd @@ -285,6 +285,79 @@ resp_process(resp, "text") # text vector ## Advanced usage +### Tool calling + +You can use [tool calling](https://ollama.com/blog/tool-support) with the `chat()` function with certain models such as Llama3.1. See also [Python examples](https://github.com/ollama/ollama-python/blob/main/examples/tools.py). + +First, define your tools as functions. Two example functions are shown below. + +```{r eval=FALSE} +add_two_numbers <- function(x, y) { + return(x + y) +} + +multiply_two_numbers <- function(a, b) { + return(a * b) +} + +# each tool needs to be in a list +tool1 <- list(type = "function", + "function" = list( + name = "add_two_numbers", # function name + description = "add two numbers", + parameters = list( + type = "object", + required = list("x", "y"), # function parameters + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number"))) + ) + ) + +tool2 <- list(type = "function", + "function" = list( + name = "multiply_two_numbers", # function name + description = "add two numbers", + parameters = list( + type = "object", + required = list("a", "b"), # function parameters + properties = list( + x = list(class = "numeric", description = "first number"), + y = list(class = "numeric", description = "second number"))) + ) + ) +``` + +Next call the `chat()` function with the `tools` parameter set to a list of your tools. Pass in a single tool. + +```{r eval=FALSE} +msg <- create_message("what is three plus one?") +resp <- chat("llama3.1", msg, tools = list(tool1), output = "tools") +tool <- resp[[1]] # get the first tool/function +do.call(tool$name, tool$arguments) # call the tool function with arguments: add_two_numbers(3, 1) +``` + +Pass in multiple tools. The model will pick the best tool to use based on the context of the message. + +```{r eval=FALSE} +msg <- create_message("what is three times four?") +resp <- chat("llama3.1", msg, tools = list(tool1, tool2), output = "tools") +tool <- resp[[1]] # get the first tool/function +do.call(tool$name, tool$arguments) # call the tool function with arguments: multiply_two_numbers(3, 4) +``` + +Pass in multiple tools and get the model to use multiple tools. + +```{r eval=FALSE} +msg <- create_message("add three plus four. then multiply by ten") +resp <- chat("llama3.1", msg, tools = list(tool1, tool2), output = "tools") + +# first tool/function: add_two_numbers(3, 4) +do.call(resp[[1]]$name, resp[[1]]$arguments) # 7 +# multiply_two_numbers(7, 10) +do.call(resp[[2]]$name, resp[[2]]$arguments) # 70 +``` + ### Parallel requests For the `generate()` and `chat()` endpoints/functions, you can specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.