Skip to content

Commit

Permalink
Support structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Dec 26, 2024
1 parent 5b5cdc4 commit 36e64ca
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 19 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# ollamar (development version)

- `generate()` and `chat()` support [structured output](https://ollama.com/blog/structured-outputs) via `format` parameter.
- `test_connection()` returns boolean instead of `httr2` object. #29
- `chat()` supports [tool calling](https://ollama.com/blog/tool-support) via `tools` parameter. Added `get_tool_calls()` helper function to process tools. #30
- Simplify README and add Get started vignette with more examples.

# ollamar 1.2.1

- `generate()` and `chat()` accept multiple images as prompts/messages.
Expand Down
24 changes: 17 additions & 7 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ create_request <- function(endpoint, host = NULL) {
#' @param prompt A character string of the prompt like "The sky is..."
#' @param suffix A character string after the model response. Default is "".
#' @param images A path to an image file to include in the prompt. Default is "".
#' @param format Format to return a response in. Format can be json/list (structured response).
#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".
#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".
#' @param context A list of context from a previous response to include previous conversation in the prompt. Default is an empty list.
Expand Down Expand Up @@ -86,10 +87,10 @@ create_request <- function(endpoint, host = NULL) {
#' image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png")
#' # use vision or multimodal model such as https://ollama.com/benzie/llava-phi-3
#' generate("benzie/llava-phi-3:latest", "What is in the image?", images = image_path, output = "text")
generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ...) {
generate <- function(model, prompt, suffix = "", images = "", format = list(), system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "structured"), endpoint = "/api/generate", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req'")
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "structured")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req', 'structured'")
}

req <- create_request(endpoint, host)
Expand All @@ -112,6 +113,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
keep_alive = keep_alive
)

if (length(format) != 0 & inherits(format, "list")) {
body_json$format <- format
}

# check if model options are passed and specified correctly
opts <- list(...)
if (length(opts) > 0) {
Expand Down Expand Up @@ -169,8 +174,9 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' @param messages A list with list of messages for the model (see examples below).
#' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list.
#' @param stream Enable response streaming. Default is FALSE.
#' @param format Format to return a response in. Format can be json/list (structured response).
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling)
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling), "structured" (structured output)
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
Expand Down Expand Up @@ -208,10 +214,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' list(role = "user", content = "What is in the image?", images = image_path)
#' )
#' chat("benzie/llava-phi-3", messages, output = 'text')
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools"), endpoint = "/api/chat", host = NULL, ...) {
chat <- function(model, messages, tools = list(), stream = FALSE, format = list(), keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools", "structured"), endpoint = "/api/chat", host = NULL, ...) {
output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools", "structured")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'tools', 'structured'")
}

req <- create_request(endpoint, host)
Expand All @@ -231,6 +237,10 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "
keep_alive = keep_alive
)

if (length(format) != 0 & inherits(format, "list")) {
body_json$format <- format
}

