Skip to content

Commit

Permalink
ci: Canonical testing (#4)
Browse files Browse the repository at this point in the history
* test: canonical value testing

* ci: Added tests before release
  • Loading branch information
Anush008 authored Oct 9, 2023
1 parent baf6f77 commit a164354
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,29 @@ on:
workflow_dispatch:

jobs:

test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: 1.21
- name: Install dependencies
run: go get .
- name: Install ONNX Runtime
run: |
wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.0/onnxruntime-linux-x64-1.16.0.tgz
tar xvzf onnxruntime-linux-x64-1.16.0.tgz
echo "ONNX_PATH=$(pwd)/onnxruntime-linux-x64-1.16.0/lib/libonnxruntime.so" >> $GITHUB_ENV
- name: Test with Go
run: go test

release:
runs-on: ubuntu-latest
needs:
- test
steps:
- name: "☁️ checkout repository"
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error
}

// Handle overflow when coercing to int, major hassle.
modelMaxLen := int(math.Min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64))))
modelMaxLen := int(min(float64(math.MaxInt32), math.Abs(tokenizerConfig["model_max_length"].(float64))))
maxLength = min(maxLength, modelMaxLen)

tknzer.WithTruncation(&tokenizer.TruncationParams{
Expand Down
52 changes: 40 additions & 12 deletions fastembed_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package fastembed

import (
"fmt"
"math"
"testing"
)

Expand All @@ -14,13 +14,12 @@ func TestEmbedBGEBaseEN(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"Is the world doing okay?"}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

fmt.Printf("result: %v\n", result[0][0:10])
if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
Expand All @@ -35,22 +34,17 @@ func TestEmbedAllMiniLML6V2(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"Is the world doing okay?"}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

fmt.Printf("result: %v\n", result[0][0:10])
if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

// Breaks on GH Actions
// --- FAIL: TestEmbedBGESmallEN (2.29s)
//
// fastembed_test.go:63: Expected no error, got The tensor's shape ([1 512]) requires 512 elements, but only 8 were provided
func TestEmbedBGESmallEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Expand All @@ -60,13 +54,12 @@ func TestEmbedBGESmallEN(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"Is the world doing okay?"}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

fmt.Printf("result: %v\n", result[0][0:10])
if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
Expand All @@ -85,7 +78,7 @@ func TestEmbedBGESmallEN(t *testing.T) {
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// input := []string{"Is the world doing okay?"}
// input := []string{"hello world"}
// result, err := fe.Embed(input, 1)
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
Expand All @@ -95,3 +88,38 @@ func TestEmbedBGESmallEN(t *testing.T) {
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
// }
// }

func TestCanonicalValues(T *testing.T) {
canonicalValues := map[EmbeddingModel]([]float32){
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822},
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459},
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472},
}

for model, expected := range canonicalValues {
fe, err := NewFlagEmbedding(&InitOptions{
Model: model,
})
defer fe.Destroy()
if err != nil {
T.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
T.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
T.Errorf("Expected result length %v, got %v", len(input), len(result))
}

epsilon := float64(1e-5)
for i, v := range expected {
if math.Abs(float64(result[0][i]-v)) > float64(epsilon) {
T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i])
}
}
}

}

0 comments on commit a164354

Please sign in to comment.