Skip to content

Commit

Permalink
Add test case for embed
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 29, 2024
1 parent e1468a5 commit 9410121
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
10 changes: 4 additions & 6 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,13 @@ pull <- function(name, stream = TRUE, insecure = FALSE, endpoint = "/api/pull",
}


vector_norm <- function(x) {
return(sqrt(sum(x^2)))
}


normalize <- function(x) {
norm <- sqrt(sum(x^2))
norm <- vector_norm(x)
normalized_vector <- x / norm
return(normalized_vector)
}
Expand All @@ -468,11 +471,6 @@ normalize <- function(x) {








#' Get embedding for inputs
#'
#' Supercedes the `embeddings()` function.
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-embed.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,36 @@ library(ollamar)

test_that("embed function works with basic input", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")

# one input
result <- embed("all-minilm", "hello")
expect_type(result, "double")
expect_true(dim(result)[2] == 1)
expect_true(dim(result)[1] > 1)

# two inputs
result <- embed("all-minilm", c("hello", "world"))
expect_true(dim(result)[2] == 2)
expect_true(dim(result)[1] > 1)

# model options
expect_type(embed("all-minilm", "hello", temperature = 2), "double")
expect_error(embed("all-minilm", "hello", dfdsffds = 0))

# check normalize (default is normalize = TRUE)
result <- embed("all-minilm", "hello", normalize = TRUE)
v <- result[, 1]
expect_true(all.equal(1, vector_norm(v)))
result2 <- embed("all-minilm", "hello")
expect_true(sum(result[, 1] - result2[, 1]) == 0) # result and result2 vectors should be the same

# check unormalize
result3 <- embed("all-minilm", "hello", normalize = FALSE)
expect_false(sum(result[, 1] - result3[, 1]) == 0) # result and result3 vectors are different

# cosine similarity
expect_true(all.equal((t(result) %*% result)[1], 1))
expect_true(t(result) %*% result2 != 1)
expect_true(t(result) %*% result3 != 1)

})

0 comments on commit 9410121

Please sign in to comment.