Skip to content

Commit

Permalink
Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 28, 2024
1 parent 17c0002 commit 2233d83
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 50 deletions.
40 changes: 23 additions & 17 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ create_request <- function(endpoint, host = NULL) {
#' generate("llama3", "The sky is...", stream = FALSE, output = "jsonlist")
generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text"), endpoint = "/api/generate", host = NULL, ...) {

output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
}

req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

Expand Down Expand Up @@ -120,7 +125,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
tryCatch(
{
resp <- httr2::req_perform(req)
return(resp_process(resp = resp, output = output[1]))
return(resp_process(resp = resp, output = output))
},
error = function(e) {
stop(e)
Expand All @@ -138,7 +143,7 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
cat("\n\n")
resp$body <- env$accumulated_data

return(resp_process(resp = resp, output = output[1]))
return(resp_process(resp = resp, output = output))
}


Expand Down Expand Up @@ -171,8 +176,8 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' list(role = "user", content = "How are you doing?")
#' )
#' chat("llama3", messages) # returns response by default
#' chat("llama3", messages, "text") # returns text/vector
#' chat("llama3", messages, "hello!", temperature = 2.8) # additional options
#' chat("llama3", messages, output = "text") # returns text/vector
#' chat("llama3", messages, temperature = 2.8) # additional options
#' chat("llama3", messages, stream = TRUE) # stream response
#' chat("llama3", messages, output = "df", stream = TRUE) # stream and return dataframe
#'
Expand All @@ -186,6 +191,12 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
#' )
#' chat("llama3", messages, stream = TRUE)
chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "text"), stream = FALSE, keep_alive = "5m", endpoint = "/api/chat", host = NULL, ...) {

output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
}

req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

