diff --git a/NEWS.md b/NEWS.md index ab5d57a..58ea07e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # ollamar (development version) +- `generate()` and `chat()` support [structured output](https://ollama.com/blog/structured-outputs) via `format` parameter. +- `test_connection()` returns boolean instead of `httr2` object. #29 +- `chat()` supports [tool calling](https://ollama.com/blog/tool-support) via `tools` parameter. Added `get_tool_calls()` helper function to process tools. #30 +- Simplify README and add Get started vignette with more examples. + # ollamar 1.2.1 - `generate()` and `chat()` accept multiple images as prompts/messages. diff --git a/R/ollama.R b/R/ollama.R index 2c78bd6..77d8045 100644 --- a/R/ollama.R +++ b/R/ollama.R @@ -58,6 +58,7 @@ create_request <- function(endpoint, host = NULL) { #' @param prompt A character string of the prompt like "The sky is..." #' @param suffix A character string after the model response. Default is "". #' @param images A path to an image file to include in the prompt. Default is "". +#' @param format Format to return a response in. Format can be json/list (structured response). #' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "". #' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "". #' @param context A list of context from a previous response to include previous conversation in the prompt. Default is an empty list. @@ -86,10 +87,10 @@ create_request <- function(endpoint, host = NULL) { #' image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png") #' # use vision or multimodal model such as https://ollama.com/benzie/llava-phi-3 #' generate("benzie/llava-phi-3:latest", "What is in the image?", images = image_path, output = "text") -generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ...) { +generate <- function(model, prompt, suffix = "", images = "", format = list(), system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "structured"), endpoint = "/api/generate", host = NULL, ...) { output <- output[1] - if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) { - stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req'") + if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "structured")) { + stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req', 'structured'") } req <- create_request(endpoint, host) @@ -112,6 +113,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ keep_alive = keep_alive ) + if (length(format) != 0 & inherits(format, "list")) { + body_json$format <- format + } + # check if model options are passed and specified correctly opts <- list(...) if (length(opts) > 0) { @@ -169,8 +174,9 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ #' @param messages A list with list of messages for the model (see examples below). #' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list. #' @param stream Enable response streaming. Default is FALSE. +#' @param format Format to return a response in. Format can be json/list (structured response). #' @param keep_alive The duration to keep the connection alive. Default is "5m". -#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling) +#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling), "structured" (structured output) #' @param endpoint The endpoint to chat with the model. Default is "/api/chat". #' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL. #' @param ... Additional options to pass to the model. @@ -208,10 +214,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ #' list(role = "user", content = "What is in the image?", images = image_path) #' ) #' chat("benzie/llava-phi-3", messages, output = 'text') -chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools"), endpoint = "/api/chat", host = NULL, ...) { +chat <- function(model, messages, tools = list(), stream = FALSE, format = list(), keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools", "structured"), endpoint = "/api/chat", host = NULL, ...) { output <- output[1] - if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools")) { - stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'") + if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools", "structured")) { + stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'tools', 'structured'") } req <- create_request(endpoint, host) @@ -231,6 +237,10 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = " keep_alive = keep_alive ) + if (length(format) != 0 & inherits(format, "list")) { + body_json$format <- format + } + opts <- list(...) if (length(opts) > 0) { if (validate_options(...)) { diff --git a/R/utils.R b/R/utils.R index 23e30ff..d7898ab 100644 --- a/R/utils.R +++ b/R/utils.R @@ -112,7 +112,7 @@ get_tool_calls <- function(resp) { #' Process httr2 response object #' #' @param resp A httr2 response object. -#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls) +#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls), "structured" (structured output). #' #' @return A data frame, json list, raw or httr2 response object. #' @export @@ -122,7 +122,6 @@ get_tool_calls <- function(resp) { #' resp_process(resp, "df") # parse response to dataframe/tibble #' resp_process(resp, "jsonlist") # parse response to list #' resp_process(resp, "raw") # parse response to raw string -#' resp_process(resp, "resp") # return input response object #' resp_process(resp, "text") # return text/character vector #' resp_process(resp, "tools") # return tool_calls resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text", "tools")) { @@ -195,6 +194,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text return(df_response) } else if (output == "text") { return(df_response$response) + } else if (output == "structured") { + return(jsonlite::fromJSON(df_response$response)) } } else if (grepl("api/chat", resp$url)) { # process chat endpoint json_body <- httr2::resp_body_json(resp) @@ -209,6 +210,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text return(df_response) } else if (output == "text") { return(df_response$content) + } else if (output == "structured") { + return(jsonlite::fromJSON(df_response$content)) } } else if (grepl("api/tags", resp$url)) { # process tags endpoint json_body <- httr2::resp_body_json(resp)[[1]] diff --git a/README.Rmd b/README.Rmd index 32b45de..6659116 100644 --- a/README.Rmd +++ b/README.Rmd @@ -27,7 +27,7 @@ The library also makes it easy to work with data structures (e.g., conversationa To use this R library, ensure the [Ollama](https://ollama.com) app is installed. Ollama can use GPUs for accelerating LLM inference. See [Ollama GPU documentation](https://github.com/ollama/ollama/blob/main/docs/gpu.md) for more information. -See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and has been tested on Ollama v0.1.30 and above. It was last tested on Ollama v0.3.10. +See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and was last tested on v0.5.4. > Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. diff --git a/README.md b/README.md index 069e9c7..a0b9b28 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,7 @@ for more information. See [Ollama’s Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) -and has been tested on Ollama v0.1.30 and above. It was last tested on -Ollama v0.3.10. +and was last tested on v0.5.4. > Note: You should have at least 8 GB of RAM available to run the 7B > models, 16 GB to run the 13B models, and 32 GB to run the 33B models. diff --git a/man/chat.Rd b/man/chat.Rd index a58412d..8cb96a8 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -9,8 +9,9 @@ chat( messages, tools = list(), stream = FALSE, + format = list(), keep_alive = "5m", - output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools"), + output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools", "structured"), endpoint = "/api/chat", host = NULL, ... @@ -25,9 +26,11 @@ chat( \item{stream}{Enable response streaming. Default is FALSE.} +\item{format}{Format to return a response in. Format can be json/list (structured response).} + \item{keep_alive}{The duration to keep the connection alive. Default is "5m".} -\item{output}{The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling)} +\item{output}{The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling), "structured" (structured output)} \item{endpoint}{The endpoint to chat with the model. Default is "/api/chat".} diff --git a/man/generate.Rd b/man/generate.Rd index 77c622c..7cc8e44 100644 --- a/man/generate.Rd +++ b/man/generate.Rd @@ -9,13 +9,14 @@ generate( prompt, suffix = "", images = "", + format = list(), system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", - output = c("resp", "jsonlist", "raw", "df", "text", "req"), + output = c("resp", "jsonlist", "raw", "df", "text", "req", "structured"), endpoint = "/api/generate", host = NULL, ... @@ -30,6 +31,8 @@ generate( \item{images}{A path to an image file to include in the prompt. Default is "".} +\item{format}{Format to return a response in. Format can be json/list (structured response).} + \item{system}{A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".} \item{template}{A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".} diff --git a/man/resp_process.Rd b/man/resp_process.Rd index 11f9453..3f4fc27 100644 --- a/man/resp_process.Rd +++ b/man/resp_process.Rd @@ -12,7 +12,7 @@ resp_process( \arguments{ \item{resp}{A httr2 response object.} -\item{output}{The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls)} +\item{output}{The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls), "structured" (structured output).} } \value{ A data frame, json list, raw or httr2 response object. @@ -26,7 +26,6 @@ resp <- list_models("resp") resp_process(resp, "df") # parse response to dataframe/tibble resp_process(resp, "jsonlist") # parse response to list resp_process(resp, "raw") # parse response to raw string -resp_process(resp, "resp") # return input response object resp_process(resp, "text") # return text/character vector resp_process(resp, "tools") # return tool_calls \dontshow{\}) # examplesIf} diff --git a/tests/testthat/test-chat.R b/tests/testthat/test-chat.R index 1ee1e34..208f18f 100644 --- a/tests/testthat/test-chat.R +++ b/tests/testthat/test-chat.R @@ -130,7 +130,7 @@ test_that("chat function handles images in messages", { }) -test_that("chat function handles tools", { +test_that("chat function tool calling", { skip_if_not(test_connection(), "Ollama server not available") add_two_numbers <- function(x, y) { @@ -188,7 +188,7 @@ test_that("chat function handles tools", { # test multiple tools - msg <- create_message("what is three plus one? then multiply the output of that by ten") + msg <- create_message("add three plus four. then multiply by ten") tools <- list(list(type = "function", "function" = list( name = "add_two_numbers", @@ -234,3 +234,38 @@ test_that("chat function handles tools", { # expect_equal(resp[[2]]$name, "multiply_two_numbers") }) + + + + +test_that("structured output", { + skip_if_not(test_connection(), "Ollama server not available") + + format <- list( + type = "object", + properties = list( + name = list( + type = "string" + ), + capital = list( + type = "string" + ), + languages = list( + type = "array", + items = list( + type = "string" + ) + ) + ), + required = list("name", "capital", "languages") + ) + + msg <- create_message("tell me about canada") + resp <- chat("llama3.1", msg, format = format) + # content <- httr2::resp_body_json(resp)$message$content + structured_output <- resp_process(resp, "structured") + expect_equal(tolower(structured_output$name), "canada") + +}) + + diff --git a/tests/testthat/test-generate.R b/tests/testthat/test-generate.R index 0263830..aca5bcc 100644 --- a/tests/testthat/test-generate.R +++ b/tests/testthat/test-generate.R @@ -79,3 +79,37 @@ test_that("generate function works with images", { }) + + + +test_that("structured output", { + skip_if_not(test_connection(), "Ollama server not available") + + format <- list( + type = "object", + properties = list( + name = list( + type = "string" + ), + capital = list( + type = "string" + ), + languages = list( + type = "array", + items = list( + type = "string" + ) + ) + ), + required = list("name", "capital", "languages") + ) + + msg <- "tell me about canada" + resp <- generate("llama3.1", prompt = msg, format = format) + # response <- httr2::resp_body_json(resp)$response + structured_output <- resp_process(resp, "structured") + expect_equal(tolower(structured_output$name), "canada") + +}) + + diff --git a/vignettes/ollamar.Rmd b/vignettes/ollamar.Rmd index 8354e2e..f883924 100644 --- a/vignettes/ollamar.Rmd +++ b/vignettes/ollamar.Rmd @@ -362,6 +362,30 @@ do.call(resp[[1]]$name, resp[[1]]$arguments) # 7 do.call(resp[[2]]$name, resp[[2]]$arguments) # 70 ``` +### Structured outputs + +The `chat()` and `generate()` functions support [structured outputs](https://ollama.com/blog/structured-outputs), making it possible to constrain a model's output to a specified format defined by a JSON schema (R list). + +```{r eval=FALSE} +# define a JSON schema as a list to constrain a model's output +format <- list( + type = "object", + properties = list( + name = list(type = "string"), + capital = list(type = "string"), + languages = list(type = "array", + items = list(type = "string") + ) + ), + required = list("name", "capital", "languages") + ) + +generate("llama3.1", "tell me about Canada", output = "structured", format = format) + +msg <- create_message("tell me about Canada") +chat("llama3.1", msg, format = format, output = "structured") +``` + ### Parallel requests For the `generate()` and `chat()` endpoints/functions, you can specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.