Skip to content

Commit

Permalink
Support tool calling #30
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Dec 26, 2024
1 parent 5651de7 commit e28924c
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 5 deletions.
2 changes: 1 addition & 1 deletion R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' chat("benzie/llava-phi-3", messages, output = 'text')
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/chat", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
}

Expand Down
48 changes: 46 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,49 @@ stream_handler <- function(x, env, endpoint) {





#' Get tool calls helper function
#'
#' Get tool calls from response object.
#'
#' @keywords internal
get_tool_calls <- function(resp) {
body <- httr2::resp_body_json(resp)
tools <- list()
tools_list <- list()
if (!is.null(body$message)) {
if (!is.null(body$message$tool_calls)) {
tools <- body$message$tool_calls

tools_list <- list()
if (length(tools) > 0) {
for (i in seq_along(tools)) {
func <- tools[[i]]$`function`
func_name <- func$name
tools_list[[i]] <- list()
names(tools_list)[[i]] <- func_name
tools_list[[func_name]] <- func
}
}

# remove empty lists
tools_list <- tools_list[which(sapply(tools_list, length) != 0)]
message(paste0("Tools: ", paste0(names(tools_list), collapse = ", ")))
}
}

return(tools_list)
}





#' Process httr2 response object
#'
#' @param resp A httr2 response object.
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text"
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls)
#'
#' @return A data frame, json list, raw or httr2 response object.
#' @export
Expand All @@ -88,7 +127,8 @@ stream_handler <- function(x, env, endpoint) {
#' resp_process(resp, "raw") # parse response to raw string
#' resp_process(resp, "resp") # return input response object
#' resp_process(resp, "text") # return text/character vector
resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text")) {
#' resp_process(resp, "tools") # return tool_calls
resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text", "tools")) {

if (!inherits(resp, "httr2_response")) {
stop("Input must be a httr2 response object")
Expand All @@ -111,6 +151,10 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
return(resp)
}

if (output == "tools") {
return(get_tool_calls(resp))
}

# process stream resp separately
stream <- FALSE
headers <- httr2::resp_headers(resp)
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ reference:
- model_options
- stream_handler
- resp_process_stream
- get_tool_calls
12 changes: 12 additions & 0 deletions man/get_tool_calls.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions man/resp_process.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 105 additions & 0 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,108 @@ test_that("chat function handles images in messages", {
expect_true(grepl("melon", tolower(result)) | grepl("cam", tolower(result)))

})


test_that("chat function handles tools", {
skip_if_not(test_connection(), "Ollama server not available")

add_two_numbers <- function(x, y) {
return(x + y)
}

multiply_two_numbers <- function(x, y) {
return(x * y)
}

tools <- list(list(type = "function",
"function" = list(
name = "add_two_numbers",
description = "add two numbers",
parameters = list(
type = "object",
required = list("x", "y"),
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")
)
)
)
)
)

msg <- create_message("what is three plus one?")
resp <- chat("llama3.1", msg, tools = tools, output = "tools")
resp2 <- resp[[1]]
expect_equal(resp2$name, "add_two_numbers")
expect_equal(do.call(resp2$name, resp2$arguments), 4)

tools <- list(list(type = "function",
"function" = list(
name = "multiply_two_numbers",
description = "multiply two numbers",
parameters = list(
type = "object",
required = list("x", "y"),
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")
)
)
)
)
)

msg <- create_message("what is three times eleven?")
resp <- chat("llama3.1", msg, tools = tools, output = "tools")
resp2 <- resp[[1]]
expect_equal(resp2$name, "multiply_two_numbers")
expect_equal(do.call(resp2$name, resp2$arguments), 33)



# test multiple tools
msg <- create_message("what is three plus one? then multiply the output of that by ten")
tools <- list(list(type = "function",
"function" = list(
name = "add_two_numbers",
description = "add two numbers",
parameters = list(
type = "object",
required = list("x", "y"),
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")
)
)
)
),
list(type = "function",
"function" = list(
name = "multiply_two_numbers",
description = "multiply two numbers",
parameters = list(
type = "object",
required = list("x", "y"),
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")
)
)
)
)
)

msg <- create_message("what is four plus five?")
resp <- chat("llama3.1", msg, tools = tools, output = "tools")
expect_equal(resp[[1]]$name, "add_two_numbers")

msg <- create_message("what is four times five?")
resp <- chat("llama3.1", msg, tools = tools, output = "tools")
expect_equal(resp[[1]]$name, "multiply_two_numbers")

msg <- create_message("three and four. sum the numbers then multiply the output by ten")
resp <- chat("llama3.1", msg, tools = tools, output = "tools")
expect_equal(resp[[1]]$name, "add_two_numbers")
expect_equal(resp[[2]]$name, "multiply_two_numbers")

})
73 changes: 73 additions & 0 deletions vignettes/ollamar.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,79 @@ resp_process(resp, "text") # text vector

## Advanced usage

### Tool calling

You can use [tool calling](https://ollama.com/blog/tool-support) with the `chat()` function with certain models such as Llama3.1. See also [Python examples](https://github.com/ollama/ollama-python/blob/main/examples/tools.py).

First, define your tools as functions. Two example functions are shown below.

```{r eval=FALSE}
add_two_numbers <- function(x, y) {
return(x + y)
}
multiply_two_numbers <- function(a, b) {
return(a * b)
}
# each tool needs to be in a list
tool1 <- list(type = "function",
"function" = list(
name = "add_two_numbers", # function name
description = "add two numbers",
parameters = list(
type = "object",
required = list("x", "y"), # function parameters
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")))
)
)
tool2 <- list(type = "function",
"function" = list(
name = "multiply_two_numbers", # function name
description = "add two numbers",
parameters = list(
type = "object",
required = list("a", "b"), # function parameters
properties = list(
x = list(class = "numeric", description = "first number"),
y = list(class = "numeric", description = "second number")))
)
)
```

Next call the `chat()` function with the `tools` parameter set to a list of your tools. Pass in a single tool.

```{r eval=FALSE}
msg <- create_message("what is three plus one?")
resp <- chat("llama3.1", msg, tools = list(tool1), output = "tools")
tool <- resp[[1]] # get the first tool/function
do.call(tool$name, tool$arguments) # call the tool function with arguments: add_two_numbers(3, 1)
```

Pass in multiple tools. The model will pick the best tool to use based on the context of the message.

```{r eval=FALSE}
msg <- create_message("what is three times four?")
resp <- chat("llama3.1", msg, tools = list(tool1, tool2), output = "tools")
tool <- resp[[1]] # get the first tool/function
do.call(tool$name, tool$arguments) # call the tool function with arguments: multiply_two_numbers(3, 4)
```

Pass in multiple tools and get the model to use multiple tools.

```{r eval=FALSE}
msg <- create_message("add three plus four. then multiply by ten")
resp <- chat("llama3.1", msg, tools = list(tool1, tool2), output = "tools")
# first tool/function: add_two_numbers(3, 4)
do.call(resp[[1]]$name, resp[[1]]$arguments) # 7
# multiply_two_numbers(7, 10)
do.call(resp[[2]]$name, resp[[2]]$arguments) # 70
```

### Parallel requests

For the `generate()` and `chat()` endpoints/functions, you can specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.
Expand Down

0 comments on commit e28924c

Please sign in to comment.