Skip to content

Commit

Permalink
feat: New FlagEmbedding models (#6)
Browse files Browse the repository at this point in the history
* feat: new embedding models

* test: canonical value checks update
  • Loading branch information
Anush008 authored Oct 18, 2023
1 parent 62e65cc commit 7fe735e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 91 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🤖 Models

- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en)
- [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5)
- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en)
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5)
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

## 🚀 Installation
Expand Down
22 changes: 20 additions & 2 deletions fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ type EmbeddingModel string
const (
AllMiniLML6V2 EmbeddingModel = "fast-all-MiniLM-L6-v2"
BGEBaseEN EmbeddingModel = "fast-bge-base-en"
BGEBaseENV15 EmbeddingModel = "fast-bge-base-en-v1.5"
BGESmallEN EmbeddingModel = "fast-bge-small-en"
BGESmallENV15 EmbeddingModel = "fast-bge-small-en-v1.5"
BGESmallZH EmbeddingModel = "fast-bge-small-zh-v1.5"

// A model with type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
Expand Down Expand Up @@ -79,7 +82,7 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) {
}

if options.Model == "" {
options.Model = BGESmallEN
options.Model = BGESmallENV15
}

if options.MaxLength == 0 {
Expand Down Expand Up @@ -281,10 +284,25 @@ func ListSupportedModels() []ModelInfo {
Dim: 768,
Description: "Base English model",
},
{
Model: BGEBaseENV15,
Dim: 768,
Description: "v1.5 release of the base English model",
},
{
Model: BGESmallEN,
Dim: 384,
Description: "Fast and Default English model",
Description: "Fast English model",
},
{
Model: BGESmallENV15,
Dim: 384,
Description: "Fast, default English model",
},
{
Model: BGESmallZH,
Dim: 512,
Description: "Fast Chinese model",
},
// {
// Model: MLE5Large,
Expand Down
97 changes: 8 additions & 89 deletions fastembed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,95 +5,14 @@ import (
"testing"
)

func TestEmbedBGEBaseEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: BGEBaseEN,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

func TestEmbedAllMiniLML6V2(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: AllMiniLML6V2,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

func TestEmbedBGESmallEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: BGESmallEN,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

// A model type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
// func TestEmbedMLE5Large(t *testing.T) {
// // Test with a single input
// show := false
// fe, err := NewFlagEmbedding(&InitOptions{
// Model: MLE5Large,
// ShowDownloadProgress: &show,
// })
// defer fe.Destroy()
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// input := []string{"hello world"}
// result, err := fe.Embed(input, 1)
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }

// if len(result) != len(input) {
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
// }
// }

func TestCanonicalValues(T *testing.T) {
canonicalValues := map[EmbeddingModel]([]float32){
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822},
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459},
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472},
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328},
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061},
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451},
BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045},
BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434},
BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762},
}

for model, expected := range canonicalValues {
Expand All @@ -114,10 +33,10 @@ func TestCanonicalValues(T *testing.T) {
T.Errorf("Expected result length %v, got %v", len(input), len(result))
}

epsilon := float64(1e-5)
epsilon := float64(1e-4)
for i, v := range expected {
if math.Abs(float64(result[0][i]-v)) > float64(epsilon) {
T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i])
T.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i])
}
}
}
Expand Down

0 comments on commit 7fe735e

Please sign in to comment.