Skip to content

Commit

Permalink
Feat: callbacks.Handler support handle error (#457)
Browse files Browse the repository at this point in the history
Feat: support HandleLLMError for llms.LLM
  • Loading branch information
chyroc authored Dec 28, 2023
1 parent 74527d3 commit 534f757
Show file tree
Hide file tree
Showing 17 changed files with 134 additions and 33 deletions.
3 changes: 3 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleLLMError(ctx context.Context, err error)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
HandleChainError(ctx context.Context, err error)
HandleToolStart(ctx context.Context, input string)
HandleToolEnd(ctx context.Context, output string)
HandleToolError(ctx context.Context, err error)
HandleAgentAction(ctx context.Context, action schema.AgentAction)
HandleRetrieverStart(ctx context.Context, query string)
HandleRetrieverEnd(ctx context.Context, query string, documents []schema.Document)
Expand Down
12 changes: 12 additions & 0 deletions callbacks/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) {
fmt.Println("Exiting LLM with results:", formatLLMResult(output))
}

func (l LogHandler) HandleLLMError(_ context.Context, err error) {
fmt.Println("Exiting LLM with error:", err)
}

func (l LogHandler) HandleChainStart(_ context.Context, inputs map[string]any) {
fmt.Println("Entering chain with inputs:", formatChainValues(inputs))
}
Expand All @@ -37,6 +41,10 @@ func (l LogHandler) HandleChainEnd(_ context.Context, outputs map[string]any) {
fmt.Println("Exiting chain with outputs:", formatChainValues(outputs))
}

func (l LogHandler) HandleChainError(_ context.Context, err error) {
fmt.Println("Exiting chain with error:", err)
}

func (l LogHandler) HandleToolStart(_ context.Context, input string) {
fmt.Println("Entering tool with input:", removeNewLines(input))
}
Expand All @@ -45,6 +53,10 @@ func (l LogHandler) HandleToolEnd(_ context.Context, output string) {
fmt.Println("Exiting tool with output:", removeNewLines(output))
}

func (l LogHandler) HandleToolError(_ context.Context, err error) {
fmt.Println("Exiting tool with error:", err)
}

func (l LogHandler) HandleAgentAction(_ context.Context, action schema.AgentAction) {
fmt.Println("Agent selected action:", formatAgentAction(action))
}
Expand Down
3 changes: 3 additions & 0 deletions callbacks/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ var _ Handler = SimpleHandler{}
func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
Expand Down
34 changes: 25 additions & 9 deletions chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,42 @@ func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...C
callbacksHandler.HandleChainStart(ctx, inputValues)
}

if err := validateInputs(c, fullValues); err != nil {
return nil, err
}

outputValues, err := c.Call(ctx, fullValues, options...)
outputValues, err := callChain(ctx, c, fullValues, options...)
if err != nil {
return outputValues, err
}
if err := validateOutputs(c, outputValues); err != nil {
if callbacksHandler != nil {
callbacksHandler.HandleChainError(ctx, err)
}
return outputValues, err
}

if callbacksHandler != nil {
callbacksHandler.HandleChainEnd(ctx, outputValues)
}

err = c.GetMemory().SaveContext(ctx, inputValues, outputValues)
if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
return outputValues, err
}

return outputValues, nil
}

