Skip to content

Commit

Permalink
Refactor chat
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 28, 2024
1 parent 2233d83 commit 5c543d7
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 6 deletions.
23 changes: 20 additions & 3 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,17 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#'
#' @param model A character string of the model name such as "llama3".
#' @param messages A list with list of messages for the model (see examples below).
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list.
#' @param stream Enable response streaming. Default is FALSE.
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text".
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion)
#'
#' @return A response in the format specified in the output parameter.
#' @export
#'
Expand All @@ -190,7 +194,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' list(role = "user", content = "List all the previous messages.")
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) {
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/chat", host = NULL, ...) {

output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
Expand All @@ -203,6 +207,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
body_json <- list(
model = model,
messages = messages,
tools = tools,
stream = stream,
keep_alive = keep_alive
)
Expand Down Expand Up @@ -269,7 +274,8 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
#' @param output The output format. Default is "df". Other options are "resp", "jsonlist", "raw", "text".
#' @param endpoint The endpoint to get the models. Default is "/api/tags".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models)
#' @return A response in the format specified in the output parameter.
#' @export
#'
Expand Down Expand Up @@ -313,6 +319,8 @@ list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), end
#' @param endpoint The endpoint to delete the model. Default is "/api/delete".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#delete-a-model)
#' @return A httr2 response object.
#' @export
#'
Expand Down Expand Up @@ -351,6 +359,9 @@ delete <- function(model, endpoint = "/api/delete", host = NULL) {
#' @param endpoint The endpoint to pull the model. Default is "/api/pull".
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#pull-a-model)
#'
#' @return A httr2 response object.
#' @export
#'
Expand Down Expand Up @@ -422,6 +433,9 @@ normalize <- function(x) {
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings)
#'
#' @return A numeric matrix of the embedding. Each column is the embedding for one input.
#' @export
#'
Expand Down Expand Up @@ -490,6 +504,9 @@ embed <- function(model, input, truncate = TRUE, normalize = TRUE, keep_alive =
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
#' @param ... Additional options to pass to the model.
#'
#' @references
#' [API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embedding)
#'
#' @return A numeric vector of the embedding.
#' @export
#'
Expand Down
10 changes: 8 additions & 2 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/delete.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/embed.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/embeddings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/list_models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/pull.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/test_connection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5c543d7

Please sign in to comment.