Expand All @@ -211,7 +222,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
tryCatch(
{
resp <- httr2::req_perform(req)
return(resp_process(resp = resp, output = output[1]))
return(resp_process(resp = resp, output = output))
},
error = function(e) {
stop(e)
Expand All @@ -229,8 +240,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
cat("\n\n")
resp$body <- env$accumulated_data

return(resp_process(resp = resp, output = output[1]))

return(resp_process(resp = resp, output = output))
}


Expand Down Expand Up @@ -264,16 +274,15 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' list_models() # returns dataframe/tibble by default
#' list_models("df")
#' list_models() # returns dataframe
#' list_models("df") # returns dataframe
#' list_models("resp") # httr2 response object
#' list_models("jsonlist")
#' list_models("raw")
list_models <- function(output = c("df", "resp", "jsonlist", "raw", "text"), endpoint = "/api/tags", host = NULL) {

output <- output[1]
if (!output %in% c("df", "resp", "jsonlist", "raw", "text")) {
stop("Invalid output format specified. Supported formats are 'df', 'resp', 'jsonlist', 'raw', 'text'.")
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
}
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "GET")
Expand Down Expand Up @@ -349,7 +358,6 @@ delete <- function(model, endpoint = "/api/delete", host = NULL) {
#' pull("llama3")
#' pull("all-minilm", stream = FALSE)
pull <- function(model, stream = TRUE, insecure = FALSE, endpoint = "/api/pull", host = NULL) {

req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

Expand Down Expand Up @@ -549,7 +557,6 @@ embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpo
#' ohelp(first_prompt = "quit")
#' # regular usage: ohelp()
ohelp <- function(model = "codegemma:7b", ...) {

if (!model_avail(model)) {
return(invisible())
}
Expand All @@ -568,19 +575,18 @@ ohelp <- function(model = "codegemma:7b", ...) {

while (prompt != "/q") {
if (n_messages == 0) {
messages <- create_message(prompt, role = 'user')
messages <- create_message(prompt, role = "user")
} else {
messages <- append(messages, create_message(prompt, role = 'user'))
messages <- append(messages, create_message(prompt, role = "user"))
}
n_messages <- n_messages + 1
response <- chat(model, messages = messages, output = 'text', stream = TRUE)
response <- chat(model, messages = messages, output = "text", stream = TRUE)
messages <- append_message(response, "assistant", messages)
n_messages <- n_messages + 1
prompt <- readline()
}

cat("Goodbye!\n")

}


Expand Down
24 changes: 0 additions & 24 deletions R/test_connection.R

This file was deleted.

37 changes: 36 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,38 @@
#' Test connection to Ollama server
#'
#' @param url The URL of the Ollama server. Default is http://localhost:11434
#'
#' @return A httr2 response object.
#' @export
#'
#' @examples
#' test_connection()
#' test_connection("http://localhost:11434") # default url
#' test_connection("http://127.0.0.1:11434")
test_connection <- function(url = "http://localhost:11434") {
req <- httr2::request(url)
req <- httr2::req_method(req, "GET")
tryCatch(
{
resp <- httr2::req_perform(req)
message("Ollama local server running")
return(resp)
},
error = function(e) {
message("Ollama local server not running or wrong server.\nDownload and launch Ollama app to run the server. Visit https://ollama.com or https://github.com/ollama/ollama")
req$status_code <- 503
return(req)
}
)
}








#' Stream handler helper function
#'
#' Function to handle streaming.
Expand Down Expand Up @@ -133,7 +168,7 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
}

if (output == "df") {
return(df_response)
return(data.frame(df_response))
} else if (output == "text") {
return(df_response$name)
}
Expand Down
4 changes: 2 additions & 2 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/list_models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/test_connection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ test_that("chat function works with basic input", {
list(role = "user", content = "Tell me a 5-word story.")
)

# incorrect output type
expect_error(chat("llama3", messages, output = "abc"))

# not streaming
expect_s3_class(chat("llama3", messages), "httr2_response")
expect_s3_class(chat("llama3", messages, output = "resp"), "httr2_response")
Expand Down Expand Up @@ -58,9 +61,9 @@ test_that("chat function handles streaming correctly", {
)

result <- chat("llama3", messages, stream = TRUE, output = "text")
expect_type(result, "character") # BUG fail test
expect_true(nchar(result) > 0) # BUG
expect_match(result, "1.*2.*3.*4.*5", all = FALSE) # BUG
expect_type(result, "character")
expect_true(nchar(result) > 0)
expect_match(result, "1.*2.*3.*4.*5", all = FALSE)
})


Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-list_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
library(testthat)
library(ollamar)

test_that("list_models function works", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")

# incorrect output type
expect_error(list_models("sdf"))

result <- list_models()
expect_s3_class(result, "data.frame")
expect_true(all(c("name", "size", "parameter_size", "quantization_level", "modified") %in% names(result)))

expect_s3_class(list_models("df"), "data.frame")
expect_s3_class(list_models("resp"), "httr2_response")
expect_type(list_models("jsonlist"), "list")
expect_type(list_models("raw"), "character")
expect_type(list_models("text"), "character")
})
13 changes: 13 additions & 0 deletions tests/testthat/test-test_connection.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
library(testthat)
library(ollamar)

test_that("test_connection function works", {
result <- test_connection()
expect_s3_class(result, "httr2_response")
expect_equal(result$status_code, 200)

# wrong url
result <- test_connection(url = "dsfdsf")
expect_s3_class(result, "httr2_request")
expect_equal(result$status_code, 503)
})
3 changes: 3 additions & 0 deletions tests/testthat/test-test_generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ library(ollamar)
test_that("generate function works with different outputs and resp_process", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")

# incorrect output type
expect_error(generate("llama3", "The sky is...", output = "abc"))

# not streaming
expect_s3_class(generate("llama3", "The sky is..."), "httr2_response")
expect_s3_class(generate("llama3", "The sky is...", output = "resp"), "httr2_response")
Expand Down

0 comments on commit 2233d83

Please sign in to comment.