func callChain(
ctx context.Context,
c Chain,
fullValues map[string]any,
options ...ChainCallOption,
) (map[string]any, error) {
if err := validateInputs(c, fullValues); err != nil {
return nil, err
}

outputValues, err := c.Call(ctx, fullValues, options...)
if err != nil {
return outputValues, err
}
if err := validateOutputs(c, outputValues); err != nil {
return outputValues, err
}

return outputValues, nil
}
Expand Down
3 changes: 3 additions & 0 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
generations = append(generations, &llms.Generation{
Expand Down
3 changes: 3 additions & 0 deletions llms/cohere/coherellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
Prompt: prompt,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

Expand Down
46 changes: 30 additions & 16 deletions llms/ernie/erniellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"os"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ernie/internal/ernieclient"
"github.com/tmc/langchaingo/schema"
Expand All @@ -17,8 +18,9 @@ var (
)

type LLM struct {
client *ernieclient.Client
model ModelName
client *ernieclient.Client
model ModelName
CallbacksHandler callbacks.Handler
}

var (
Expand All @@ -40,8 +42,9 @@ func New(opts ...Option) (*LLM, error) {
c, err := newClient(options)

return &LLM{
client: c,
model: options.modelName,
client: c,
model: options.modelName,
CallbacksHandler: options.callbacksHandler,
}, err
}

Expand All @@ -61,22 +64,22 @@ doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2`, ernieclient.ErrNot
}

// GeneratePrompt implements llms.LanguageModel.
func (l *LLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue,
func (o *LLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue,
options ...llms.CallOption,
) (llms.LLMResult, error) {
return llms.GeneratePrompt(ctx, l, promptValues, options...)
return llms.GeneratePrompt(ctx, o, promptValues, options...)
}

// GetNumTokens implements llms.LanguageModel.
func (l *LLM) GetNumTokens(_ string) int {
func (o *LLM) GetNumTokens(_ string) int {
// todo: not provided yet
// see: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
return -1
}

// Call implements llms.LLM.
func (l *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := l.Generate(ctx, []string{prompt}, options...)
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
if err != nil {
return "", err
}
Expand All @@ -89,15 +92,19 @@ func (l *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio
}

// Generate implements llms.LLM.
func (l *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMStart(ctx, prompts)
}

opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}

generations := make([]*llms.Generation, 0, len(prompts))
for _, prompt := range prompts {
result, err := l.client.CreateCompletion(ctx, l.getModelPath(opts), &ernieclient.CompletionRequest{
result, err := o.client.CreateCompletion(ctx, o.getModelPath(opts), &ernieclient.CompletionRequest{
Messages: []ernieclient.Message{{Role: "user", Content: prompt}},
Temperature: opts.Temperature,
TopP: opts.TopP,
Expand All @@ -106,11 +113,18 @@ func (l *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
Stream: opts.StreamingFunc != nil,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
if result.ErrorCode > 0 {
return nil, fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v",
err = fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v",
ErrCodeResponse, result.ErrorCode, result.ErrorMsg, result.ID)
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

generations = append(generations, &llms.Generation{
Expand All @@ -125,8 +139,8 @@ func (l *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
// 1. texts counts less than 16
// 2. text runes counts less than 384
// doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
func (l *LLM) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) {
resp, e := l.client.CreateEmbedding(ctx, texts)
func (o *LLM) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) {
resp, e := o.client.CreateEmbedding(ctx, texts)
if e != nil {
return nil, e
}
Expand All @@ -144,8 +158,8 @@ func (l *LLM) CreateEmbedding(ctx context.Context, texts []string) ([][]float32,
return emb, nil
}

func (l *LLM) getModelPath(opts llms.CallOptions) ernieclient.ModelPath {
model := l.model
func (o *LLM) getModelPath(opts llms.CallOptions) ernieclient.ModelPath {
model := o.model

if model == "" {
model = ModelName(opts.Model)
Expand Down
18 changes: 14 additions & 4 deletions llms/ernie/erniellm_option.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ernie

import "github.com/tmc/langchaingo/callbacks"

const (
ernieAPIKey = "ERNIE_API_KEY" //nolint:gosec
ernieSecretKey = "ERNIE_SECRET_KEY" //nolint:gosec
Expand All @@ -19,10 +21,11 @@ const (
)

type options struct {
apiKey string
secretKey string
accessToken string
modelName ModelName
apiKey string
secretKey string
accessToken string
modelName ModelName
callbacksHandler callbacks.Handler
}

type Option func(*options)
Expand Down Expand Up @@ -56,3 +59,10 @@ func WithModelName(modelName ModelName) Option {
opts.modelName = modelName
}
}

// WithCallbackHandler passes the callback Handler to the client.
func WithCallbackHandler(callbacksHandler callbacks.Handler) Option {
return func(opts *options) {
opts.callbacksHandler = callbacksHandler
}
}
3 changes: 3 additions & 0 deletions llms/huggingface/huggingfacellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
Seed: opts.Seed,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

Expand Down
3 changes: 3 additions & 0 deletions llms/local/localllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
Prompt: prompt,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

Expand Down
3 changes: 3 additions & 0 deletions llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca

err := o.client.Generate(ctx, req, fn)
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return []*llms.Generation{}, err
}

Expand Down
3 changes: 3 additions & 0 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
generations = append(generations, &llms.Generation{
Expand Down
3 changes: 3 additions & 0 deletions llms/vertexai/vertexai_palm_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca
StopSequences: opts.StopWords,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

Expand Down
3 changes: 3 additions & 0 deletions tools/duckduckgo/ddg.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func (t Tool) Call(ctx context.Context, input string) (string, error) {
if errors.Is(err, internal.ErrNoGoodResult) {
return "No good DuckDuckGo Search Results was found", nil
}
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolError(ctx, err)
}
return "", err
}

Expand Down
4 changes: 4 additions & 0 deletions tools/serpapi/serpapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ func (t Tool) Call(ctx context.Context, input string) (string, error) {
return "No good Google Search Results was found", nil
}

if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolError(ctx, err)
}

return "", err
}

Expand Down
20 changes: 16 additions & 4 deletions tools/wikipedia/wikipedia.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ func (t Tool) Call(ctx context.Context, input string) (string, error) {
t.CallbacksHandler.HandleToolStart(ctx, input)
}

result, err := t.searchWiKi(ctx, input)
if err != nil {
if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolError(ctx, err)
}
return "", err
}

if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolEnd(ctx, result)
}

return result, nil
}

func (t Tool) searchWiKi(ctx context.Context, input string) (string, error) {
searchResult, err := search(ctx, t.TopK, input, t.LanguageCode, t.UserAgent)
if err != nil {
return "", err
Expand Down Expand Up @@ -91,9 +107,5 @@ func (t Tool) Call(ctx context.Context, input string) (string, error) {
result += page.Extract
}

if t.CallbacksHandler != nil {
t.CallbacksHandler.HandleToolEnd(ctx, result)
}

return result, nil
}
Loading

0 comments on commit 534f757

Please sign in to comment.