opts <- list(...)
if (length(opts) > 0) {
if (validate_options(...)) {
Expand Down
7 changes: 5 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ get_tool_calls <- function(resp) {
#' Process httr2 response object
#'
#' @param resp A httr2 response object.
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls)
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls), "structured" (structured output).
#'
#' @return A data frame, json list, raw or httr2 response object.
#' @export
Expand All @@ -122,7 +122,6 @@ get_tool_calls <- function(resp) {
#' resp_process(resp, "df") # parse response to dataframe/tibble
#' resp_process(resp, "jsonlist") # parse response to list
#' resp_process(resp, "raw") # parse response to raw string
#' resp_process(resp, "resp") # return input response object
#' resp_process(resp, "text") # return text/character vector
#' resp_process(resp, "tools") # return tool_calls
resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text", "tools")) {
Expand Down Expand Up @@ -195,6 +194,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
return(df_response)
} else if (output == "text") {
return(df_response$response)
} else if (output == "structured") {
return(jsonlite::fromJSON(df_response$response))
}
} else if (grepl("api/chat", resp$url)) { # process chat endpoint
json_body <- httr2::resp_body_json(resp)
Expand All @@ -209,6 +210,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
return(df_response)
} else if (output == "text") {
return(df_response$content)
} else if (output == "structured") {
return(jsonlite::fromJSON(df_response$content))
}
} else if (grepl("api/tags", resp$url)) { # process tags endpoint
json_body <- httr2::resp_body_json(resp)[[1]]
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The library also makes it easy to work with data structures (e.g., conversationa

To use this R library, ensure the [Ollama](https://ollama.com) app is installed. Ollama can use GPUs for accelerating LLM inference. See [Ollama GPU documentation](https://github.com/ollama/ollama/blob/main/docs/gpu.md) for more information.

See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and has been tested on Ollama v0.1.30 and above. It was last tested on Ollama v0.3.10.
See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and was last tested on v0.5.4.

> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ for more information.
See [Ollama’s Github page](https://github.com/ollama/ollama) for more
information. This library uses the [Ollama REST API (see documentation
for details)](https://github.com/ollama/ollama/blob/main/docs/api.md)
and has been tested on Ollama v0.1.30 and above. It was last tested on
Ollama v0.3.10.
and was last tested on v0.5.4.

> Note: You should have at least 8 GB of RAM available to run the 7B
> models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
Expand Down
7 changes: 5 additions & 2 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/generate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/resp_process.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 37 additions & 2 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ test_that("chat function handles images in messages", {
})


test_that("chat function handles tools", {
test_that("chat function tool calling", {
skip_if_not(test_connection(), "Ollama server not available")

add_two_numbers <- function(x, y) {
Expand Down Expand Up @@ -188,7 +188,7 @@ test_that("chat function handles tools", {


# test multiple tools
msg <- create_message("what is three plus one? then multiply the output of that by ten")
msg <- create_message("add three plus four. then multiply by ten")
tools <- list(list(type = "function",
"function" = list(
name = "add_two_numbers",
Expand Down Expand Up @@ -234,3 +234,38 @@ test_that("chat function handles tools", {
# expect_equal(resp[[2]]$name, "multiply_two_numbers")

})




test_that("structured output", {
skip_if_not(test_connection(), "Ollama server not available")

format <- list(
type = "object",
properties = list(
name = list(
type = "string"
),
capital = list(
type = "string"
),
languages = list(
type = "array",
items = list(
type = "string"
)
)
),
required = list("name", "capital", "languages")
)

msg <- create_message("tell me about canada")
resp <- chat("llama3.1", msg, format = format)
# content <- httr2::resp_body_json(resp)$message$content
structured_output <- resp_process(resp, "structured")
expect_equal(tolower(structured_output$name), "canada")

})


34 changes: 34 additions & 0 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,37 @@ test_that("generate function works with images", {
})





test_that("structured output", {
skip_if_not(test_connection(), "Ollama server not available")

format <- list(
type = "object",
properties = list(
name = list(
type = "string"
),
capital = list(
type = "string"
),
languages = list(
type = "array",
items = list(
type = "string"
)
)
),
required = list("name", "capital", "languages")
)

msg <- "tell me about canada"
resp <- generate("llama3.1", prompt = msg, format = format)
# response <- httr2::resp_body_json(resp)$response
structured_output <- resp_process(resp, "structured")
expect_equal(tolower(structured_output$name), "canada")

})


24 changes: 24 additions & 0 deletions vignettes/ollamar.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,30 @@ do.call(resp[[1]]$name, resp[[1]]$arguments) # 7
do.call(resp[[2]]$name, resp[[2]]$arguments) # 70
```

### Structured outputs

The `chat()` and `generate()` functions support [structured outputs](https://ollama.com/blog/structured-outputs), making it possible to constrain a model's output to a specified format defined by a JSON schema (R list).

```{r eval=FALSE}
# define a JSON schema as a list to constrain a model's output
format <- list(
type = "object",
properties = list(
name = list(type = "string"),
capital = list(type = "string"),
languages = list(type = "array",
items = list(type = "string")
)
),
required = list("name", "capital", "languages")
)
generate("llama3.1", "tell me about Canada", output = "structured", format = format)
msg <- create_message("tell me about Canada")
chat("llama3.1", msg, format = format, output = "structured")
```

### Parallel requests

For the `generate()` and `chat()` endpoints/functions, you can specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.
Expand Down

0 comments on commit 36e64ca

Please sign in to comment.