diff --git a/encoders/closedai/go.mod b/encoders/closedai/go.mod index 30d1a13..3a674c5 100644 --- a/encoders/closedai/go.mod +++ b/encoders/closedai/go.mod @@ -1,5 +1,5 @@ module github.com/conneroisu/semanticrouter-go/encoders/closedai -go 1.23.0 +go 1.24.0 require github.com/sashabaranov/go-openai v1.29.1 diff --git a/encoders/closedai/openai.go b/encoders/closedai/openai.go index 26d32e5..e29ead1 100644 --- a/encoders/closedai/openai.go +++ b/encoders/closedai/openai.go @@ -7,14 +7,26 @@ import ( openai "github.com/sashabaranov/go-openai" ) +// Client is a minimal interface for the OpenAI client. +type Client interface { + CreateEmbeddings(ctx context.Context, req openai.EmbeddingRequest) (openai.EmbeddingResponse, error) +} + // Encoder encodes a query string into an OpenAI embedding. type Encoder struct { // Client is the OpenAI client. - Client *openai.Client + Client Client // Model is the OpenAI embedding model to use. Model openai.EmbeddingModel } +func NewEncoder(client Client, model openai.EmbeddingModel) Encoder { + return Encoder{ + Client: client, + Model: model, + } +} + // Encode encodes the given utterance using the OpenAI API. func (o Encoder) Encode( ctx context.Context, diff --git a/encoders/closedai/openai_test.go b/encoders/closedai/openai_test.go new file mode 100644 index 0000000..590e45c --- /dev/null +++ b/encoders/closedai/openai_test.go @@ -0,0 +1,39 @@ +package closedai + +import ( + "context" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +func TestEncoder_Encode(t *testing.T) { + t.Parallel() + client := &mockClient{} + encoder := NewEncoder(client, "text-embedding-ada-002") + utterance := "Hello, world!" + got, err := encoder.Encode(t.Context(), utterance) + expected := []float64{ + 0.0, + } + assert.Equal(t, expected, got) + assert.NoError(t, err) +} + +type mockClient struct{} + +func (m *mockClient) CreateEmbeddings( + _ context.Context, + req openai.EmbeddingRequest, +) (openai.EmbeddingResponse, error) { + return openai.EmbeddingResponse{ + Data: []openai.Embedding{ + { + Embedding: []float32{ + 0.0, + }, + }, + }, + }, nil +} diff --git a/encoders/google/doc.go b/encoders/google/doc.go index 813a535..72b60f9 100644 --- a/encoders/google/doc.go +++ b/encoders/google/doc.go @@ -1,4 +1,4 @@ -// Package encoders provides encoders for Google language models. +// Package google provides encoders for Google language models. // // Google language models are language models that is trained on a large corpus of text data. -package encoders +package google diff --git a/encoders/google/go.mod b/encoders/google/go.mod index 0427456..2ea11da 100644 --- a/encoders/google/go.mod +++ b/encoders/google/go.mod @@ -1,8 +1,11 @@ module github.com/conneroisu/semanticrouter-go/encoders/google -go 1.23.0 +go 1.24.0 -require github.com/google/generative-ai-go v0.17.0 +require ( + github.com/google/generative-ai-go v0.17.0 + github.com/stretchr/testify v1.9.0 +) require ( cloud.google.com/go v0.115.0 // indirect @@ -11,6 +14,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -20,6 +24,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect @@ -38,4 +43,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/encoders/google/go.sum b/encoders/google/go.sum index 6594204..5ac2c3a 100644 --- a/encoders/google/go.sum +++ b/encoders/google/go.sum @@ -158,6 +158,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/encoders/google/google.go b/encoders/google/google.go index bb20808..3a72193 100644 --- a/encoders/google/google.go +++ b/encoders/google/google.go @@ -1,4 +1,4 @@ -package encoders +package google import ( "context" @@ -6,23 +6,33 @@ import ( "github.com/google/generative-ai-go/genai" ) -// GoogleEncoder encodes a query string into a Google search URL. -type GoogleEncoder struct { - client genai.Client +// Client is a minimal client for the Google Generative AI API. +type Client interface { + EmbeddingModel(name string) Model +} + +// Model is a minimal model for the Google Generative AI API. +type Model interface { + EmbedContent(ctx context.Context, content genai.Text) (genai.EmbedContentResponse, error) +} + +// Encoder encodes a query string into a Google search URL. +type Encoder struct { + client Client name string } -// NewGoogleEncoder creates a new GoogleEncoder. -func NewGoogleEncoder( - client genai.Client, -) *GoogleEncoder { - return &GoogleEncoder{ +// NewEncoder creates a new GoogleEncoder. +func NewEncoder( + client Client, +) *Encoder { + return &Encoder{ client: client, } } // Encode encodes a query string into a Google search URL. -func (e *GoogleEncoder) Encode( +func (e *Encoder) Encode( ctx context.Context, query string, ) ([]float64, error) { @@ -30,8 +40,9 @@ func (e *GoogleEncoder) Encode( case <-ctx.Done(): return nil, ctx.Err() default: - model := e.client.EmbeddingModel(e.name) - embedding, err := model.EmbedContent(ctx, genai.Text(query)) + embedding, err := e.client.EmbeddingModel( + e.name, + ).EmbedContent(ctx, genai.Text(query)) if err != nil { return nil, err } diff --git a/encoders/google/google_test.go b/encoders/google/google_test.go new file mode 100644 index 0000000..a34e928 --- /dev/null +++ b/encoders/google/google_test.go @@ -0,0 +1,34 @@ +package google + +import ( + "context" + "testing" + + "github.com/google/generative-ai-go/genai" + "github.com/stretchr/testify/assert" +) + +func TestEncoder_Encode(t *testing.T) { + ctx := context.Background() + mockClient := mockClient{} + encoder := NewEncoder(mockClient) + result, err := encoder.Encode(ctx, "query") + assert.NoError(t, err) + assert.Equal(t, []float64{0.0}, result) +} + +type mockClient struct{} + +func (m mockClient) EmbeddingModel(_ string) Model { + return mockModel{} +} + +type mockModel struct{} + +func (m mockModel) EmbedContent(_ context.Context, _ genai.Text) (genai.EmbedContentResponse, error) { + return genai.EmbedContentResponse{ + Embedding: &genai.ContentEmbedding{ + Values: []float32{0.0}, + }, + }, nil +} diff --git a/encoders/ollama/go.mod b/encoders/ollama/go.mod index 970cf16..d18a848 100644 --- a/encoders/ollama/go.mod +++ b/encoders/ollama/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/semanticrouter-go/encoders/ollama -go 1.23.0 +go 1.24.0 require ( github.com/ollama/ollama v0.3.10 diff --git a/encoders/voyageai/go.mod b/encoders/voyageai/go.mod index 689df13..3d49510 100644 --- a/encoders/voyageai/go.mod +++ b/encoders/voyageai/go.mod @@ -1,5 +1,5 @@ module github.com/conneroisu/semanticrouter-go/encoders/voyageai -go 1.23.0 +go 1.24.0 require github.com/conneroisu/go-voyageai v0.0.0-20240712192129-77bcd696824e diff --git a/examples/chit-chat/go.mod b/examples/chit-chat/go.mod index 43d76a4..527d1b5 100644 --- a/examples/chit-chat/go.mod +++ b/examples/chit-chat/go.mod @@ -1,5 +1,5 @@ module github.com/conneroisu/semanticrouter-go/examples/chit-chat -go 1.23.0 +go 1.24.0 require github.com/sashabaranov/go-openai v1.29.1 diff --git a/examples/veterinarian/go.mod b/examples/veterinarian/go.mod index a458581..9ca037e 100644 --- a/examples/veterinarian/go.mod +++ b/examples/veterinarian/go.mod @@ -1,5 +1,5 @@ module github.com/conneroisu/semanticrouter-go/examples/veterinarian -go 1.23.0 +go 1.24.0 require github.com/ollama/ollama v0.3.10 diff --git a/go.mod b/go.mod index 14ed44f..b9acb82 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/semanticrouter-go -go 1.23.0 +go 1.24.0 require ( github.com/conneroisu/semanticrouter-go/encoders/ollama v0.0.0-20240909025305-0a3db7c99137 diff --git a/go.work b/go.work index 8cf7052..f9ebbe9 100644 --- a/go.work +++ b/go.work @@ -1,6 +1,6 @@ -go 1.23.0 +go 1.24.0 -toolchain go1.23.0 +toolchain go1.24.1 use ( . diff --git a/stores/memory/go.mod b/stores/memory/go.mod index 202b4eb..15c5b15 100644 --- a/stores/memory/go.mod +++ b/stores/memory/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/semanticrouter-go/stores/memory -go 1.23.0 +go 1.24.0 require github.com/stretchr/testify v1.9.0 diff --git a/stores/mongo/go.mod b/stores/mongo/go.mod index ba957be..a4e0f73 100644 --- a/stores/mongo/go.mod +++ b/stores/mongo/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/semanticrouter-go/stores/mongo -go 1.23.0 +go 1.24.0 require ( github.com/stretchr/testify v1.9.0 diff --git a/stores/mongo/mongo.go b/stores/mongo/mongo.go index 1bbc6c0..acbec58 100644 --- a/stores/mongo/mongo.go +++ b/stores/mongo/mongo.go @@ -3,35 +3,52 @@ package mongo import ( "context" + "io" "github.com/conneroisu/semanticrouter-go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) +// Cursor is a MongoDB cursor. +// +// It implements an minimal subset of the mongo.Cursor interface. +type Cursor interface { + All(ctx context.Context, result any) error + io.Closer +} + +// Collection is a MongoDB collection. +// +// It implements an minimal subset of the mongo.Collection interface. +type Collection interface { + Find(ctx context.Context, filter any, opts ...*options.FindOptions) (cur *mongo.Cursor, err error) + InsertOne(ctx context.Context, doc any, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) +} + // Store is a MongoDB store. // // It implements the Store interface. type Store struct { - coll *mongo.Collection + coll Collection } // New creates a new MongoDB store. -func New(collection *mongo.Collection) *Store { - return &Store{ - coll: collection, - } +func New(collection Collection) *Store { + return &Store{coll: collection} } // Get gets a value from the store. func (s *Store) Get(ctx context.Context, utterance string) ([]float64, error) { - var floats []float64 - filter := bson.M{"utterance": utterance} - cur, err := s.coll.Find(ctx, filter) + var ( + floats []float64 + results []semanticrouter.Utterance + ) + cur, err := s.coll.Find(ctx, bson.M{"utterance": utterance}) if err != nil { return nil, err } - var results []semanticrouter.Utterance if err = cur.All(ctx, &results); err != nil { panic(err) } diff --git a/stores/mongo/mongo_test.go b/stores/mongo/mongo_test.go index 95cb670..8e7df95 100644 --- a/stores/mongo/mongo_test.go +++ b/stores/mongo/mongo_test.go @@ -1,7 +1,6 @@ package mongo import ( - "context" "log" "testing" @@ -17,22 +16,23 @@ var ( ) func TestStore(t *testing.T) { + if testing.Short() { + t.Skip("skipping non-short test") + } a := assert.New(t) - ctx := context.Background() - - mongodbContainer, err := mongodb.Run(ctx, "mongo:6") + mongodbContainer, err := mongodb.Run(t.Context(), "mongo:6") a.NoError(err) defer func() { - if err := mongodbContainer.Terminate(ctx); err != nil { + if Terr := mongodbContainer.Terminate(t.Context()); Terr != nil { log.Fatalf("failed to terminate container: %s", err) } }() - uri, err := mongodbContainer.ConnectionString(ctx) + uri, err := mongodbContainer.ConnectionString(t.Context()) a.NoError(err) - client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + client, err := mongo.Connect(t.Context(), options.Client().ApplyURI(uri)) a.NoError(err) defer func() { - err = client.Disconnect(ctx) + err = client.Disconnect(t.Context()) if err != nil { log.Fatalf("failed to disconnect from mongodb: %s", err) } @@ -41,14 +41,14 @@ func TestStore(t *testing.T) { store := New(collection) a.NoError(err) err = store.Set( - ctx, + t.Context(), semanticrouter.Utterance{ Utterance: "key", Embed: []float64{1.0, 2.0, 3.0, 4.0, 5.0}, }) a.NoError(err) - floats, err := store.Get(ctx, "key") + floats, err := store.Get(t.Context(), "key") a.NoError(err) a.Len(floats, 5) } diff --git a/stores/valkey/go.mod b/stores/valkey/go.mod index c70997c..e59f8d5 100644 --- a/stores/valkey/go.mod +++ b/stores/valkey/go.mod @@ -1,6 +1,6 @@ module github.com/conneroisu/semanticrouter-go/stores/valkey -go 1.23.0 +go 1.24.0 require ( github.com/redis/go-redis/v9 v9.6.1 diff --git a/stores/valkey/valkey.go b/stores/valkey/valkey.go index ed121dc..254cdb1 100644 --- a/stores/valkey/valkey.go +++ b/stores/valkey/valkey.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "time" "github.com/conneroisu/semanticrouter-go" "github.com/redis/go-redis/v9" @@ -12,11 +14,20 @@ import ( // Store is a valkey/redis store for embeddings. type Store struct { - rds *redis.Client + rds Client +} + +// Client is a redis client for valkey. +// +// This is a minimal interface to allow for different redis clients. +type Client interface { + Get(ctx context.Context, key string) *redis.StringCmd + Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd + io.Closer } // NewStore creates a new Store from a redis client. -func NewStore(rds *redis.Client) *Store { +func NewStore(rds Client) *Store { return &Store{rds: rds} } @@ -30,17 +41,19 @@ func (s *Store) Get( ctx context.Context, utterance string, ) (embedding []float64, err error) { - cmd := s.rds.Get(ctx, utterance) - val, err := cmd.Result() + var ( + res *redis.StringCmd + val string + utPr semanticrouter.Utterance + ) + res = s.rds.Get(ctx, utterance) + val, err = res.Result() if err != nil { if errors.Is(err, redis.Nil) { - fmt.Println("key2 does not exist") - fmt.Println(err) return nil, fmt.Errorf("key does not exist: %w", err) } return nil, err } - var utPr semanticrouter.Utterance err = json.Unmarshal([]byte(val), &utPr) if err != nil { return nil, fmt.Errorf("error unmarshaling embedding: %w", err) @@ -52,18 +65,22 @@ func (s *Store) Get( func (s *Store) Set( ctx context.Context, utterance semanticrouter.Utterance, -) error { - val, err := json.Marshal(utterance) +) (err error) { + var ( + val []byte + res *redis.StatusCmd + ) + val, err = json.Marshal(utterance) if err != nil { return fmt.Errorf("error marshaling embedding: %w", err) } - cmd := s.rds.Set( + res = s.rds.Set( ctx, utterance.Utterance, string(val), 0, ) - err = cmd.Err() + err = res.Err() if err != nil { return fmt.Errorf("error setting embedding: %w", err) } diff --git a/stores/valkey/valkey_test.go b/stores/valkey/valkey_test.go index 7b4d6f8..fd821d9 100644 --- a/stores/valkey/valkey_test.go +++ b/stores/valkey/valkey_test.go @@ -37,12 +37,13 @@ func TestStore(t *testing.T) { assert.NoError(t, err) endpoint, err := redisContainer.Endpoint(ctx, "") assert.NoError(t, err) - store := valkey.NewStore(redis.NewClient( + cli := redis.NewClient( &redis.Options{ Addr: endpoint, Network: "tcp", }, - )) + ) + store := valkey.NewStore(cli) err = store.Set( ctx,