Skip to content

Commit

Permalink
Refactor chat
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 28, 2024
1 parent 68f8402 commit 17c0002
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 47 deletions.
38 changes: 3 additions & 35 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
}
}

req <- httr2::req_body_json(req, body_json)
req <- httr2::req_body_json(req, body_json, stream = stream)

if (!stream) {
tryCatch(
Expand All @@ -227,42 +227,10 @@ chat <- function(model, messages, output = c("resp", "jsonlist", "raw", "df", "t
wrapped_handler <- function(x) stream_handler(x, env, endpoint)
resp <- httr2::req_perform_stream(req, wrapped_handler, buffer_kb = 1)
cat("\n\n")
resp$body <- env$accumulated_data

# process streaming output
json_lines <- strsplit(rawToChar(env$accumulated_data), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
role = character(length(json_lines_output)),
content = character(length(json_lines_output)),
created_at = character(length(json_lines_output))
)

if (output[1] == "raw") {
return(rawToChar(env$accumulated_data))
}

for (i in seq_along(json_lines)) {
json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]])
df_response$model[i] <- json_lines_output[[i]]$model
df_response$role[i] <- json_lines_output[[i]]$message$role
df_response$content[i] <- json_lines_output[[i]]$message$content
df_response$created_at[i] <- json_lines_output[[i]]$created_at
}

if (output[1] == "jsonlist") {
return(json_lines_output)
}

if (output[1] == "df") {
return(df_response)
}

if (output[1] == "text") {
return(paste0(df_response$content, collapse = ""))
}
return(resp_process(resp = resp, output = output[1]))

return(resp)
}


Expand Down
34 changes: 30 additions & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
return(resp)
}

endpoints_without_stream <- c("api/tags", "api/delete") # endpoints that do not stream
# endpoints that do not stream
endpoints_without_stream <- c("api/tags", "api/delete")

# process stream resp separately
stream <- FALSE
Expand Down Expand Up @@ -176,8 +177,33 @@ resp_process_stream <- function(resp, output) {
return(paste0(df_response$response, collapse = ""))
}
} else if (grepl("api/chat", resp$url)) { # process chat endpoint
return(NULL) # TODO fill in
} else if (grepl("api/tags", resp$url)) { # process tags endpoint {
json_lines <- strsplit(rawToChar(resp$body), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
model = character(length(json_lines_output)),
role = character(length(json_lines_output)),
content = character(length(json_lines_output)),
created_at = character(length(json_lines_output))
)

for (i in seq_along(json_lines)) {
json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]])
df_response$model[i] <- json_lines_output[[i]]$model
df_response$role[i] <- json_lines_output[[i]]$message$role
df_response$content[i] <- json_lines_output[[i]]$message$content
df_response$created_at[i] <- json_lines_output[[i]]$created_at
}

if (output[1] == "jsonlist") {
return(json_lines_output)
}
if (output[1] == "df") {
return(df_response)
}
if (output[1] == "text") {
return(paste0(df_response$content, collapse = ""))
}
} else if (grepl("api/tags", resp$url)) { # process tags endpoint
return(NULL) # TODO fill in
}
}
Expand All @@ -196,7 +222,7 @@ resp_process_stream <- function(resp, output) {
#'
#' @examples
#' image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png")
#' substr(image_encode_base64(image_path), 1, 5) # truncate output
#' substr(image_encode_base64(image_path), 1, 5) # truncate output
image_encode_base64 <- function(image_path) {
if (!file.exists(image_path)) {
stop("Image file does not exist.")
Expand Down
2 changes: 1 addition & 1 deletion man/image_encode_base64.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions tests/testthat/test-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ test_that("chat function works with basic input", {
result <- chat("llama3", messages, stream = TRUE)
expect_s3_class(result, "httr2_response")
expect_s3_class(resp_process(result, "resp"), "httr2_response")
# expect_s3_class(resp_process(result, "df"), "data.frame") # BUG: fail test
# expect_type(resp_process(result, "jsonlist"), "list") # BUG fail test
# expect_type(resp_process(result, "text"), "character") # BUG fail test
# expect_type(resp_process(result, "raw"), "character") # BUG fail test
expect_s3_class(resp_process(result, "df"), "data.frame")
expect_type(resp_process(result, "jsonlist"), "list")
expect_type(resp_process(result, "text"), "character")
expect_type(resp_process(result, "raw"), "character")

result <- chat("llama3", messages, output = "df")
expect_s3_class(result, "data.frame")
Expand All @@ -58,9 +58,9 @@ test_that("chat function handles streaming correctly", {
)

result <- chat("llama3", messages, stream = TRUE, output = "text")
expect_type(result, "character")
expect_true(nchar(result) > 0)
expect_match(result, "1.*2.*3.*4.*5", all = FALSE)
expect_type(result, "character") # BUG fail test
expect_true(nchar(result) > 0) # BUG
expect_match(result, "1.*2.*3.*4.*5", all = FALSE) # BUG
})


Expand Down

0 comments on commit 17c0002

Please sign in to comment.