Skip to content

Commit 0522a5d

Browse files
committed
Updated Candle example [skip ci]
1 parent 5f3b7cf commit 0522a5d

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

examples/candle/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ edition = "2021"
55
publish = false
66

77
[dependencies]
8-
candle-core = "0.6"
9-
candle-nn = "0.6"
10-
candle-transformers = "0.6"
8+
candle-core = "0.7"
9+
candle-nn = "0.7"
10+
candle-transformers = "0.7"
1111
hf-hub = "0.3"
1212
pgvector = { path = "../..", features = ["postgres"] }
1313
postgres = "0.19"
1414
serde_json = "1"
15-
tokenizers = "0.19"
15+
tokenizers = "0.20"

examples/candle/src/main.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ impl EmbeddingModel {
8181
Ok(Self { tokenizer, model })
8282
}
8383

84-
// embed one at a time since BertModel does not support attention mask
85-
// https://github.com/huggingface/candle/issues/1798
84+
// TODO support multiple texts
8685
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
8786
let tokens = self.tokenizer.encode(text, true)?;
8887
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
8988
let token_type_ids = token_ids.zeros_like()?;
90-
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
89+
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
9190
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
9291
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
9392
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)

0 commit comments

Comments
 (0)