diff --git a/components/.DS_Store b/components/.DS_Store new file mode 100644 index 000000000..1a19ed59a Binary files /dev/null and b/components/.DS_Store differ diff --git a/components/agentic/ark/README.md b/components/agentic/ark/README.md new file mode 100644 index 000000000..5a1923e8c --- /dev/null +++ b/components/agentic/ark/README.md @@ -0,0 +1,408 @@ +# Volcengine Ark Agentic Model + +A Volcengine Ark model implementation for [Eino](https://github.com/cloudwego/eino) that implements the `Model` interface in `agentic` component. This enables seamless integration with Eino's Agent capabilities for enhanced natural language processing and generation. + +## Features + +- Implements `github.com/cloudwego/eino/components/agentic.Model` +- Easy integration with Eino's agent system +- Configurable model parameters +- Support for responses api +- Support for streaming responses +- Support for tool calling (Function Tools, MCP Tools, Server Tools) +- Support for Prefix Cache and Session Cache + +## Installation + +```bash +go get github.com/cloudwego/eino-ext/components/agentic/ark@latest +``` + +## Quick Start +Here's a quick example of how to use the `Model`: + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/schema" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + msg, err := am.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +## Configuration + +The `Model` can be configured using the `ark.Config` struct: + +```go +type Config struct { + // Timeout specifies the maximum duration to wait for API responses + // If HTTPClient is set, Timeout will not be used. + // Optional. Default: 10 minutes + Timeout *time.Duration `json:"timeout"` + + // HTTPClient specifies the client to send HTTP requests. + // If HTTPClient is set, Timeout will not be used. + // Optional. Default &http.Client{Timeout: Timeout} + HTTPClient *http.Client `json:"http_client"` + + // RetryTimes specifies the number of retry attempts for failed API calls + // Optional. Default: 2 + RetryTimes *int `json:"retry_times"` + + // BaseURL specifies the base URL for Ark service + // Optional. Default: "https://ark.cn-beijing.volces.com/api/v3" + BaseURL string `json:"base_url"` + + // Region specifies the region where Ark service is located + // Optional. Default: "cn-beijing" + Region string `json:"region"` + + // The following three fields are about authentication - either APIKey or AccessKey/SecretKey pair is required + // For authentication details, see: https://www.volcengine.com/docs/82379/1298459 + // APIKey takes precedence if both are provided + APIKey string `json:"api_key"` + + AccessKey string `json:"access_key"` + + SecretKey string `json:"secret_key"` + + // The following fields correspond to Ark's responses API parameters + // Ref: https://www.volcengine.com/docs/82379/1298454 + + // Model specifies the ID of endpoint on ark platform + // Required + Model string `json:"model"` + + // MaxBlocks limits the maximum number of blocks that can be generated in the chat completion. + // Optional. Default: 4096 + MaxOutputTokens *int64 `json:"max_output_tokens,omitempty"` + + // Temperature specifies what sampling temperature to use + // Generally recommend altering this or TopP but not both + // Range: 0.0 to 1.0. Higher values make output more random + // Optional. Default: 1.0 + Temperature *float64 `json:"temperature,omitempty"` + + // TopP controls diversity via nucleus sampling + // Generally recommend altering this or Temperature but not both + // Range: 0.0 to 1.0. Lower values make output more focused + // Optional. Default: 0.7 + TopP *float64 `json:"top_p,omitempty"` + + // Stop sequences where the API will stop generating further tokens + // Optional. Example: []string{"\n", "User:"} + Stop []string `json:"stop,omitempty"` + + // FrequencyPenalty prevents repetition by penalizing tokens based on frequency + // Range: -2.0 to 2.0. Positive values decrease likelihood of repetition + // Optional. Default: 0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + + // LogitBias modifies likelihood of specific tokens appearing in completion + // Optional. Map token IDs to bias values from -100 to 100 + LogitBias map[string]int32 `json:"logit_bias,omitempty"` + + // PresencePenalty prevents repetition by penalizing tokens based on presence + // Range: -2.0 to 2.0. Positive values increase likelihood of new topics + // Optional. Default: 0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + + // LogProbs specifies whether to return log probabilities of the output tokens. + LogProbs *bool `json:"log_probs,omitempty"` + + // TopLogProbs specifies the number of most likely tokens to return at each token position, each with an associated log probability. + TopLogProbs *int `json:"top_log_probs,omitempty"` + + // RepetitionPenalty penalizes new tokens based on their existing frequency in the text so far. + // Range: 0.0 to 2.0. 1.0 means no penalty. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // Thinking controls whether the model is set to activate the deep thinking mode. + // It is set to be enabled by default. + Thinking *responses.ResponsesThinking `json:"thinking,omitempty"` + + // Reasoning specifies the reasoning effort of the model. + // Optional. + Reasoning *responses.ResponsesReasoning `json:"reasoning,omitempty"` + + // MaxToolCalls limits the maximum number of tool calls that can be generated in the chat completion. + // Optional. + MaxToolCalls *int64 `json:"max_tool_calls,omitempty"` + + // ParallelToolCalls controls whether the model is set to perform parallel tool calls. + // Optional. + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + + // ServerTools specifies the server-side tools that will be available for the model. + // Optional. + ServerTools []*ServerToolConfig `json:"server_tools,omitempty"` + + // MCPTools specifies the MCP tools that will be available for the model. + // Optional. + MCPTools []*responses.ToolMcp `json:"mcp_tools,omitempty"` + + // Cache specifies the cache configuration for the model. + // Optional. + Cache *CacheConfig `json:"cache,omitempty"` + + // CustomHeader the http header passed to model when requesting model + CustomHeader map[string]string `json:"custom_header"` +} +``` + +## Advanced Usage + +### Tool Calling + +The `Model` supports tool calling, including Function Tools, MCP Tools, and Server Tools. + +#### Function Tool Example + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + + +#### Server Tool Example + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*ark.ServerToolConfig{ + { + WebSearch: &responses.ToolWebSearch{ + Type: responses.ToolType_web_search, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(ark.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + ark.WithServerTools(serverTools), + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + ark.WithThinking(&responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall == nil { + continue + } + + serverToolArgs := block.ServerToolCall.Arguments.(*ark.ServerToolCallArguments) + + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +For more examples, please refer to the `examples` directory. diff --git a/components/agentic/ark/README.zh_CN.md b/components/agentic/ark/README.zh_CN.md new file mode 100644 index 000000000..7ae772142 --- /dev/null +++ b/components/agentic/ark/README.zh_CN.md @@ -0,0 +1,409 @@ +# Volcengine Ark Agentic Model + +基于 [Eino](https://github.com/cloudwego/eino) 的火山引擎 Ark 模型实现,实现了 `agentic` 组件中的 `Model` 接口。这使得该模型能够无缝集成到 Eino 的 Agent 能力中,提供增强的自然语言处理和生成功能。 + +## 功能特性 + +- 实现了 `github.com/cloudwego/eino/components/agentic.Model` 接口 +- 易于集成到 Eino 的 agent 系统中 +- 可配置的模型参数 +- 支持 Responses API +- 支持流式响应 (Streaming) +- 支持工具调用 (Tools),包括函数工具 (Function Tools)、MCP 工具 (MCP Tools) 和服务器工具 (Server Tools) +- 支持前缀缓存 (Prefix Cache) 和会话缓存 (Session Cache) + +## 安装 + +```bash +go get github.com/cloudwego/eino-ext/components/agentic/ark@latest +``` + +## 快速开始 + +以下是如何使用 `Model` 的一个快速示例: + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/schema" +) + +func main() { + ctx := context.Background() + + // 获取 ARK_API_KEY 和 ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + msg, err := am.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +## 配置 + +可以使用 `ark.Config` 结构体配置 `Model`: + +```go +type Config struct { + // Timeout 指定等待 API 响应的最大持续时间 + // 如果设置了 HTTPClient,则不会使用 Timeout。 + // 可选。默认值:10 分钟 + Timeout *time.Duration `json:"timeout"` + + // HTTPClient 指定用于发送 HTTP 请求的客户端。 + // 如果设置了 HTTPClient,则不会使用 Timeout。 + // 可选。默认值 &http.Client{Timeout: Timeout} + HTTPClient *http.Client `json:"http_client"` + + // RetryTimes 指定失败 API 调用的重试次数 + // 可选。默认值:2 + RetryTimes *int `json:"retry_times"` + + // BaseURL 指定 Ark 服务的基准 URL + // 可选。默认值:"https://ark.cn-beijing.volces.com/api/v3" + BaseURL string `json:"base_url"` + + // Region 指定 Ark 服务所在的区域 + // 可选。默认值:"cn-beijing" + Region string `json:"region"` + + // 以下三个字段与认证有关 - 需要 APIKey 或 AccessKey/SecretKey 对之一 + // 有关认证的详细信息,请参阅:https://www.volcengine.com/docs/82379/1298459 + // 如果同时提供,APIKey 优先 + APIKey string `json:"api_key"` + + AccessKey string `json:"access_key"` + + SecretKey string `json:"secret_key"` + + // 以下字段对应于 Ark 的 responses API 参数 + // 参考:https://www.volcengine.com/docs/82379/1298454 + + // Model 指定 ark 平台上的端点 ID + // 必填 + Model string `json:"model"` + + // MaxBlocks 限制聊天补全中生成的最大块数。 + // 可选。默认值:4096 + MaxOutputTokens *int64 `json:"max_output_tokens,omitempty"` + + // Temperature 指定要使用的采样温度 + // 通常建议修改此项或 TopP,但不能同时修改 + // 范围:0.0 到 1.0。值越高,输出越随机 + // 可选。默认值:1.0 + Temperature *float64 `json:"temperature,omitempty"` + + // TopP 通过核心采样控制多样性 + // 通常建议修改此项或 Temperature,但不能同时修改 + // 范围:0.0 到 1.0。值越低,输出越集中 + // 可选。默认值:0.7 + TopP *float64 `json:"top_p,omitempty"` + + // Stop 序列,API 将在这些序列处停止生成更多 token + // 可选。示例:[]string{"\n", "User:"} + Stop []string `json:"stop,omitempty"` + + // FrequencyPenalty 根据频率惩罚 token 以防止重复 + // 范围:-2.0 到 2.0。正值降低重复的可能性 + // 可选。默认值:0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + + // LogitBias 修改特定 token 在补全中出现的可能性 + // 可选。将 token ID 映射到 -100 到 100 的偏置值 + LogitBias map[string]int32 `json:"logit_bias,omitempty"` + + // PresencePenalty 根据存在与否惩罚 token 以防止重复 + // 范围:-2.0 到 2.0。正值增加新主题的可能性 + // 可选。默认值:0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + + // LogProbs 指定是否返回输出 token 的对数概率。 + LogProbs *bool `json:"log_probs,omitempty"` + + // TopLogProbs 指定每个 token 位置返回的最可能 token 的数量,每个都带有相关的对数概率。 + TopLogProbs *int `json:"top_log_probs,omitempty"` + + // RepetitionPenalty 基于 token 在目前为止的文本中的现有频率对其进行惩罚。 + // 范围:0.0 到 2.0。1.0 表示无惩罚。 + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // Thinking 控制模型是否设置为激活深度思考模式。 + // 默认设置为启用。 + Thinking *responses.ResponsesThinking `json:"thinking,omitempty"` + + // Reasoning 指定模型的推理力度。 + // 可选。 + Reasoning *responses.ResponsesReasoning `json:"reasoning,omitempty"` + + // MaxToolCalls 限制聊天补全中生成的最大工具调用数。 + // 可选。 + MaxToolCalls *int64 `json:"max_tool_calls,omitempty"` + + // ParallelToolCalls 控制模型是否设置为执行并行工具调用。 + // 可选。 + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + + // ServerTools 指定模型可用的服务器端工具。 + // 可选。 + ServerTools []*ServerToolConfig `json:"server_tools,omitempty"` + + // MCPTools 指定模型可用的 MCP 工具。 + // 可选。 + MCPTools []*responses.ToolMcp `json:"mcp_tools,omitempty"` + + // Cache 指定模型的缓存配置。 + // 可选。 + Cache *CacheConfig `json:"cache,omitempty"` + + // CustomHeader 请求模型时传递的 http 标头 + CustomHeader map[string]string `json:"custom_header"` +} +``` + +## 高级用法 + +### 工具调用 (Tool Calling) + +`Model` 支持工具调用,包括函数工具、MCP 工具和服务器工具。 + +#### 函数工具示例 + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + // 获取 ARK_API_KEY 和 ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + + +#### 服务器工具示例 + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*ark.ServerToolConfig{ + { + WebSearch: &responses.ToolWebSearch{ + Type: responses.ToolType_web_search, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(ark.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + ark.WithServerTools(serverTools), + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + ark.WithThinking(&responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall == nil { + continue + } + + serverToolArgs := block.ServerToolCall.Arguments.(*ark.ServerToolCallArguments) + + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +更多示例请参考 `examples` 目录。 diff --git a/components/agentic/ark/consts.go b/components/agentic/ark/consts.go new file mode 100644 index 000000000..095a827d9 --- /dev/null +++ b/components/agentic/ark/consts.go @@ -0,0 +1,62 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +const implType = "Ark" + +type WebSearchAction string + +const ( + WebSearchActionSearch WebSearchAction = "search" +) + +type ServerToolName string + +const ( + ServerToolNameWebSearch ServerToolName = "web_search" +) + +type TextAnnotationType string + +const ( + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeDocCitation TextAnnotationType = "doc_citation" +) + +type ThinkingType string + +const ( + ThinkingTypeAuto ThinkingType = "auto" + ThinkingTypeEnabled ThinkingType = "enabled" + ThinkingTypeDisabled ThinkingType = "disabled" +) + +type ResponseStatus string + +const ( + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusIncomplete ResponseStatus = "incomplete" + ResponseStatusFailed ResponseStatus = "failed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" +) diff --git a/components/agentic/ark/content_block_extra.go b/components/agentic/ark/content_block_extra.go new file mode 100644 index 000000000..b77b44ed8 --- /dev/null +++ b/components/agentic/ark/content_block_extra.go @@ -0,0 +1,115 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "reflect" + + "github.com/cloudwego/eino/schema" +) + +type blockExtraItemID string +type blockExtraItemStatus string + +const ( + videoURLFPS = "ark-user-input-video-url-fps" + itemIDKey = "ark-item-id" + itemStatusKey = "ark-item-status" +) + +func SetUserInputVideoFPS(block *schema.UserInputVideo, fps float64) { + setBlockExtraValue(schema.NewContentBlock(block), videoURLFPS, fps) +} + +func GetUserInputVideoFPS(block *schema.UserInputVideo) (float64, bool) { + return getBlockExtraValue[float64](schema.NewContentBlock(block), videoURLFPS) +} + +func setItemID(block *schema.ContentBlock, itemID string) { + setBlockExtraValue(block, itemIDKey, blockExtraItemID(itemID)) +} + +func getItemID(block *schema.ContentBlock) (string, bool) { + itemID, ok := getBlockExtraValue[blockExtraItemID](block, itemIDKey) + if !ok { + return "", false + } + return string(itemID), true +} + +func setItemStatus(block *schema.ContentBlock, status string) { + setBlockExtraValue(block, itemStatusKey, blockExtraItemStatus(status)) +} + +func GetItemStatus(block *schema.ContentBlock) (string, bool) { + itemStatus, ok := getBlockExtraValue[blockExtraItemStatus](block, itemStatusKey) + if !ok { + return "", false + } + return string(itemStatus), true +} + +func setBlockExtraValue[T any](block *schema.ContentBlock, key string, value T) { + if block == nil { + return + } + if block.Extra == nil { + block.Extra = map[string]any{} + } + block.Extra[key] = value +} + +func getBlockExtraValue[T any](block *schema.ContentBlock, key string) (T, bool) { + var zero T + if block == nil { + return zero, false + } + if block.Extra == nil { + return zero, false + } + val, ok := block.Extra[key].(T) + if !ok { + return zero, false + } + return val, true +} + +func concatFirstNonZero[T any](chunks []T) (T, error) { + for _, chunk := range chunks { + if !reflect.ValueOf(chunk).IsZero() { + return chunk, nil + } + } + var zero T + return zero, nil +} + +func concatFirst[T any](chunks []T) (T, error) { + if len(chunks) == 0 { + var zero T + return zero, nil + } + return chunks[0], nil +} + +func concatLast[T any](chunks []T) (T, error) { + if len(chunks) == 0 { + var zero T + return zero, nil + } + return chunks[len(chunks)-1], nil +} diff --git a/components/agentic/ark/convertor.go b/components/agentic/ark/convertor.go new file mode 100644 index 000000000..749a0f248 --- /dev/null +++ b/components/agentic/ark/convertor.go @@ -0,0 +1,1171 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "fmt" + "strings" + "sync" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/types/known/structpb" +) + +func toSystemRoleInputItems(msg *schema.AgenticMessage) (items []*responses.InputItem, err error) { + items = make([]*responses.InputItem, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item *responses.InputItem + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.MessageRole_system, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.MessageRole_system, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with system role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func toDeveloperRoleInputItems(msg *schema.AgenticMessage) (items []*responses.InputItem, err error) { + items = make([]*responses.InputItem, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item *responses.InputItem + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.MessageRole_developer, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.MessageRole_developer, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with developer role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func toAssistantRoleInputItems(msg *schema.AgenticMessage) (items []*responses.InputItem, err error) { + items = make([]*responses.InputItem, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item *responses.InputItem + + switch block.Type { + case schema.ContentBlockTypeAssistantGenText: + item, err = assistantGenTextToInputItem(block.AssistantGenText) + if err != nil { + return nil, fmt.Errorf("failed to convert assistant generated text to input item, err: %w", err) + } + + case schema.ContentBlockTypeReasoning: + item, err = reasoningToInputItem(block.Reasoning) + if err != nil { + return nil, fmt.Errorf("failed to convert reasoning to input item, err: %w", err) + } + + case schema.ContentBlockTypeFunctionToolCall: + item, err = functionToolCallToInputItem(block.FunctionToolCall) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeServerToolCall: + item, err = serverToolCallToInputItem(block.ServerToolCall) + if err != nil { + return nil, fmt.Errorf("failed to convert server tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolApprovalRequest: + item, err = mcpToolApprovalRequestToInputItem(block.MCPToolApprovalRequest) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool approval request to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPListToolsResult: + item, err = mcpListToolsResultToInputItem(block.MCPListToolsResult) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp list tools result to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolCall: + item, err = mcpToolCallToInputItem(block.MCPToolCall) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolResult: + item, err = mcpToolResultToInputItem(block.MCPToolResult) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool result to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with assistant role", block.Type) + } + + items = append(items, item) + } + + items, err = pairMCPToolCallItems(items) + if err != nil { + return nil, fmt.Errorf("pairMCPToolCallItems failed, err: %w", err) + } + + return items, nil +} + +func pairMCPToolCallItems(items []*responses.InputItem) (newItems []*responses.InputItem, err error) { + processed := make(map[int]bool) + mcpCallItemIDIndices := make(map[string][]int) + + for i, item := range items { + mcpCall := item.GetFunctionMcpCall() + if mcpCall == nil { + continue + } + + id := mcpCall.GetId() + if id == "" { + return nil, fmt.Errorf("found mcp tool call item with empty id at index %d", i) + } + + mcpCallItemIDIndices[id] = append(mcpCallItemIDIndices[id], i) + } + + for id, indices := range mcpCallItemIDIndices { + if len(indices) != 2 { + return nil, fmt.Errorf("mcp tool call %q should have exactly 2 items (call and result), "+ + "but found %d", id, len(indices)) + } + } + + for i, item := range items { + if processed[i] { + continue + } + + mcpCall := item.GetFunctionMcpCall() + if mcpCall == nil { + newItems = append(newItems, item) + continue + } + + id := mcpCall.GetId() + indices := mcpCallItemIDIndices[id] + + var pairIndex int + if indices[0] == i { + pairIndex = indices[1] + } else { + pairIndex = indices[0] + } + + pairMcpCall := items[pairIndex].GetFunctionMcpCall() + + mergedItem := &responses.InputItem{ + Union: &responses.InputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: &id, + ServerLabel: mcpCall.ServerLabel, + ApprovalRequestId: coalesce(mcpCall.ApprovalRequestId, pairMcpCall.ApprovalRequestId), + Name: mcpCall.Name, + Arguments: coalesce(mcpCall.Arguments, pairMcpCall.Arguments), + Output: coalesce(mcpCall.Output, pairMcpCall.Output), + Error: coalesce(mcpCall.Error, pairMcpCall.Error), + }, + }, + } + + newItems = append(newItems, mergedItem) + + processed[i] = true + processed[pairIndex] = true + } + + return newItems, nil +} + +func toUserRoleInputItems(msg *schema.AgenticMessage) (items []*responses.InputItem, err error) { + items = make([]*responses.InputItem, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item *responses.InputItem + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.MessageRole_user, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.MessageRole_user, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputVideo: + item, err = userInputVideoToInputItem(responses.MessageRole_user, block.UserInputVideo) + if err != nil { + return nil, fmt.Errorf("failed to convert user input video to input item, err: %w", err) + } + + case schema.ContentBlockTypeFunctionToolResult: + item, err = functionToolResultToInputItem(block.FunctionToolResult) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool result to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolApprovalResponse: + item, err = mcpToolApprovalResponseToInputItem(block.MCPToolApprovalResponse) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool approval response to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputFile: + item, err = userInputFileToInputItem(responses.MessageRole_user, block.UserInputFile) + if err != nil { + return nil, fmt.Errorf("failed to convert user input file to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with user role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func userInputTextToInputItem(role responses.MessageRole_Enum, block *schema.UserInputText) (inputItem *responses.InputItem, err error) { + item := &responses.ContentItem{ + Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, + Text: block.Text, + }, + }, + } + + inputItem = &responses.InputItem{ + Union: &responses.InputItem_InputMessage{ + InputMessage: &responses.ItemInputMessage{ + Type: ptrOf(responses.ItemType_message), + Role: role, + Content: []*responses.ContentItem{item}, + }, + }, + } + + return inputItem, nil +} + +func userInputImageToInputItem(role responses.MessageRole_Enum, block *schema.UserInputImage) (inputItem *responses.InputItem, err error) { + imageURL, err := resolveURL(block.URL, block.Base64Data, block.MIMEType) + if err != nil { + return nil, err + } + + detail, err := toContentItemImageDetail(block.Detail) + if err != nil { + return nil, err + } + + item := &responses.ContentItem{ + Union: &responses.ContentItem_Image{ + Image: &responses.ContentItemImage{ + Type: responses.ContentItemType_input_image, + ImageUrl: &imageURL, + Detail: detail, + }, + }, + } + + inputItem = &responses.InputItem{ + Union: &responses.InputItem_InputMessage{ + InputMessage: &responses.ItemInputMessage{ + Type: ptrOf(responses.ItemType_message), + Role: role, + Content: []*responses.ContentItem{item}, + }, + }, + } + + return inputItem, nil +} + +func toContentItemImageDetail(detail schema.ImageURLDetail) (*responses.ContentItemImageDetail_Enum, error) { + switch detail { + case schema.ImageURLDetailHigh: + return responses.ContentItemImageDetail_high.Enum(), nil + case schema.ImageURLDetailLow: + return responses.ContentItemImageDetail_low.Enum(), nil + case schema.ImageURLDetailAuto: + return responses.ContentItemImageDetail_auto.Enum(), nil + default: + return nil, fmt.Errorf("invalid image detail: %s", detail) + } +} + +func userInputVideoToInputItem(role responses.MessageRole_Enum, block *schema.UserInputVideo) (inputItem *responses.InputItem, err error) { + videoURL, err := resolveURL(block.URL, block.Base64Data, block.MIMEType) + if err != nil { + return nil, err + } + + var fpsPtr *float32 + if fps, ok := GetUserInputVideoFPS(block); ok { + fpsPtr = ptrOf(float32(fps)) + } + + contentItem := &responses.ContentItem{ + Union: &responses.ContentItem_Video{ + Video: &responses.ContentItemVideo{ + Type: responses.ContentItemType_input_video, + VideoUrl: videoURL, + Fps: fpsPtr, + }, + }, + } + + inputItem = &responses.InputItem{ + Union: &responses.InputItem_InputMessage{ + InputMessage: &responses.ItemInputMessage{ + Type: ptrOf(responses.ItemType_message), + Role: role, + Content: []*responses.ContentItem{contentItem}, + }, + }, + } + + return inputItem, nil +} + +func userInputFileToInputItem(role responses.MessageRole_Enum, block *schema.UserInputFile) (inputItem *responses.InputItem, err error) { + fileItem := &responses.ContentItemFile{ + Type: responses.ContentItemType_input_file, + Filename: &block.Name, + } + + if block.URL != "" { + fileItem.FileUrl = &block.URL + } else if block.Base64Data != "" { + fileItem.FileData = &block.Base64Data + } else { + return nil, fmt.Errorf("file input must have either URL or Base64Data") + } + + contentItem := &responses.ContentItem{ + Union: &responses.ContentItem_File{ + File: fileItem, + }, + } + + inputItem = &responses.InputItem{ + Union: &responses.InputItem_InputMessage{ + InputMessage: &responses.ItemInputMessage{ + Type: ptrOf(responses.ItemType_message), + Role: role, + Content: []*responses.ContentItem{contentItem}, + }, + }, + } + + return inputItem, nil +} + +func functionToolResultToInputItem(block *schema.FunctionToolResult) (item *responses.InputItem, err error) { + item = &responses.InputItem{ + Union: &responses.InputItem_FunctionToolCallOutput{ + FunctionToolCallOutput: &responses.ItemFunctionToolCallOutput{ + Type: responses.ItemType_function_call_output, + CallId: block.CallID, + Output: block.Result, + }, + }, + } + + return item, nil +} + +func assistantGenTextToInputItem(block *schema.AssistantGenText) (item *responses.InputItem, err error) { + block_ := schema.NewContentBlock(block) + id, _ := getItemID(block_) + status, _ := GetItemStatus(block_) + + content := &responses.ContentItem{ + Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_output_text, + Text: block.Text, + }, + }, + } + + item = &responses.InputItem{ + Union: &responses.InputItem_InputMessage{ + InputMessage: &responses.ItemInputMessage{ + Type: ptrOf(responses.ItemType_message), + Id: ptrIfNonZero(id), + Status: func() *responses.ItemStatus_Enum { + if status == "" { + return nil + } + return ptrOf(responses.ItemStatus_Enum(responses.ItemStatus_Enum_value[status])) + }(), + Role: responses.MessageRole_assistant, + Content: []*responses.ContentItem{content}, + }, + }, + } + + return item, nil +} + +func functionToolCallToInputItem(block *schema.FunctionToolCall) (item *responses.InputItem, err error) { + block_ := schema.NewContentBlock(block) + id, _ := getItemID(block_) + status, _ := GetItemStatus(block_) + + item = &responses.InputItem{ + Union: &responses.InputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + Type: responses.ItemType_function_call, + Id: ptrIfNonZero(id), + Status: func() *responses.ItemStatus_Enum { + if status == "" { + return nil + } + return ptrOf(responses.ItemStatus_Enum(responses.ItemStatus_Enum_value[status])) + }(), + CallId: block.CallID, + Name: block.Name, + Arguments: block.Arguments, + }, + }, + } + + return item, nil +} + +func reasoningToInputItem(block *schema.Reasoning) (item *responses.InputItem, err error) { + block_ := schema.NewContentBlock(block) + id, _ := getItemID(block_) + status, _ := GetItemStatus(block_) + + summary := make([]*responses.ReasoningSummaryPart, 0, len(block.Summary)) + for _, s := range block.Summary { + summary = append(summary, &responses.ReasoningSummaryPart{ + Text: s.Text, + }) + } + + item = &responses.InputItem{ + Union: &responses.InputItem_Reasoning{ + Reasoning: &responses.ItemReasoning{ + Type: responses.ItemType_reasoning, + Id: ptrIfNonZero(id), + Status: responses.ItemStatus_Enum(responses.ItemStatus_Enum_value[status]), + Summary: summary, + }, + }, + } + + return item, nil +} + +func serverToolCallToInputItem(block *schema.ServerToolCall) (item *responses.InputItem, err error) { + block_ := schema.NewContentBlock(block) + id, _ := getItemID(block_) + status, _ := GetItemStatus(block_) + + arguments, err := getServerToolCallArguments(block) + if err != nil { + return nil, err + } + + ws := arguments.WebSearch + if ws == nil { + return nil, fmt.Errorf("web search arguments is nil") + } + + var action *responses.Action + switch ws.ActionType { + case WebSearchActionSearch: + action = &responses.Action{ + Type: responses.ActionType_search, + Query: ws.Search.Query, + } + + default: + return nil, fmt.Errorf("invalid web search action type: %s", ws.ActionType) + } + + item = &responses.InputItem{ + Union: &responses.InputItem_FunctionWebSearchCall{ + FunctionWebSearchCall: &responses.ItemFunctionWebSearch{ + Type: responses.ItemType_web_search_call, + Id: id, + Status: responses.ItemStatus_Enum(responses.ItemStatus_Enum_value[status]), + Action: action, + }, + }, + } + + return item, nil +} + +func mcpToolApprovalRequestToInputItem(block *schema.MCPToolApprovalRequest) (item *responses.InputItem, err error) { + item = &responses.InputItem{ + Union: &responses.InputItem_McpApprovalRequest{ + McpApprovalRequest: &responses.ItemFunctionMcpApprovalRequest{ + Type: responses.ItemType_mcp_approval_request, + Id: ptrIfNonZero(block.ID), + ServerLabel: block.ServerLabel, + Arguments: block.Arguments, + Name: block.Name, + }, + }, + } + + return item, nil +} + +func mcpToolApprovalResponseToInputItem(block *schema.MCPToolApprovalResponse) (item *responses.InputItem, err error) { + item = &responses.InputItem{ + Union: &responses.InputItem_McpApprovalResponse{ + McpApprovalResponse: &responses.ItemFunctionMcpApprovalResponse{ + Type: responses.ItemType_mcp_approval_response, + Approve: block.Approve, + ApprovalRequestId: block.ApprovalRequestID, + Reason: func() *string { + if block.Reason == "" { + return nil + } + return &block.Reason + }(), + }, + }, + } + + return item, nil +} + +func mcpListToolsResultToInputItem(block *schema.MCPListToolsResult) (item *responses.InputItem, err error) { + tools := make([]*responses.McpTool, 0, len(block.Tools)) + + for i := range block.Tools { + tool := block.Tools[i] + + sc, err := jsonschemaToMap(tool.InputSchema) + if err != nil { + return nil, fmt.Errorf("failed to convert tool input schema to map, err: %w", err) + } + + sc_, err := structpb.NewStruct(sc) + if err != nil { + return nil, fmt.Errorf("failed to new structpb struct, err: %w", err) + } + + tools = append(tools, &responses.McpTool{ + Name: tool.Name, + Description: tool.Description, + InputSchema: sc_, + }) + } + + id, _ := getItemID(schema.NewContentBlock(block)) + + item = &responses.InputItem{ + Union: &responses.InputItem_McpListTools{ + McpListTools: &responses.ItemFunctionMcpListTools{ + Type: responses.ItemType_mcp_list_tools, + ServerLabel: block.ServerLabel, + Tools: tools, + Id: ptrIfNonZero(id), + Error: ptrIfNonZero(block.Error), + }, + }, + } + + return item, nil +} + +func mcpToolCallToInputItem(block *schema.MCPToolCall) (item *responses.InputItem, err error) { + id, _ := getItemID(schema.NewContentBlock(block)) + + item = &responses.InputItem{ + Union: &responses.InputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: ptrIfNonZero(id), + ServerLabel: block.ServerLabel, + ApprovalRequestId: ptrIfNonZero(block.ApprovalRequestID), + Arguments: block.Arguments, + Name: block.Name, + }, + }, + } + + return item, nil +} + +func mcpToolResultToInputItem(block *schema.MCPToolResult) (item *responses.InputItem, err error) { + id, _ := getItemID(schema.NewContentBlock(block)) + + item = &responses.InputItem{ + Union: &responses.InputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: ptrIfNonZero(id), + ServerLabel: block.ServerLabel, + Name: block.Name, + Output: ptrIfNonZero(block.Result), + Error: func() *string { + if block.Error == nil { + return nil + } + return &block.Error.Message + }(), + }, + }, + } + + return item, nil +} + +func toOutputMessage(resp *responses.ResponseObject) (msg *schema.AgenticMessage, err error) { + blocks := make([]*schema.ContentBlock, 0, len(resp.Output)) + + for _, item := range resp.Output { + var tmpBlocks []*schema.ContentBlock + + switch t := item.Union.(type) { + case *responses.OutputItem_Reasoning: + block, err := reasoningToContentBlocks(t) + if err != nil { + return nil, fmt.Errorf("failed to convert reasoning to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case *responses.OutputItem_OutputMessage: + tmpBlocks, err = outputMessageToContentBlocks(t) + if err != nil { + return nil, fmt.Errorf("failed to convert output message to content blocks, err: %w", err) + } + + case *responses.OutputItem_FunctionToolCall: + block, err := functionToolCallToContentBlock(t) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool call to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case *responses.OutputItem_FunctionMcpListTools: + block, err := mcpListToolsToContentBlock(t) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp list tools to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case *responses.OutputItem_FunctionMcpCall: + tmpBlocks, err = mcpCallToContentBlocks(t) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp call to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, tmpBlocks...) + + case *responses.OutputItem_FunctionMcpApprovalRequest: + block, err := mcpApprovalRequestToContentBlock(t) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp approval request to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case *responses.OutputItem_FunctionWebSearch: + block, err := webSearchToContentBlock(t) + if err != nil { + return nil, fmt.Errorf("failed to convert function web search to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + default: + return nil, fmt.Errorf("invalid output item type: %T", t) + } + + blocks = append(blocks, tmpBlocks...) + } + + msg = &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: blocks, + ResponseMeta: responseObjectToResponseMeta(resp), + } + + return msg, nil +} + +func outputMessageToContentBlocks(item *responses.OutputItem_OutputMessage) (blocks []*schema.ContentBlock, err error) { + outputMsg := item.OutputMessage + if outputMsg == nil { + return nil, fmt.Errorf("received empty output message") + } + + blocks = make([]*schema.ContentBlock, 0, len(outputMsg.Content)) + + for _, content := range outputMsg.Content { + var block *schema.ContentBlock + + switch t := content.Union.(type) { + case *responses.OutputContentItem_Text: + block, err = outputContentTextToContentBlock(t.Text) + if err != nil { + return nil, fmt.Errorf("failed to convert output text to content block, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid output content item type: %T", t) + } + + setItemID(block, outputMsg.Id) + setItemStatus(block, outputMsg.Status.String()) + + blocks = append(blocks, block) + } + + return blocks, nil +} + +func outputContentTextToContentBlock(text *responses.OutputContentItemText) (block *schema.ContentBlock, err error) { + annotations := make([]*TextAnnotation, 0, len(text.Annotations)) + for _, anno := range text.Annotations { + ta, err := outputTextAnnotationToTextAnnotation(anno) + if err != nil { + return nil, fmt.Errorf("failed to convert text annotation to text annotation, err: %w", err) + } + annotations = append(annotations, ta) + } + + block = schema.NewContentBlock(&schema.AssistantGenText{ + Text: text.Text, + Extension: &AssistantGenTextExtension{ + Annotations: annotations, + }, + }) + + return block, nil +} + +func outputTextAnnotationToTextAnnotation(anno *responses.Annotation) (*TextAnnotation, error) { + var ta *TextAnnotation + switch anno.Type { + case responses.AnnotationType_url_citation: + var coverImage *CoverImage + if anno.CoverImage != nil { + coverImage = &CoverImage{ + URL: anno.GetCoverImage().GetUrl(), + Width: anno.CoverImage.Width, + Height: anno.CoverImage.Height, + } + } + + ta = &TextAnnotation{ + Type: TextAnnotationTypeURLCitation, + URLCitation: &URLCitation{ + Title: anno.GetTitle(), + URL: anno.GetUrl(), + LogoURL: anno.GetLogoUrl(), + MobileURL: anno.GetMobileUrl(), + SiteName: anno.GetSiteName(), + PublishTime: anno.GetPublishTime(), + CoverImage: coverImage, + Summary: anno.GetSummary(), + FreshnessInfo: anno.GetFreshnessInfo(), + }, + } + + case responses.AnnotationType_doc_citation: + var chunkAttachment []map[string]any + for _, ca := range anno.ChunkAttachment { + chunkAttachment = append(chunkAttachment, ca.AsMap()) + } + + ta = &TextAnnotation{ + Type: TextAnnotationTypeDocCitation, + DocCitation: &DocCitation{ + DocID: anno.GetDocId(), + DocName: anno.GetDocName(), + ChunkID: anno.ChunkId, + ChunkAttachment: chunkAttachment, + }, + } + + default: + return nil, fmt.Errorf("invalid annotation type: %s", anno.Type.String()) + } + + return ta, nil +} + +func functionToolCallToContentBlock(item *responses.OutputItem_FunctionToolCall) (block *schema.ContentBlock, err error) { + toolCall := item.FunctionToolCall + if toolCall == nil { + return nil, fmt.Errorf("received empty function tool call") + } + + block = schema.NewContentBlock(&schema.FunctionToolCall{ + CallID: toolCall.CallId, + Name: toolCall.Name, + Arguments: toolCall.Arguments, + }) + + if toolCall.Id != nil { + setItemID(block, *toolCall.Id) + } + if toolCall.Status != nil { + setItemStatus(block, toolCall.Status.String()) + } + + return block, nil +} + +func webSearchToContentBlock(item *responses.OutputItem_FunctionWebSearch) (block *schema.ContentBlock, err error) { + webSearch := item.FunctionWebSearch + if webSearch == nil { + return nil, fmt.Errorf("received empty function web search") + } + + var args *ServerToolCallArguments + if action := webSearch.Action; action != nil { + switch action_ := WebSearchAction(action.Type.String()); action_ { + case WebSearchActionSearch: + args = &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: action_, + Search: &WebSearchQuery{ + Query: webSearch.Action.Query, + }, + }, + } + + default: + return nil, fmt.Errorf("invalid web search action type: %s", action_) + } + } + + block = schema.NewContentBlock(&schema.ServerToolCall{ + Name: string(ServerToolNameWebSearch), + Arguments: args, + }) + + setItemID(block, webSearch.Id) + setItemStatus(block, webSearch.Status.String()) + + return block, nil +} + +func reasoningToContentBlocks(item *responses.OutputItem_Reasoning) (block *schema.ContentBlock, err error) { + reasoning := item.Reasoning + if reasoning == nil { + return nil, fmt.Errorf("received empty reasoning") + } + + summary := make([]*schema.ReasoningSummary, 0, len(reasoning.Summary)) + for _, s := range reasoning.Summary { + summary = append(summary, &schema.ReasoningSummary{ + Text: s.Text, + }) + } + + block = schema.NewContentBlock(&schema.Reasoning{ + Summary: summary, + }) + + if reasoning.Id != nil { + setItemID(block, *reasoning.Id) + } + setItemStatus(block, reasoning.Status.String()) + + return block, nil +} + +func mcpCallToContentBlocks(item *responses.OutputItem_FunctionMcpCall) (blocks []*schema.ContentBlock, err error) { + mcpCall := item.FunctionMcpCall + if mcpCall == nil { + return nil, fmt.Errorf("received empty MCP call") + } + + callBlock := schema.NewContentBlock(&schema.MCPToolCall{ + ServerLabel: mcpCall.ServerLabel, + ApprovalRequestID: mcpCall.GetApprovalRequestId(), + Name: mcpCall.Name, + Arguments: mcpCall.Arguments, + }) + + resultBlock := schema.NewContentBlock(&schema.MCPToolResult{ + ServerLabel: mcpCall.ServerLabel, + Name: mcpCall.Name, + Result: mcpCall.GetOutput(), + Error: func() *schema.MCPToolCallError { + if mcpCall.Error == nil { + return nil + } + return &schema.MCPToolCallError{ + Message: mcpCall.GetError(), + } + }(), + }) + + if mcpCall.Id != nil { + setItemID(callBlock, *mcpCall.Id) + setItemID(resultBlock, *mcpCall.Id) + } + if resultBlock.MCPToolResult.Error == nil { + setItemStatus(resultBlock, responses.ItemStatus_completed.String()) + } else { + setItemStatus(resultBlock, responses.ItemStatus_failed.String()) + } + + blocks = []*schema.ContentBlock{callBlock, resultBlock} + + return blocks, nil +} + +func mcpListToolsToContentBlock(item *responses.OutputItem_FunctionMcpListTools) (block *schema.ContentBlock, err error) { + mcpListTools := item.FunctionMcpListTools + if mcpListTools == nil { + return nil, fmt.Errorf("received empty MCP list tools") + } + + group := &errgroup.Group{} + group.SetLimit(5) + mu := sync.Mutex{} + + tools := make([]*schema.MCPListToolsItem, 0, len(mcpListTools.Tools)) + for i := range mcpListTools.Tools { + tool := mcpListTools.Tools[i] + + group.Go(func() error { + b, err := sonic.Marshal(tool.InputSchema) + if err != nil { + return fmt.Errorf("failed to marshal tool input schema, err: %w", err) + } + + sc := &jsonschema.Schema{} + if err := sonic.Unmarshal(b, sc); err != nil { + return fmt.Errorf("failed to unmarshal tool input schema, err: %w", err) + } + + mu.Lock() + defer mu.Unlock() + + tools = append(tools, &schema.MCPListToolsItem{ + Name: tool.Name, + Description: tool.Description, + InputSchema: sc, + }) + + return nil + }) + } + + if err = group.Wait(); err != nil { + return nil, err + } + + block = schema.NewContentBlock(&schema.MCPListToolsResult{ + ServerLabel: mcpListTools.ServerLabel, + Tools: tools, + Error: mcpListTools.GetError(), + }) + + if mcpListTools.Id != nil { + setItemID(block, *mcpListTools.Id) + } + + return block, nil +} + +func mcpApprovalRequestToContentBlock(item *responses.OutputItem_FunctionMcpApprovalRequest) (block *schema.ContentBlock, err error) { + apReq := item.FunctionMcpApprovalRequest + if apReq == nil { + return nil, fmt.Errorf("received empty MCP approval request") + } + + block = schema.NewContentBlock(&schema.MCPToolApprovalRequest{ + ID: apReq.GetId(), + ServerLabel: apReq.ServerLabel, + Name: apReq.Name, + Arguments: apReq.Arguments, + }) + + if apReq.Id != nil { + setItemID(block, *apReq.Id) + } + + return block, nil +} + +func responseObjectToResponseMeta(obj *responses.ResponseObject) *schema.AgenticResponseMeta { + return &schema.AgenticResponseMeta{ + TokenUsage: toTokenUsage(obj), + Extension: toResponseMetaExtension(obj), + } +} + +func toTokenUsage(resp *responses.ResponseObject) (tokenUsage *schema.TokenUsage) { + if resp.Usage == nil { + return nil + } + + usage := &schema.TokenUsage{ + PromptTokens: int(resp.Usage.InputTokens), + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: int(resp.Usage.InputTokensDetails.GetCachedTokens()), + }, + CompletionTokens: int(resp.Usage.OutputTokens), + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: int(resp.Usage.OutputTokensDetails.GetReasoningTokens()), + }, + TotalTokens: int(resp.Usage.TotalTokens), + } + + return usage +} + +func toResponseMetaExtension(resp *responses.ResponseObject) *ResponseMetaExtension { + if resp == nil { + return nil + } + + var incompleteDetails *IncompleteDetails + if details := resp.IncompleteDetails; details != nil { + var contentFilter *ContentFilter + if filter := details.ContentFilter; filter != nil { + contentFilter = &ContentFilter{ + Type: filter.Type, + Details: filter.Details, + } + } + incompleteDetails = &IncompleteDetails{ + Reason: details.Reason, + ContentFilter: contentFilter, + } + } + + var respErr *ResponseError + if e := resp.Error; e != nil { + respErr = &ResponseError{ + Code: e.Code, + Message: e.Message, + } + } + + var thinking *ResponseThinking + if t := resp.Thinking; t != nil { + thinking = &ResponseThinking{ + Type: ThinkingType(t.Type.String()), + } + } + + var serviceTier ServiceTier + if s := resp.ServiceTier; s != nil { + serviceTier = ServiceTier(s.String()) + } + + var status ResponseStatus + if s := resp.Status; s != responses.ResponseStatus_unspecified { + status = ResponseStatus(s.String()) + } + + extension := &ResponseMetaExtension{ + ID: resp.Id, + Status: status, + IncompleteDetails: incompleteDetails, + Error: respErr, + PreviousResponseID: resp.GetPreviousResponseId(), + Thinking: thinking, + ExpireAt: resp.ExpireAt, + ServiceTier: serviceTier, + } + + return extension +} + +func resolveURL(url string, base64Data string, mimeType string) (real string, err error) { + if url != "" { + return url, nil + } + + if mimeType == "" { + return "", fmt.Errorf("mimeType is required when using base64Data") + } + + real, err = ensureDataURL(base64Data, mimeType) + if err != nil { + return "", err + } + + return real, nil +} + +func ensureDataURL(base64Data, mimeType string) (string, error) { + if strings.HasPrefix(base64Data, "data:") { + return "", fmt.Errorf("base64Data field must be a raw base64 string, but got a string with prefix 'data:'") + } + if mimeType == "" { + return "", fmt.Errorf("mimeType is required") + } + return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil +} diff --git a/components/agentic/ark/convertor_test.go b/components/agentic/ark/convertor_test.go new file mode 100644 index 000000000..e21ff22e7 --- /dev/null +++ b/components/agentic/ark/convertor_test.go @@ -0,0 +1,728 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "errors" + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/stretchr/testify/assert" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestToSystemRoleInputItems(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "hello"}), + schema.NewContentBlock(&schema.UserInputImage{ + URL: "http://example.com/image.png", + MIMEType: "image/png", + Detail: schema.ImageURLDetailHigh, + }), + }, + } + + items, err := toSystemRoleInputItems(msg) + assert.NoError(t, err) + assert.Equal(t, 2, len(items)) + assert.Equal(t, responses.MessageRole_system, items[0].GetInputMessage().Role) + + msgInvalid := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + {Type: "invalid"}, + }, + } + _, err = toSystemRoleInputItems(msgInvalid) + assert.Error(t, err) +} + +func TestToDeveloperRoleInputItems(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "dev"}), + }, + } + + items, err := toDeveloperRoleInputItems(msg) + assert.NoError(t, err) + assert.Equal(t, 1, len(items)) + assert.Equal(t, responses.MessageRole_developer, items[0].GetInputMessage().Role) +} + +func TestToAssistantRoleInputItems(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "answer"}), + schema.NewContentBlock(&schema.Reasoning{ + Summary: []*schema.ReasoningSummary{{Text: "reason"}}, + }), + }, + } + setItemID(msg.ContentBlocks[1], "id-1") + setItemStatus(msg.ContentBlocks[1], responses.ItemStatus_completed.String()) + + items, err := toAssistantRoleInputItems(msg) + assert.NoError(t, err) + assert.Equal(t, 2, len(items)) + assert.Equal(t, responses.MessageRole_assistant, items[0].GetInputMessage().Role) + assert.NotNil(t, items[1].GetReasoning()) +} + +func TestPairMCPToolCallItems(t *testing.T) { + id := "call-1" + out := "result" + errStr := "err" + + call := &responses.InputItem{ + Union: &responses.InputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: &id, + ServerLabel: "server", + Name: "tool", + }, + }, + } + result := &responses.InputItem{ + Union: &responses.InputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: &id, + ServerLabel: "server", + Name: "tool", + Output: &out, + Error: &errStr, + }, + }, + } + + items, err := pairMCPToolCallItems([]*responses.InputItem{call, result}) + assert.NoError(t, err) + assert.Equal(t, 1, len(items)) + mcp := items[0].GetFunctionMcpCall() + assert.NotNil(t, mcp) + assert.Equal(t, out, mcp.GetOutput()) + assert.Equal(t, errStr, mcp.GetError()) + + onlyCall := []*responses.InputItem{call} + _, err = pairMCPToolCallItems(onlyCall) + assert.Error(t, err) +} + +func TestToUserRoleInputItems(t *testing.T) { + msg := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "u"}), + schema.NewContentBlock(&schema.UserInputVideo{ + URL: "http://example.com/video.mp4", + MIMEType: "video/mp4", + }), + schema.NewContentBlock(&schema.FunctionToolResult{ + CallID: "c1", + Name: "n1", + Result: "r1", + }), + schema.NewContentBlock(&schema.MCPToolApprovalResponse{ + ApprovalRequestID: "ar1", + Approve: true, + }), + }, + } + + items, err := toUserRoleInputItems(msg) + assert.NoError(t, err) + assert.Equal(t, 4, len(items)) + assert.Equal(t, responses.MessageRole_user, items[0].GetInputMessage().Role) + assert.NotNil(t, items[1].GetInputMessage().Content[0].GetVideo()) + assert.NotNil(t, items[2].GetFunctionToolCallOutput()) + assert.NotNil(t, items[3].GetMcpApprovalResponse()) +} + +func TestUserInputTextToInputItem(t *testing.T) { + block := &schema.UserInputText{Text: "hello"} + item, err := userInputTextToInputItem(responses.MessageRole_user, block) + assert.NoError(t, err) + assert.Equal(t, "hello", item.GetInputMessage().Content[0].GetText().Text) +} + +func TestUserInputImageToInputItem(t *testing.T) { + block := &schema.UserInputImage{ + URL: "http://example.com/image.png", + MIMEType: "image/png", + Detail: schema.ImageURLDetailLow, + } + item, err := userInputImageToInputItem(responses.MessageRole_user, block) + assert.NoError(t, err) + img := item.GetInputMessage().Content[0].GetImage() + assert.NotNil(t, img) + assert.NotNil(t, img.ImageUrl) + assert.Equal(t, block.URL, *img.ImageUrl) + + blockInvalid := &schema.UserInputImage{ + Base64Data: "xxx", + MIMEType: "", + Detail: "invalid", + } + _, err = userInputImageToInputItem(responses.MessageRole_user, blockInvalid) + assert.Error(t, err) +} + +func TestToContentItemImageDetail(t *testing.T) { + tests := []struct { + in schema.ImageURLDetail + valid bool + }{ + {schema.ImageURLDetailHigh, true}, + {schema.ImageURLDetailLow, true}, + {schema.ImageURLDetailAuto, true}, + {"invalid", false}, + } + for _, tt := range tests { + detail, err := toContentItemImageDetail(tt.in) + if tt.valid { + assert.NoError(t, err) + assert.NotNil(t, detail) + } else { + assert.Error(t, err) + } + } +} + +func TestUserInputVideoToInputItem(t *testing.T) { + video := &schema.UserInputVideo{ + URL: "http://example.com/video.mp4", + MIMEType: "video/mp4", + } + item, err := userInputVideoToInputItem(responses.MessageRole_user, video) + assert.NoError(t, err) + assert.Equal(t, video.URL, item.GetInputMessage().Content[0].GetVideo().VideoUrl) +} + +func TestUserInputFileToInputItem(t *testing.T) { + tests := []struct { + name string + block *schema.UserInputFile + hasURL bool + }{ + { + name: "with_url", + block: &schema.UserInputFile{ + Name: "file.txt", + URL: "http://example.com/file.txt", + }, + hasURL: true, + }, + { + name: "with_base64", + block: &schema.UserInputFile{ + Name: "file.bin", + Base64Data: "ZGF0YQ==", + }, + hasURL: false, + }, + } + + for _, tt := range tests { + item, err := userInputFileToInputItem(responses.MessageRole_user, tt.block) + assert.NoError(t, err) + msg := item.GetInputMessage() + assert.Equal(t, responses.MessageRole_user, msg.Role) + assert.Len(t, msg.Content, 1) + file := msg.Content[0].GetFile() + assert.NotNil(t, file) + assert.Equal(t, responses.ContentItemType_input_file, file.Type) + assert.NotNil(t, file.Filename) + assert.Equal(t, tt.block.Name, *file.Filename) + if tt.hasURL { + assert.NotNil(t, file.FileUrl) + assert.Equal(t, tt.block.URL, *file.FileUrl) + assert.Nil(t, file.FileData) + } else { + assert.NotNil(t, file.FileData) + assert.Equal(t, tt.block.Base64Data, *file.FileData) + assert.Nil(t, file.FileUrl) + } + } +} + +func TestFunctionToolResultToInputItem(t *testing.T) { + block := &schema.FunctionToolResult{ + CallID: "c1", + Name: "n1", + Result: "r1", + } + item, err := functionToolResultToInputItem(block) + assert.NoError(t, err) + out := item.GetFunctionToolCallOutput() + assert.NotNil(t, out) + assert.Equal(t, "c1", out.CallId) + assert.Equal(t, "r1", out.Output) +} + +func TestAssistantGenTextToInputItem(t *testing.T) { + block := &schema.AssistantGenText{Text: "answer"} + item, err := assistantGenTextToInputItem(block) + assert.NoError(t, err) + msg := item.GetInputMessage() + assert.Equal(t, responses.MessageRole_assistant, msg.Role) + assert.Equal(t, "answer", msg.Content[0].GetText().Text) +} + +func TestFunctionToolCallToInputItem(t *testing.T) { + block := &schema.FunctionToolCall{ + CallID: "cid", + Name: "name", + Arguments: "{}", + } + item, err := functionToolCallToInputItem(block) + assert.NoError(t, err) + call := item.GetFunctionToolCall() + assert.NotNil(t, call) + assert.Equal(t, "cid", call.CallId) + assert.Equal(t, "name", call.Name) +} + +func TestReasoningToInputItem(t *testing.T) { + block := &schema.Reasoning{ + Summary: []*schema.ReasoningSummary{{Text: "r"}}, + } + + item, err := reasoningToInputItem(block) + assert.NoError(t, err) + reason := item.GetReasoning() + assert.NotNil(t, reason) + assert.Equal(t, 1, len(reason.Summary)) + assert.Equal(t, "r", reason.Summary[0].Text) +} + +func TestServerToolCallToInputItem(t *testing.T) { + mockey.PatchConvey("TestServerToolCallToInputItem", t, func() { + args := &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: WebSearchActionSearch, + Search: &WebSearchQuery{Query: "q"}, + }, + } + call := &schema.ServerToolCall{ + Name: string(ServerToolNameWebSearch), + Arguments: args, + } + + item, err := serverToolCallToInputItem(call) + assert.NoError(t, err) + ws := item.GetFunctionWebSearchCall() + assert.NotNil(t, ws) + assert.NotNil(t, ws.Action) + assert.Equal(t, "q", ws.Action.Query) + + mockey.Mock(getServerToolCallArguments).Return(nil, errors.New("mock")).Build() + _, err = serverToolCallToInputItem(call) + assert.Error(t, err) + }) +} + +func TestMcpToolApprovalRequestToInputItem(t *testing.T) { + req := &schema.MCPToolApprovalRequest{ + ID: "id", + ServerLabel: "server", + Name: "name", + Arguments: "{}", + } + + item, err := mcpToolApprovalRequestToInputItem(req) + assert.NoError(t, err) + ap := item.GetMcpApprovalRequest() + assert.NotNil(t, ap) + assert.NotEmpty(t, ap.GetId()) + assert.Equal(t, "server", ap.ServerLabel) +} + +func TestMcpToolApprovalResponseToInputItem(t *testing.T) { + resp := &schema.MCPToolApprovalResponse{ + ApprovalRequestID: "rid", + Approve: true, + Reason: "ok", + } + item, err := mcpToolApprovalResponseToInputItem(resp) + assert.NoError(t, err) + ap := item.GetMcpApprovalResponse() + assert.NotNil(t, ap) + assert.True(t, ap.Approve) + assert.Equal(t, "rid", ap.ApprovalRequestId) +} + +func TestMcpListToolsResultToInputItem(t *testing.T) { + sc := &jsonschema.Schema{ + Title: "t", + Description: "d", + } + + content := &schema.MCPListToolsResult{ + ServerLabel: "server", + Tools: []*schema.MCPListToolsItem{ + { + Name: "tool", + Description: "desc", + InputSchema: sc, + }, + }, + Error: "err", + } + + item, err := mcpListToolsResultToInputItem(content) + assert.NoError(t, err) + list := item.GetMcpListTools() + assert.NotNil(t, list) + assert.Equal(t, 1, len(list.Tools)) + assert.Equal(t, "tool", list.Tools[0].Name) +} + +func TestMcpToolCallToInputItem(t *testing.T) { + call := &schema.MCPToolCall{ + ServerLabel: "server", + Name: "name", + Arguments: "{}", + ApprovalRequestID: "ar", + } + + item, err := mcpToolCallToInputItem(call) + assert.NoError(t, err) + mcp := item.GetFunctionMcpCall() + assert.NotNil(t, mcp) + assert.Equal(t, "server", mcp.ServerLabel) + assert.Equal(t, "ar", mcp.GetApprovalRequestId()) +} + +func TestMcpToolResultToInputItem(t *testing.T) { + res := &schema.MCPToolResult{ + ServerLabel: "server", + Name: "name", + Result: "r", + Error: &schema.MCPToolCallError{Message: "e"}, + } + + item, err := mcpToolResultToInputItem(res) + assert.NoError(t, err) + mcp := item.GetFunctionMcpCall() + assert.NotNil(t, mcp) + assert.Equal(t, "server", mcp.ServerLabel) + assert.Equal(t, "r", mcp.GetOutput()) +} + +func TestToOutputMessage(t *testing.T) { + outputText := &responses.OutputContentItemText{ + Text: "answer", + } + outMsg := &responses.OutputItem{ + Union: &responses.OutputItem_OutputMessage{ + OutputMessage: &responses.ItemOutputMessage{ + Content: []*responses.OutputContentItem{ + {Union: &responses.OutputContentItem_Text{Text: outputText}}, + }, + }, + }, + } + + id := "mid" + mcpCall := &responses.OutputItem{ + Union: &responses.OutputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Type: responses.ItemType_mcp_call, + Id: &id, + ServerLabel: "server", + Name: "tool", + Output: ptrOf("out"), + }, + }, + } + + resp := &responses.ResponseObject{ + Output: []*responses.OutputItem{outMsg, mcpCall}, + } + + msg, err := toOutputMessage(resp) + assert.NoError(t, err) + assert.Equal(t, schema.AgenticRoleTypeAssistant, msg.Role) + assert.Greater(t, len(msg.ContentBlocks), 0) + assert.NotNil(t, msg.ContentBlocks[0].AssistantGenText) + assert.Equal(t, "answer", msg.ContentBlocks[0].AssistantGenText.Text) +} + +func TestOutputMessageToContentBlocks(t *testing.T) { + out := &responses.ItemOutputMessage{ + Id: "id", + Status: responses.ItemStatus_completed, + Content: []*responses.OutputContentItem{ + { + Union: &responses.OutputContentItem_Text{ + Text: &responses.OutputContentItemText{Text: "a"}, + }, + }, + }, + } + blocks, err := outputMessageToContentBlocks(&responses.OutputItem_OutputMessage{OutputMessage: out}) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].AssistantGenText) + + _, err = outputMessageToContentBlocks(&responses.OutputItem_OutputMessage{}) + assert.Error(t, err) +} + +func TestOutputContentTextToContentBlock(t *testing.T) { + title := "t" + url := "u" + anno := &responses.Annotation{ + Type: responses.AnnotationType_url_citation, + Title: title, + Url: url, + } + block, err := outputContentTextToContentBlock(&responses.OutputContentItemText{ + Text: "a", + Annotations: []*responses.Annotation{anno}, + }) + assert.NoError(t, err) + assert.NotNil(t, block.AssistantGenText) + assert.Equal(t, "a", block.AssistantGenText.Text) +} + +func TestOutputTextAnnotationToTextAnnotation(t *testing.T) { + docID := "d" + docName := "n" + a := &responses.Annotation{ + Type: responses.AnnotationType_doc_citation, + DocId: &docID, + DocName: &docName, + ChunkId: ptrOf[int32](1), + ChunkAttachment: []*structpb.Struct{ + structpb.NewStructValue(&structpb.Struct{}).GetStructValue(), + }, + } + ta, err := outputTextAnnotationToTextAnnotation(a) + assert.NoError(t, err) + assert.NotNil(t, ta) + assert.NotNil(t, ta.DocCitation) + assert.Equal(t, "d", ta.DocCitation.DocID) + + invalid := &responses.Annotation{ + Type: responses.AnnotationType_unspecified, + } + _, err = outputTextAnnotationToTextAnnotation(invalid) + assert.Error(t, err) +} + +func TestFunctionToolCallToContentBlock(t *testing.T) { + id := "id" + item := &responses.OutputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + CallId: "cid", + Name: "name", + Status: responses.ItemStatus_completed.Enum(), + Id: &id, + }, + } + block, err := functionToolCallToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block.FunctionToolCall) + assert.Equal(t, "cid", block.FunctionToolCall.CallID) + + _, err = functionToolCallToContentBlock(&responses.OutputItem_FunctionToolCall{}) + assert.Error(t, err) +} + +func TestWebSearchToContentBlock(t *testing.T) { + item := &responses.OutputItem_FunctionWebSearch{ + FunctionWebSearch: &responses.ItemFunctionWebSearch{ + Id: "id", + Status: responses.ItemStatus_completed, + Action: &responses.Action{ + Type: responses.ActionType_search, + Query: "q", + }, + }, + } + block, err := webSearchToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block.ServerToolCall) + args := block.ServerToolCall.Arguments.(*ServerToolCallArguments) + assert.NotNil(t, args.WebSearch) + assert.Equal(t, "q", args.WebSearch.Search.Query) + + itemInvalid := &responses.OutputItem_FunctionWebSearch{ + FunctionWebSearch: &responses.ItemFunctionWebSearch{ + Action: &responses.Action{ + Type: responses.ActionType_unspecified, + }, + }, + } + _, err = webSearchToContentBlock(itemInvalid) + assert.Error(t, err) +} + +func TestReasoningToContentBlocks(t *testing.T) { + id := "id" + item := &responses.OutputItem_Reasoning{ + Reasoning: &responses.ItemReasoning{ + Id: &id, + Status: responses.ItemStatus_completed, + Summary: []*responses.ReasoningSummaryPart{ + {Text: "r"}, + }, + }, + } + block, err := reasoningToContentBlocks(item) + assert.NoError(t, err) + assert.NotNil(t, block.Reasoning) + assert.Equal(t, 1, len(block.Reasoning.Summary)) + + _, err = reasoningToContentBlocks(&responses.OutputItem_Reasoning{}) + assert.Error(t, err) +} + +func TestMcpCallToContentBlocks(t *testing.T) { + id := "id" + item := &responses.OutputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Id: &id, + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + Output: ptrOf("out"), + }, + } + blocks, err := mcpCallToContentBlocks(item) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + assert.NotNil(t, blocks[0].MCPToolCall) + assert.NotNil(t, blocks[1].MCPToolResult) + + _, err = mcpCallToContentBlocks(&responses.OutputItem_FunctionMcpCall{}) + assert.Error(t, err) +} + +func TestMcpListToolsToContentBlock(t *testing.T) { + toolSchema, err := structpb.NewStruct(map[string]any{"type": "object"}) + assert.NoError(t, err) + id := "id" + item := &responses.OutputItem_FunctionMcpListTools{ + FunctionMcpListTools: &responses.ItemFunctionMcpListTools{ + Id: &id, + ServerLabel: "server", + Tools: []*responses.McpTool{ + { + Name: "tool", + Description: "desc", + InputSchema: toolSchema, + }, + }, + }, + } + block, err := mcpListToolsToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block.MCPListToolsResult) + assert.Equal(t, 1, len(block.MCPListToolsResult.Tools)) + + _, err = mcpListToolsToContentBlock(&responses.OutputItem_FunctionMcpListTools{}) + assert.Error(t, err) +} + +func TestMcpApprovalRequestToContentBlock(t *testing.T) { + item := &responses.OutputItem_FunctionMcpApprovalRequest{ + FunctionMcpApprovalRequest: &responses.ItemFunctionMcpApprovalRequest{ + Id: ptrOf("id"), + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + }, + } + block, err := mcpApprovalRequestToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block.MCPToolApprovalRequest) + assert.NotEmpty(t, block.MCPToolApprovalRequest.ID) + + _, err = mcpApprovalRequestToContentBlock(&responses.OutputItem_FunctionMcpApprovalRequest{}) + assert.Error(t, err) +} + +func TestResponseObjectToResponseMeta(t *testing.T) { + resp := &responses.ResponseObject{ + Id: "id", + } + meta := responseObjectToResponseMeta(resp) + assert.NotNil(t, meta) + assert.NotNil(t, meta.Extension) +} + +func TestToTokenUsage(t *testing.T) { + assert.Nil(t, toTokenUsage(&responses.ResponseObject{})) +} + +func TestToResponseMetaExtension(t *testing.T) { + resp := &responses.ResponseObject{ + Id: "id", + IncompleteDetails: &responses.IncompleteDetails{ + Reason: "r", + ContentFilter: &responses.ContentFilter{ + Type: "t", + Details: "d", + }, + }, + Error: &responses.Error{ + Code: "c", + Message: "m", + }, + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_enabled.Enum(), + }, + ServiceTier: responses.ResponsesServiceTier_default.Enum(), + Status: responses.ResponseStatus_completed, + } + ext := toResponseMetaExtension(resp) + assert.NotNil(t, ext) + assert.Equal(t, "id", ext.ID) + assert.NotNil(t, ext.IncompleteDetails) + assert.NotNil(t, ext.Error) + assert.NotNil(t, ext.Thinking) + assert.Nil(t, toResponseMetaExtension(nil)) +} + +func TestResolveURL(t *testing.T) { + u, err := resolveURL("http://example.com/image.png", "", "") + assert.NoError(t, err) + assert.Equal(t, "http://example.com/image.png", u) + + u, err = resolveURL("", "abcd", "image/png") + assert.NoError(t, err) + assert.NotEmpty(t, u) + + _, err = resolveURL("", "abcd", "") + assert.Error(t, err) +} + +func TestEnsureDataURL(t *testing.T) { + _, err := ensureDataURL("data:xxx", "image/png") + assert.Error(t, err) + u, err := ensureDataURL("abcd", "image/png") + assert.NoError(t, err) + assert.Equal(t, "data:image/png;base64,abcd", u) + _, err = ensureDataURL("abcd", "") + assert.Error(t, err) +} diff --git a/components/agentic/ark/event_convertor.go b/components/agentic/ark/event_convertor.go new file mode 100644 index 000000000..b7400dccb --- /dev/null +++ b/components/agentic/ark/event_convertor.go @@ -0,0 +1,760 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "errors" + "fmt" + "io" + + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" +) + +func receivedStreamResponse(streamReader *utils.ResponsesStreamReader, + config *agentic.Config, sw *schema.StreamWriter[*agentic.CallbackOutput]) { + + receiver := newStreamReceiver() + sender := newCallbackSender(sw, config) + + for { + event, err := streamReader.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return + } + _ = sw.Send(nil, fmt.Errorf("failed to read stream, err: %w", err)) + return + } + + sender.errHeader = fmt.Sprintf("failed to convert event '%s'", event.GetEventType()) + + switch ev := event.Event.(type) { + case *responses.Event_TextDone, + *responses.Event_ReasoningPart, + *responses.Event_ReasoningPartDone, + *responses.Event_ReasoningTextDone, + *responses.Event_FunctionCallArgumentsDone, + *responses.Event_ResponseMcpCallArgumentsDone, + *responses.Event_ResponseMcpApprovalRequest: + + // Do nothing. + continue + + case *responses.Event_Error: + meta := receiver.errorEventToResponseMeta(ev.Error) + sender.sendMeta(meta, nil) + + case *responses.Event_Response: + meta := responseObjectToResponseMeta(ev.Response.Response) + sender.sendMeta(meta, nil) + + case *responses.Event_ResponseInProgress: + meta := responseObjectToResponseMeta(ev.ResponseInProgress.Response) + sender.sendMeta(meta, nil) + + case *responses.Event_ResponseCompleted: + meta := responseObjectToResponseMeta(ev.ResponseCompleted.Response) + sender.sendMeta(meta, nil) + + case *responses.Event_ResponseIncomplete: + meta := responseObjectToResponseMeta(ev.ResponseIncomplete.Response) + sender.sendMeta(meta, nil) + + case *responses.Event_ResponseFailed: + meta := responseObjectToResponseMeta(ev.ResponseFailed.Response) + sender.sendMeta(meta, nil) + + case *responses.Event_Item: + blocks, err := receiver.itemAddedEventToContentBlock(ev.Item) + for _, block := range blocks { + sender.sendBlock(block, err) + } + + case *responses.Event_ItemDone: + blocks, err := receiver.itemDoneEventToContentBlocks(ev.ItemDone) + for _, block := range blocks { + sender.sendBlock(block, err) + } + + case *responses.Event_ContentPart: + block, err := receiver.contentPartAddedEventToContentBlock(ev.ContentPart) + sender.sendBlock(block, err) + + case *responses.Event_ContentPartDone: + block, err := receiver.contentPartDoneEventToContentBlock(ev.ContentPartDone) + sender.sendBlock(block, err) + + case *responses.Event_Text: + block := receiver.outputTextDeltaEventToContentBlock(ev.Text) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseAnnotationAdded: + block, err := receiver.annotationAddedEventToContentBlock(ev.ResponseAnnotationAdded) + sender.sendBlock(block, err) + + case *responses.Event_ReasoningText: + block := receiver.reasoningSummaryTextDeltaEventToContentBlock(ev.ReasoningText) + sender.sendBlock(block, nil) + + case *responses.Event_FunctionCallArguments: + block := receiver.functionCallArgumentsDeltaEventToContentBlock(ev.FunctionCallArguments) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpListToolsInProgress: + phase := ev.ResponseMcpListToolsInProgress + block := receiver.mcpListToolsPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_in_progress) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpListToolsCompleted: + phase := ev.ResponseMcpListToolsCompleted + block := receiver.mcpListToolsPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_completed) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpCallArgumentsDelta: + block := receiver.mcpCallArgumentsDeltaEventToContentBlock(ev.ResponseMcpCallArgumentsDelta) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpCallInProgress: + phase := ev.ResponseMcpCallInProgress + block := receiver.mcpCallPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_in_progress) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpCallCompleted: + phase := ev.ResponseMcpCallCompleted + block := receiver.mcpCallPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_completed) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseMcpCallFailed: + phase := ev.ResponseMcpCallFailed + block := receiver.mcpCallPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_failed) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseWebSearchCallInProgress: + phase := ev.ResponseWebSearchCallInProgress + block := receiver.webSearchPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_in_progress) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseWebSearchCallSearching: + phase := ev.ResponseWebSearchCallSearching + block := receiver.webSearchPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_searching) + sender.sendBlock(block, nil) + + case *responses.Event_ResponseWebSearchCallCompleted: + phase := ev.ResponseWebSearchCallCompleted + block := receiver.webSearchPhaseToContentBlock(phase.ItemId, phase.OutputIndex, responses.ItemStatus_completed) + sender.sendBlock(block, nil) + + default: + sw.Send(nil, fmt.Errorf("invalid event type: %T", ev)) + } + } +} + +type callbackSender struct { + sw *schema.StreamWriter[*agentic.CallbackOutput] + config *agentic.Config + errHeader string +} + +func newCallbackSender(sw *schema.StreamWriter[*agentic.CallbackOutput], config *agentic.Config) *callbackSender { + return &callbackSender{ + sw: sw, + config: config, + } +} + +func (s *callbackSender) sendMeta(meta *schema.AgenticResponseMeta, err error) { + s.send(meta, nil, err) +} + +func (s *callbackSender) sendBlock(block *schema.ContentBlock, err error) { + s.send(nil, block, err) +} + +func (s *callbackSender) send(meta *schema.AgenticResponseMeta, block *schema.ContentBlock, err error) { + if err != nil { + _ = s.sw.Send(nil, fmt.Errorf("%s: %w", s.errHeader, err)) + return + } + + msg := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ResponseMeta: meta, + } + if block != nil { + msg.ContentBlocks = []*schema.ContentBlock{block} + } + + s.sw.Send(&agentic.CallbackOutput{ + Message: msg, + Config: s.config, + }, nil) +} + +type streamReceiver struct { + ProcessingAssistantGenTextBlockIndex map[string]map[int]bool + + MaxBlockIndex int + IndexMapper map[string]int + + MaxReasoningSummaryIndex map[string]int + ReasoningSummaryIndexMapper map[string]int + + MaxTextAnnotationIndex map[string]int + TextAnnotationIndexMapper map[string]int +} + +func newStreamReceiver() *streamReceiver { + return &streamReceiver{ + ProcessingAssistantGenTextBlockIndex: map[string]map[int]bool{}, + MaxBlockIndex: -1, + IndexMapper: map[string]int{}, + MaxReasoningSummaryIndex: map[string]int{}, + ReasoningSummaryIndexMapper: map[string]int{}, + TextAnnotationIndexMapper: map[string]int{}, + MaxTextAnnotationIndex: map[string]int{}, + } +} + +func (r *streamReceiver) getBlockIndex(key string) int { + if idx, ok := r.IndexMapper[key]; ok { + return idx + } + + r.MaxBlockIndex++ + r.IndexMapper[key] = r.MaxBlockIndex + + return r.MaxBlockIndex +} + +func (r *streamReceiver) getReasoningSummaryIndex(outputIdx, summaryIdx int64) int { + maxSummaryIndex := -1 + if idx, ok := r.MaxReasoningSummaryIndex[int64ToStr(outputIdx)]; ok { + maxSummaryIndex = idx + } + + idxKey := fmt.Sprintf("%d:%d", outputIdx, summaryIdx) + if idx, ok := r.ReasoningSummaryIndexMapper[idxKey]; ok { + return idx + } + + maxSummaryIndex++ + r.ReasoningSummaryIndexMapper[idxKey] = maxSummaryIndex + r.MaxReasoningSummaryIndex[int64ToStr(outputIdx)] = maxSummaryIndex + + return maxSummaryIndex +} + +func (r *streamReceiver) getTextAnnotationIndex(outputIdx, contentIdx, annotationIdx int64) int { + maxAnnotationIndex := -1 + + maxIdxKey := fmt.Sprintf("%d:%d", outputIdx, contentIdx) + if idx, ok := r.MaxTextAnnotationIndex[maxIdxKey]; ok { + maxAnnotationIndex = idx + } + + idxKey := fmt.Sprintf("%d:%d:%d", outputIdx, contentIdx, annotationIdx) + if idx, ok := r.TextAnnotationIndexMapper[idxKey]; ok { + return idx + } + + maxAnnotationIndex++ + r.TextAnnotationIndexMapper[idxKey] = maxAnnotationIndex + r.MaxTextAnnotationIndex[maxIdxKey] = maxAnnotationIndex + + return maxAnnotationIndex +} + +func (r *streamReceiver) errorEventToResponseMeta(ev *responses.ErrorEvent) *schema.AgenticResponseMeta { + return &schema.AgenticResponseMeta{ + Extension: &ResponseMetaExtension{ + StreamingError: &StreamingResponseError{ + Code: ev.GetCode(), + Message: ev.GetMessage(), + Param: ev.GetParam(), + }, + }, + } +} + +func (r *streamReceiver) itemAddedEventToContentBlock(ev *responses.ItemEvent) (blocks []*schema.ContentBlock, err error) { + switch item := ev.Item.Union.(type) { + case *responses.OutputItem_FunctionToolCall: + block, err := r.itemAddedEventFunctionToolCallToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_Reasoning: + block, err := r.itemAddedEventReasoningToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_OutputMessage, + *responses.OutputItem_FunctionWebSearch, + *responses.OutputItem_FunctionMcpListTools, + *responses.OutputItem_FunctionMcpApprovalRequest, + *responses.OutputItem_FunctionMcpCall: + + // Do nothing. + + default: + return nil, fmt.Errorf("invalid item type %T with 'output_item.added' event", item) + } + + return blocks, nil +} + +func (r *streamReceiver) itemAddedEventFunctionToolCallToContentBlock(outputIdx int64, item *responses.OutputItem_FunctionToolCall) (block *schema.ContentBlock, err error) { + block, err = functionToolCallToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemAddedEventReasoningToContentBlock(outputIdx int64, item *responses.OutputItem_Reasoning) (block *schema.ContentBlock, err error) { + block, err = reasoningToContentBlocks(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventToContentBlocks(ev *responses.ItemDoneEvent) (blocks []*schema.ContentBlock, err error) { + switch item := ev.Item.Union.(type) { + case *responses.OutputItem_OutputMessage: + blocks, err = r.itemDoneEventOutputMessageToContentBlock(item) + if err != nil { + return nil, err + } + + case *responses.OutputItem_Reasoning: + block, err := r.itemDoneEventReasoningToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_FunctionToolCall: + block, err := r.itemDoneEventFunctionToolCallToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_FunctionWebSearch: + block, err := r.itemDoneEventFunctionWebSearchToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_FunctionMcpCall: + blocks, err = r.itemDoneEventFunctionMCPCallToContentBlocks(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + case *responses.OutputItem_FunctionMcpListTools: + block, err := r.itemDoneEventFunctionMCPListToolsToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case *responses.OutputItem_FunctionMcpApprovalRequest: + block, err := r.itemDoneEventFunctionMCPApprovalRequestToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + default: + return nil, fmt.Errorf("invalid item type %T with 'output_item.done' event", item) + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventOutputMessageToContentBlock(item *responses.OutputItem_OutputMessage) (blocks []*schema.ContentBlock, err error) { + msg := item.OutputMessage + if msg == nil { + return nil, fmt.Errorf("received empty output message") + } + + indices, ok := r.ProcessingAssistantGenTextBlockIndex[msg.Id] + if !ok { + return nil, fmt.Errorf("item %q not found in processing queue", msg.Id) + } + + for idx := range indices { + meta := &schema.StreamingMeta{Index: idx} + block := schema.NewContentBlockChunk(&schema.AssistantGenText{}, meta) + setItemID(block, msg.Id) + setItemStatus(block, msg.Status.String()) + + blocks = append(blocks, block) + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventReasoningToContentBlock(outputIdx int64, item *responses.OutputItem_Reasoning) (block *schema.ContentBlock, err error) { + reasoning := item.Reasoning + if reasoning == nil { + return nil, fmt.Errorf("received empty reasoning") + } + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(outputIdx)), + } + block = schema.NewContentBlockChunk(&schema.Reasoning{}, meta) + + if reasoning.Id != nil { + setItemID(block, *reasoning.Id) + } + setItemStatus(block, reasoning.Status.String()) + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionToolCallToContentBlock(outputIdx int64, item *responses.OutputItem_FunctionToolCall) (block *schema.ContentBlock, err error) { + toolCall := item.FunctionToolCall + if toolCall == nil { + return nil, fmt.Errorf("received empty function tool call") + } + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(outputIdx)), + } + block = schema.NewContentBlockChunk(&schema.FunctionToolCall{ + CallID: toolCall.CallId, + Name: toolCall.Name, + }, meta) + + if toolCall.Id != nil { + setItemID(block, *toolCall.Id) + } + if toolCall.Status != nil { + setItemStatus(block, toolCall.Status.String()) + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionWebSearchToContentBlock(outputIdx int64, item *responses.OutputItem_FunctionWebSearch) (block *schema.ContentBlock, err error) { + block, err = webSearchToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeServerToolCallIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPCallToContentBlocks(outputIdx int64, item *responses.OutputItem_FunctionMcpCall) (blocks []*schema.ContentBlock, err error) { + blocks, err = mcpCallToContentBlocks(item) + if err != nil { + return nil, err + } + + for _, block := range blocks { + switch block.Type { + case schema.ContentBlockTypeMCPToolCall: + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(outputIdx)), + } + case schema.ContentBlockTypeMCPToolResult: + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolResultIndexKey(outputIdx)), + } + default: + return nil, fmt.Errorf("expected mcp tool call or result block, but got '%s'", block.Type) + } + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPListToolsToContentBlock(outputIdx int64, item *responses.OutputItem_FunctionMcpListTools) (block *schema.ContentBlock, err error) { + block, err = mcpListToolsToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPListToolsResultIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPApprovalRequestToContentBlock(outputIdx int64, item *responses.OutputItem_FunctionMcpApprovalRequest) (block *schema.ContentBlock, err error) { + block, err = mcpApprovalRequestToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolApprovalRequestIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) contentPartAddedEventToContentBlock(ev *responses.ContentPartEvent) (block *schema.ContentBlock, err error) { + key := makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex) + blockIdx := r.getBlockIndex(key) + + indices, ok := r.ProcessingAssistantGenTextBlockIndex[ev.ItemId] + if !ok { + indices = map[int]bool{} + r.ProcessingAssistantGenTextBlockIndex[ev.ItemId] = indices + } + + indices[blockIdx] = true + + return r.eventContentPartToContentBlock(ev.ItemId, ev.Part, blockIdx, responses.ItemStatus_in_progress) +} + +func (r *streamReceiver) contentPartDoneEventToContentBlock(ev *responses.ContentPartDoneEvent) (block *schema.ContentBlock, err error) { + key := makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex) + blockIdx := r.getBlockIndex(key) + + indices, ok := r.ProcessingAssistantGenTextBlockIndex[ev.ItemId] + if !ok { + return nil, fmt.Errorf("item %s has no processing assistant gen text block index", ev.ItemId) + } + + delete(indices, blockIdx) + + return r.eventContentPartToContentBlock(ev.ItemId, ev.Part, blockIdx, responses.ItemStatus_completed) +} + +func (r *streamReceiver) eventContentPartToContentBlock(itemID string, content *responses.OutputContentItem, + blockIdx int, status responses.ItemStatus_Enum) (block *schema.ContentBlock, err error) { + + meta := &schema.StreamingMeta{Index: blockIdx} + + switch part := content.Union.(type) { + case *responses.OutputContentItem_Text: + block = schema.NewContentBlockChunk(&schema.AssistantGenText{}, meta) + default: + return nil, fmt.Errorf("invalid content part type: %T", part) + } + + setItemStatus(block, status.String()) + setItemID(block, itemID) + + return block, nil +} + +func (r *streamReceiver) outputTextDeltaEventToContentBlock(ev *responses.OutputTextEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex)), + } + block := schema.NewContentBlockChunk(&schema.AssistantGenText{ + Text: ev.GetDelta(), + }, meta) + + setItemID(block, ev.ItemId) + + return block +} + +func (r *streamReceiver) annotationAddedEventToContentBlock(ev *responses.ResponseAnnotationAddedEvent) (block *schema.ContentBlock, err error) { + annotation, err := outputTextAnnotationToTextAnnotation(ev.Annotation) + if err != nil { + return nil, fmt.Errorf("failed to convert annotation, err: %w", err) + } + + annotation.Index = r.getTextAnnotationIndex(ev.OutputIndex, ev.ContentIndex, ev.AnnotationIndex) + + genText := &schema.AssistantGenText{ + Text: ev.GetDelta(), + Extension: &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{annotation}, + }, + } + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex)), + } + block = schema.NewContentBlockChunk(genText, meta) + + setItemID(block, ev.ItemId) + + return block, nil +} + +func (r *streamReceiver) reasoningSummaryTextDeltaEventToContentBlock(ev *responses.ReasoningSummaryTextEvent) *schema.ContentBlock { + reasoning := &schema.Reasoning{ + Summary: []*schema.ReasoningSummary{ + { + Index: r.getReasoningSummaryIndex(ev.OutputIndex, ev.SummaryIndex), + Text: ev.GetDelta(), + }, + }, + } + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(reasoning, meta) + + setItemID(block, ev.ItemId) + + return block +} + +func (r *streamReceiver) functionCallArgumentsDeltaEventToContentBlock(ev *responses.FunctionCallArgumentsEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(&schema.FunctionToolCall{ + Arguments: ev.GetDelta(), + }, meta) + + setItemID(block, ev.ItemId) + + return block +} + +func (r *streamReceiver) mcpListToolsPhaseToContentBlock(itemID string, outputIdx int64, status responses.ItemStatus_Enum) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPListToolsResultIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.MCPListToolsResult{}, meta) + + setItemID(block, itemID) + setItemStatus(block, status.String()) + + return block +} + +func (r *streamReceiver) mcpApprovalRequestEventToContentBlock(ev *responses.ResponseMcpApprovalRequestEvent) (block *schema.ContentBlock, err error) { + apReq := ev.FunctionMcpApprovalRequest + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolApprovalRequestIndexKey(ev.OutputIndex)), + } + block = schema.NewContentBlockChunk(&schema.MCPToolApprovalRequest{ + ID: apReq.GetId(), + Name: apReq.Name, + Arguments: apReq.Arguments, + ServerLabel: apReq.ServerLabel, + }, meta) + + setItemID(block, apReq.GetId()) + + return block, nil +} + +func (r *streamReceiver) mcpCallArgumentsDeltaEventToContentBlock(ev *responses.ResponseMcpCallArgumentsDeltaEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(&schema.MCPToolCall{ + Arguments: ev.Delta, + }, meta) + + setItemID(block, ev.ItemId) + + return block +} + +func (r *streamReceiver) mcpCallPhaseToContentBlock(itemID string, outputIdx int64, status responses.ItemStatus_Enum) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.MCPToolCall{}, meta) + + setItemID(block, itemID) + setItemStatus(block, status.String()) + return block +} + +func (r *streamReceiver) webSearchPhaseToContentBlock(itemID string, outputIdx int64, status responses.ItemStatus_Enum) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeServerToolCallIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.ServerToolCall{}, meta) + + setItemID(block, itemID) + setItemStatus(block, status.String()) + + return block +} + +func makeAssistantGenTextIndexKey(outputIndex, contentIndex int64) string { + return fmt.Sprintf("assistant_gen_text:%d:%d", outputIndex, contentIndex) +} + +func makeReasoningIndexKey(outputIndex int64) string { + return fmt.Sprintf("reasoning:%d", outputIndex) +} + +func makeFunctionToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("function_tool_call:%d", outputIndex) +} + +func makeServerToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("server_tool_call:%d", outputIndex) +} + +func makeMCPListToolsResultIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_list_tools_result:%d", outputIndex) +} + +func makeMCPToolApprovalRequestIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_approval_request:%d", outputIndex) +} + +func makeMCPToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_call:%d", outputIndex) +} + +func makeMCPToolResultIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_result:%d", outputIndex) +} diff --git a/components/agentic/ark/event_convertor_test.go b/components/agentic/ark/event_convertor_test.go new file mode 100644 index 000000000..cef7cb90f --- /dev/null +++ b/components/agentic/ark/event_convertor_test.go @@ -0,0 +1,534 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "errors" + "io" + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" +) + +func TestNewStreamReceiverInit(t *testing.T) { + r := newStreamReceiver() + assert.NotNil(t, r.ProcessingAssistantGenTextBlockIndex) + assert.Equal(t, -1, r.MaxBlockIndex) + assert.NotNil(t, r.IndexMapper) + assert.NotNil(t, r.MaxReasoningSummaryIndex) + assert.NotNil(t, r.ReasoningSummaryIndexMapper) + assert.NotNil(t, r.TextAnnotationIndexMapper) + assert.NotNil(t, r.MaxTextAnnotationIndex) +} + +func TestGetBlockIndexAndReuse(t *testing.T) { + r := newStreamReceiver() + a := r.getBlockIndex("k1") + b := r.getBlockIndex("k2") + c := r.getBlockIndex("k1") + assert.Equal(t, a, c) + assert.NotEqual(t, a, b) + assert.GreaterOrEqual(t, r.MaxBlockIndex, 1) +} + +func TestGetReasoningSummaryIndex(t *testing.T) { + r := newStreamReceiver() + i1 := r.getReasoningSummaryIndex(1, 1) + i2 := r.getReasoningSummaryIndex(1, 2) + i3 := r.getReasoningSummaryIndex(2, 1) + i4 := r.getReasoningSummaryIndex(1, 1) + assert.Equal(t, 0, i1) + assert.Equal(t, 1, i2) + assert.Equal(t, 0, i3) + assert.Equal(t, i1, i4) +} + +func TestGetTextAnnotationIndex(t *testing.T) { + r := newStreamReceiver() + i1 := r.getTextAnnotationIndex(1, 1, 1) + i2 := r.getTextAnnotationIndex(1, 1, 2) + i3 := r.getTextAnnotationIndex(1, 2, 1) + i4 := r.getTextAnnotationIndex(2, 1, 1) + i5 := r.getTextAnnotationIndex(1, 1, 1) + assert.Equal(t, 0, i1) + assert.Equal(t, 1, i2) + assert.Equal(t, 0, i3) + assert.Equal(t, 0, i4) + assert.Equal(t, i1, i5) +} + +func TestItemAddedEventToContentBlockFunctionToolCall(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemEvent{ + OutputIndex: 1, + Item: &responses.OutputItem{ + Union: &responses.OutputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + CallId: "cid", + Name: "name", + }, + }, + }, + } + blocks, err := r.itemAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].FunctionToolCall) + assert.GreaterOrEqual(t, blocks[0].StreamingMeta.Index, 0) +} + +func TestItemAddedEventToContentBlockReasoning(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemEvent{ + OutputIndex: 2, + Item: &responses.OutputItem{ + Union: &responses.OutputItem_Reasoning{ + Reasoning: &responses.ItemReasoning{ + Status: responses.ItemStatus_completed, + Summary: []*responses.ReasoningSummaryPart{ + {Text: "x"}, + }, + }, + }, + }, + } + blocks, err := r.itemAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].Reasoning) +} + +func TestItemAddedEventToContentBlockInvalid(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemEvent{ + Item: &responses.OutputItem{Union: nil}, + } + _, err := r.itemAddedEventToContentBlock(ev) + assert.Error(t, err) +} + +func TestItemDoneEventToContentBlocksReasoning(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemDoneEvent{ + OutputIndex: 3, + Item: &responses.OutputItem{ + Union: &responses.OutputItem_Reasoning{ + Reasoning: &responses.ItemReasoning{ + Status: responses.ItemStatus_completed, + }, + }, + }, + } + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].Reasoning) +} + +func TestItemDoneEventToContentBlocksFunctionToolCall(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemDoneEvent{ + OutputIndex: 4, + Item: &responses.OutputItem{ + Union: &responses.OutputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + CallId: "cid", + Name: "nm", + }, + }, + }, + } + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].FunctionToolCall) +} + +func TestItemDoneEventToContentBlocksInvalid(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ItemDoneEvent{ + Item: &responses.OutputItem{Union: nil}, + } + _, err := r.itemDoneEventToContentBlocks(ev) + assert.Error(t, err) +} + +func TestItemDoneEventOutputMessageToContentBlockMissingProcessing(t *testing.T) { + r := newStreamReceiver() + ev := &responses.OutputItem_OutputMessage{ + OutputMessage: &responses.ItemOutputMessage{ + Id: "mid", + Status: responses.ItemStatus_completed, + }, + } + _, err := r.itemDoneEventOutputMessageToContentBlock(ev) + assert.Error(t, err) +} + +func TestItemDoneEventOutputMessageToContentBlockOK(t *testing.T) { + r := newStreamReceiver() + r.ProcessingAssistantGenTextBlockIndex["mid"] = map[int]bool{0: true, 2: true} + ev := &responses.OutputItem_OutputMessage{ + OutputMessage: &responses.ItemOutputMessage{ + Id: "mid", + Status: responses.ItemStatus_completed, + }, + } + blocks, err := r.itemDoneEventOutputMessageToContentBlock(ev) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + id, ok := getItemID(blocks[0]) + assert.True(t, ok) + assert.Equal(t, "mid", id) + status, ok := GetItemStatus(blocks[0]) + assert.True(t, ok) + assert.Equal(t, responses.ItemStatus_completed.String(), status) +} + +func TestItemDoneEventReasoningToContentBlockNilError(t *testing.T) { + r := newStreamReceiver() + _, err := r.itemDoneEventReasoningToContentBlock(1, &responses.OutputItem_Reasoning{}) + assert.Error(t, err) +} + +func TestItemDoneEventFunctionToolCallToContentBlockNilError(t *testing.T) { + r := newStreamReceiver() + _, err := r.itemDoneEventFunctionToolCallToContentBlock(1, &responses.OutputItem_FunctionToolCall{}) + assert.Error(t, err) +} + +func TestItemDoneEventFunctionWebSearchToContentBlock(t *testing.T) { + r := newStreamReceiver() + block, err := r.itemDoneEventFunctionWebSearchToContentBlock(1, &responses.OutputItem_FunctionWebSearch{ + FunctionWebSearch: &responses.ItemFunctionWebSearch{ + Id: "id", + Status: responses.ItemStatus_completed, + Action: &responses.Action{ + Type: responses.ActionType_search, + Query: "q", + }, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, block.ServerToolCall) +} + +func TestItemDoneEventFunctionMCPCallToContentBlocksTypeCheck(t *testing.T) { + r := newStreamReceiver() + blocks, err := r.itemDoneEventFunctionMCPCallToContentBlocks(1, &responses.OutputItem_FunctionMcpCall{ + FunctionMcpCall: &responses.ItemFunctionMcpCall{ + Id: ptrOf("id"), + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + Output: ptrOf("out"), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + assert.NotNil(t, blocks[0].StreamingMeta) + assert.GreaterOrEqual(t, blocks[0].StreamingMeta.Index, 0) + assert.NotNil(t, blocks[1].StreamingMeta) + assert.GreaterOrEqual(t, blocks[1].StreamingMeta.Index, 0) +} + +func TestItemDoneEventFunctionMCPListToolsToContentBlock(t *testing.T) { + r := newStreamReceiver() + block, err := r.itemDoneEventFunctionMCPListToolsToContentBlock(1, &responses.OutputItem_FunctionMcpListTools{ + FunctionMcpListTools: &responses.ItemFunctionMcpListTools{ + ServerLabel: "server", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, block.MCPListToolsResult) +} + +func TestItemDoneEventFunctionMCPApprovalRequestToContentBlock(t *testing.T) { + r := newStreamReceiver() + block, err := r.itemDoneEventFunctionMCPApprovalRequestToContentBlock(1, &responses.OutputItem_FunctionMcpApprovalRequest{ + FunctionMcpApprovalRequest: &responses.ItemFunctionMcpApprovalRequest{ + Id: ptrOf("id"), + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + }, + }) + assert.NoError(t, err) + assert.NotNil(t, block.MCPToolApprovalRequest) +} + +func TestContentPartAddedEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ContentPartEvent{ + ItemId: "mid", + OutputIndex: 1, + ContentIndex: 2, + Part: &responses.OutputContentItem{ + Union: &responses.OutputContentItem_Text{Text: &responses.OutputContentItemText{}}, + }, + } + block, err := r.contentPartAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.NotNil(t, block.AssistantGenText) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "mid", id) +} + +func TestContentPartDoneEventToContentBlockNoIndex(t *testing.T) { + r := newStreamReceiver() + ev := &responses.ContentPartDoneEvent{ + ItemId: "mid", + OutputIndex: 1, + ContentIndex: 2, + Part: &responses.OutputContentItem{ + Union: &responses.OutputContentItem_Text{Text: &responses.OutputContentItemText{}}, + }, + } + _, err := r.contentPartDoneEventToContentBlock(ev) + assert.Error(t, err) +} + +func TestContentPartDoneEventToContentBlockOK(t *testing.T) { + r := newStreamReceiver() + evAdd := &responses.ContentPartEvent{ + ItemId: "mid", + OutputIndex: 1, + ContentIndex: 1, + Part: &responses.OutputContentItem{ + Union: &responses.OutputContentItem_Text{Text: &responses.OutputContentItemText{}}, + }, + } + _, _ = r.contentPartAddedEventToContentBlock(evAdd) + evDone := &responses.ContentPartDoneEvent{ + ItemId: "mid", + OutputIndex: 1, + ContentIndex: 1, + Part: &responses.OutputContentItem{ + Union: &responses.OutputContentItem_Text{Text: &responses.OutputContentItemText{}}, + }, + } + block, err := r.contentPartDoneEventToContentBlock(evDone) + assert.NoError(t, err) + assert.NotNil(t, block.AssistantGenText) + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, responses.ItemStatus_completed.String(), status) +} + +func TestEventContentPartToContentBlockInvalid(t *testing.T) { + r := newStreamReceiver() + _, err := r.eventContentPartToContentBlock("id", &responses.OutputContentItem{}, 1, responses.ItemStatus_in_progress) + assert.Error(t, err) +} + +func TestOutputTextDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.outputTextDeltaEventToContentBlock(&responses.OutputTextEvent{ + Delta: ptrOf("d"), + ItemId: "iid", + OutputIndex: 1, + ContentIndex: 1, + }) + assert.NotNil(t, block.AssistantGenText) + assert.Equal(t, "d", block.AssistantGenText.Text) +} + +func TestAnnotationAddedEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + title := "t" + url := "u" + block, err := r.annotationAddedEventToContentBlock(&responses.ResponseAnnotationAddedEvent{ + ItemId: "iid", + OutputIndex: 1, + ContentIndex: 1, + AnnotationIndex: 0, + Annotation: &responses.Annotation{ + Type: responses.AnnotationType_url_citation, + Title: title, + Url: url, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, block.AssistantGenText) + assert.NotNil(t, block.AssistantGenText.Extension) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestReasoningSummaryTextDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.reasoningSummaryTextDeltaEventToContentBlock(&responses.ReasoningSummaryTextEvent{ + ItemId: "iid", + OutputIndex: 2, + SummaryIndex: 0, + Delta: ptrOf("x"), + }) + assert.NotNil(t, block.Reasoning) + assert.Equal(t, "x", block.Reasoning.Summary[0].Text) +} + +func TestFunctionCallArgumentsDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.functionCallArgumentsDeltaEventToContentBlock(&responses.FunctionCallArgumentsEvent{ + ItemId: "iid", + OutputIndex: 3, + Delta: ptrOf("{}"), + }) + assert.NotNil(t, block.FunctionToolCall) + assert.Equal(t, "{}", block.FunctionToolCall.Arguments) +} + +func TestMcpListToolsPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.mcpListToolsPhaseToContentBlock("iid", 4, responses.ItemStatus_in_progress) + assert.NotNil(t, block.MCPListToolsResult) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, responses.ItemStatus_in_progress.String(), status) +} + +func TestMcpCallArgumentsDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.mcpCallArgumentsDeltaEventToContentBlock(&responses.ResponseMcpCallArgumentsDeltaEvent{ + ItemId: "iid", + OutputIndex: 6, + Delta: "{}", + }) + assert.NotNil(t, block.MCPToolCall) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestMcpCallPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.mcpCallPhaseToContentBlock("iid", 7, responses.ItemStatus_failed) + assert.NotNil(t, block.MCPToolCall) + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, responses.ItemStatus_failed.String(), status) +} + +func TestWebSearchPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + block := r.webSearchPhaseToContentBlock("iid", 8, responses.ItemStatus_completed) + assert.NotNil(t, block.ServerToolCall) + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, responses.ItemStatus_completed.String(), status) +} + +func TestMakeIndexKeyFunctions(t *testing.T) { + assert.Equal(t, "assistant_gen_text:1:2", makeAssistantGenTextIndexKey(1, 2)) + assert.Equal(t, "reasoning:3", makeReasoningIndexKey(3)) + assert.Equal(t, "function_tool_call:4", makeFunctionToolCallIndexKey(4)) + assert.Equal(t, "server_tool_call:5", makeServerToolCallIndexKey(5)) + assert.Equal(t, "mcp_list_tools_result:6", makeMCPListToolsResultIndexKey(6)) + assert.Equal(t, "mcp_tool_approval_request:7", makeMCPToolApprovalRequestIndexKey(7)) + assert.Equal(t, "mcp_tool_call:8", makeMCPToolCallIndexKey(8)) + assert.Equal(t, "mcp_tool_result:9", makeMCPToolResultIndexKey(9)) +} + +func TestNewCallbackSenderAndSend(t *testing.T) { + sr, sw := schema.Pipe[*agentic.CallbackOutput](8) + s := newCallbackSender(sw, &agentic.Config{}) + r0 := sr.Copy(1)[0] + + // Send a meta message first + s.sendMeta(&schema.AgenticResponseMeta{}, nil) + ch, err := r0.Recv() + assert.NoError(t, err) + assert.NotNil(t, ch) + assert.NotNil(t, ch.Message.ResponseMeta) + + // Send a block + block := schema.NewContentBlock(&schema.AssistantGenText{Text: "x"}) + s.sendBlock(block, nil) + ch, err = r0.Recv() + assert.NoError(t, err) + assert.NotNil(t, ch.Message.ContentBlocks) + + // Send an error + s.errHeader = "h" + s.sendMeta(nil, errors.New("e")) + _, err = r0.Recv() + assert.Error(t, err) +} + +func TestReceivedStreamResponseResponseAndText(t *testing.T) { + mockey.PatchConvey("TestReceivedStreamResponseResponseAndText", t, func() { + reader, writer := schema.Pipe[*agentic.CallbackOutput](16) + var rr utils.ResponsesStreamReader + call := 0 + mockey.Mock((*utils.ResponsesStreamReader).Recv).To(func(_ *utils.ResponsesStreamReader) (*responses.Event, error) { + call++ + if call == 1 { + return &responses.Event{ + Event: &responses.Event_Response{ + Response: &responses.ResponseEvent{ + Response: &responses.ResponseObject{Id: "rid"}, + }, + }, + }, nil + } + if call == 2 { + return &responses.Event{ + Event: &responses.Event_Text{ + Text: &responses.OutputTextEvent{ + Delta: ptrOf("d"), + ItemId: "iid", + OutputIndex: 1, + ContentIndex: 1, + }, + }, + }, nil + } + return nil, io.EOF + }).Build() + receivedStreamResponse(&rr, &agentic.Config{}, writer) + r := reader.Copy(1)[0] + out1, err1 := r.Recv() + assert.NoError(t, err1) + assert.NotNil(t, out1) + assert.NotNil(t, out1.Message.ResponseMeta) + out2, err2 := r.Recv() + assert.NoError(t, err2) + assert.NotNil(t, out2.Message.ContentBlocks) + }) +} + +func TestReceivedStreamResponseRecvError(t *testing.T) { + mockey.PatchConvey("TestReceivedStreamResponseRecvError", t, func() { + reader, writer := schema.Pipe[*agentic.CallbackOutput](4) + var rr utils.ResponsesStreamReader + mockey.Mock((*utils.ResponsesStreamReader).Recv).Return(nil, errors.New("x")).Build() + mockey.Mock((*utils.ResponsesStreamReader).Close).Return(nil).Build() + receivedStreamResponse(&rr, &agentic.Config{}, writer) + _, err := reader.Copy(1)[0].Recv() + assert.Error(t, err) + }) +} diff --git a/components/agentic/ark/examples/generate/main.go b/components/agentic/ark/examples/generate/main.go new file mode 100644 index 000000000..73a334157 --- /dev/null +++ b/components/agentic/ark/examples/generate/main.go @@ -0,0 +1,82 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + am_, err := am.WithTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model with tools, err: %v", err) + } + + msg, err := am_.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/ark/examples/prefix_cache/main.go b/components/agentic/ark/examples/prefix_cache/main.go new file mode 100644 index 000000000..d6d407a8b --- /dev/null +++ b/components/agentic/ark/examples/prefix_cache/main.go @@ -0,0 +1,166 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "encoding/json" + "log" + "os" + "time" + + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + APIKey: os.Getenv("ARK_API_KEY"), + Model: os.Getenv("ARK_MODEL_ID"), + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }, + }) + if err != nil { + log.Fatalf("NewChatModel failed, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "article_content_extractor", + Desc: "Extract key statements and chapter summaries from the provided article content", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "content": { + Type: schema.String, + Desc: "The full article content to analyze and extract key information from", + Required: true, + }, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "article_content_extractor", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + expireAtSec := time.Now().Add(10 * time.Minute).Unix() + + prefix := []*schema.AgenticMessage{ + schema.SystemAgenticMessage(`Once upon a time, in a quaint little village surrounded by vast green forests and blooming meadows, there lived a spirited young girl known as Little Red Riding Hood. She earned her name from the vibrant red cape that her beloved grandmother had sewn for her, a gift that she cherished deeply. This cape was more than just a piece of clothing; it was a symbol of the bond between her and her grandmother, who lived on the other side of the great woods, near a sparkling brook that bubbled merrily all year round. + + One sunny morning, Little Red Riding Hood's mother called her into the cozy kitchen, where the aroma of freshly baked bread filled the air. “My dear,” she said, “your grandmother isn’t feeling well today. I want you to take her this basket of treats. There are some delicious cakes, a jar of honey, and her favorite herbal tea. Can you do that for me?” + + Little Red Riding Hood’s eyes sparkled with excitement as she nodded eagerly. “Yes, Mama! I’ll take good care of them!” Her mother handed her a beautifully woven basket, filled to the brim with goodies, and reminded her, “Remember to stay on the path and don’t talk to strangers.” + + “I promise, Mama!” she replied confidently, pulling her red hood over her head and setting off on her adventure. The sun shone brightly, and birds chirped merrily as she walked, making her feel like she was in a fairy tale. + + As she journeyed through the woods, the tall trees whispered secrets to one another, and colorful flowers danced in the gentle breeze. Little Red Riding Hood was so enchanted by the beauty around her that she began to hum a tune, her voice harmonizing with the sounds of nature. + + However, unbeknownst to her, lurking in the shadows was a cunning wolf. The wolf was known throughout the forest for his deceptive wit and insatiable hunger. He watched Little Red Riding Hood with keen interest, contemplating his next meal. + + “Good day, little girl!” the wolf called out, stepping onto the path with a friendly yet sly smile. + + Startled, she halted and took a step back. “Hello there! I’m just on my way to visit my grandmother,” she replied, clutching the basket tightly. + + “Ah, your grandmother! I know her well,” the wolf said, his eyes glinting with mischief. “Why don’t you pick some lovely flowers for her? I’m sure she would love them, and I’m sure there are many beautiful ones just off the path.” + + Little Red Riding Hood hesitated for a moment but was easily convinced by the wolf’s charming suggestion. “That’s a wonderful idea! Thank you!” she exclaimed, letting her curiosity pull her away from the safety of the path. As she wandered deeper into the woods, her gaze fixed on the vibrant blooms, the wolf took a shortcut towards her grandmother’s house. + + When the wolf arrived at Grandma’s quaint cottage, he knocked on the door with a confident swagger. “It’s me, Little Red Riding Hood!” he shouted in a high-pitched voice to mimic the girl. + + “Come in, dear!” came the frail voice of the grandmother, who had been resting on her cozy bed, wrapped in warm blankets. The wolf burst through the door, his eyes gleaming with the thrill of his plan. + + With astonishing speed, the wolf gulped down the unsuspecting grandmother whole. Afterward, he dressed in her nightgown, donning her nightcap and climbing into her bed. He lay there, waiting for Little Red Riding Hood to arrive, concealing his wicked smile behind a facade of innocence. + + Meanwhile, Little Red Riding Hood was merrily picking flowers, completely unaware of the impending danger. After gathering a beautiful bouquet of wildflowers, she finally made her way back to the path and excitedly skipped towards her grandmother’s cottage. + + Upon arriving, she noticed the door was slightly ajar. “Grandmother, it’s me!” she called out, entering the dimly lit home. It was silent, with only the faint sound of an old clock ticking in the background. She stepped into the small living room, a feeling of unease creeping over her. + + “Grandmother, are you here?” she asked, peeking into the bedroom. There, she saw a figure lying under the covers. + + “Grandmother, what big ears you have!” she exclaimed, taking a few cautious steps closer. + + “All the better to hear you with, my dear,” the wolf replied in a voice that was deceptively sweet. + + “Grandmother, what big eyes you have!” Little Red Riding Hood continued, now feeling an unsettling chill in the air. + + “All the better to see you with, my dear,” the wolf said, his eyes narrowing as he tried to contain his glee. + + “Grandmother, what big teeth you have!” she exclaimed, the terror flooding her senses as she began to realize this was no ordinary visit. + + “All the better to eat you with!” the wolf roared, springing out of the bed with startling speed. + + Just as the wolf lunged towards her, a brave woodsman, who had been passing by the cottage and heard the commotion, burst through the door. His strong presence was a beacon of hope in the dire situation. “Stay back, wolf!” he shouted with authority, brandishing his axe. + + The wolf, taken aback by the sudden intrusion, hesitated for a moment. Before he could react, the woodsman swung his axe with determination, and with a swift motion, he drove the wolf away, rescuing Little Red Riding Hood and her grandmother from certain doom. + + Little Red Riding Hood was shaking with fright, but relief washed over her as the woodsman helped her grandmother out from behind the bed where the wolf had hidden her. The grandmother, though shaken, was immensely grateful to the woodsman for his bravery. “Thank you so much! You saved us!” she cried, embracing him warmly. + + Little Red Riding Hood, still in shock but filled with gratitude, looked up at the woodsman and said, “I promise I will never stray from the path again. Thank you for being our hero!” + + From that day on, the woodland creatures spoke of the brave woodsman who saved Little Red Riding Hood and her grandmother. Little Red Riding Hood learned a valuable lesson about being cautious and listening to her mother’s advice. The bond between her and her grandmother grew stronger, and they often reminisced about that day’s adventure over cups of tea, surrounded by cookies and laughter. + + To ensure safety, Little Red Riding Hood always took extra precautions when traveling through the woods, carrying a small whistle her grandmother had given her. It would alert anyone nearby if she ever found herself in trouble again. + + And so, in the heart of that small village, life continued, filled with love, laughter, and the occasional adventure, as Little Red Riding Hood and her grandmother thrived, forever grateful for the friendship of the woodsman who had acted as their guardian that fateful day. + + And they all lived happily ever after. + + The end.`), + } + + // create response prefix cache, note: more than 1024 tokens are required, otherwise the prefix cache cannot be created + cacheInfo, err := am.CreatePrefixCache(ctx, prefix, &expireAtSec, opts...) + if err != nil { + log.Fatalf("CreatePrefixCache failed, err=%v", err) + } + + // use cache information in subsequent requests + cacheOpt := &ark.CacheOption{ + HeadPreviousResponseID: &cacheInfo.ResponseID, + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("What is the main idea expressed above?"), + } + + opts = append(opts, ark.WithCache(cacheOpt)) + outMsg, err := am.Generate(ctx, input, opts...) + if err != nil { + log.Fatalf("Generate failed, err=%v", err) + } + + meta := outMsg.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + log.Printf("request_id: %s\n", meta.ID) + + log.Printf("\ngenerate output: \n") + respBody, _ := json.MarshalIndent(outMsg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/ark/examples/session_cache/main.go b/components/agentic/ark/examples/session_cache/main.go new file mode 100644 index 000000000..454f51d96 --- /dev/null +++ b/components/agentic/ark/examples/session_cache/main.go @@ -0,0 +1,93 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "encoding/json" + "io" + "log" + "os" + "time" + + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + expireAtSec := time.Now().Add(10 * time.Minute).Unix() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + APIKey: os.Getenv("ARK_API_KEY"), + Model: os.Getenv("ARK_MODEL_ID"), + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }, + Cache: &ark.CacheConfig{ + SessionCache: &ark.SessionCacheConfig{ + EnableCache: true, + ExpireAtSec: expireAtSec, + }, + }, + }) + if err != nil { + log.Fatalf("failed to create chat model, err=%v", err) + } + + useMsgs := []*schema.AgenticMessage{ + schema.UserAgenticMessage("Your name is superman"), + schema.UserAgenticMessage("What's your name?"), + schema.UserAgenticMessage("What do I ask you last time?"), + } + + var input []*schema.AgenticMessage + for _, msg := range useMsgs { + input = append(input, msg) + + streamResp, err := am.Stream(ctx, input) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var messages []*schema.AgenticMessage + for { + chunk, err := streamResp.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Fatalf("failed to receive stream response, err: %v", err) + } + messages = append(messages, chunk) + } + + resp, err := schema.ConcatAgenticMessages(messages) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + jsonBody, _ := json.MarshalIndent(resp, " ", " ") + + log.Printf("stream output json: \n%v\n\n", string(jsonBody)) + + input = append(input, resp) + } +} diff --git a/components/agentic/ark/examples/stream_with_function_tool/main.go b/components/agentic/ark/examples/stream_with_function_tool/main.go new file mode 100644 index 000000000..9311619fb --- /dev/null +++ b/components/agentic/ark/examples/stream_with_function_tool/main.go @@ -0,0 +1,128 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + Thinking: &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + if lastBlock.Type != schema.ContentBlockTypeFunctionToolCall { + log.Fatalf("last block is not function tool call, type: %s", lastBlock.Type) + } + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/ark/examples/stream_with_mcp_tool/main.go b/components/agentic/ark/examples/stream_with_mcp_tool/main.go new file mode 100644 index 000000000..320276c7d --- /dev/null +++ b/components/agentic/ark/examples/stream_with_mcp_tool/main.go @@ -0,0 +1,112 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + mcpTools := []*responses.ToolMcp{ + { + Type: responses.ToolType_mcp, + ServerLabel: "test_mcp_server", + RequireApproval: &responses.McpRequireApproval{ + Union: &responses.McpRequireApproval_Mode{ + Mode: responses.ApprovalMode_never, + }, + }, + ServerUrl: "server url", + Headers: map[string]string{ + "X-API-KEY": "x-api-key if needed", + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + MCPTool: &schema.AllowedMCPTool{ + ServerLabel: "test_mcp_server", + Name: "amap/maps_weather", + }, + }, + } + + opts := []agentic.Option{ + ark.WithMCPTools(mcpTools), + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + ark.WithThinking(&responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }), + ark.WithCustomHeaders(map[string]string{ + "ark-beta-mcp": "true", + }), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/ark/examples/stream_with_sever_tool/main.go b/components/agentic/ark/examples/stream_with_sever_tool/main.go new file mode 100644 index 000000000..e555c7e45 --- /dev/null +++ b/components/agentic/ark/examples/stream_with_sever_tool/main.go @@ -0,0 +1,110 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/ark" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func main() { + ctx := context.Background() + + // Get ARK_API_KEY and ARK_MODEL_ID: https://www.volcengine.com/docs/82379/1399008 + am, err := ark.New(ctx, &ark.Config{ + Model: os.Getenv("ARK_MODEL_ID"), + APIKey: os.Getenv("ARK_API_KEY"), + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*ark.ServerToolConfig{ + { + WebSearch: &responses.ToolWebSearch{ + Type: responses.ToolType_web_search, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(ark.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + ark.WithServerTools(serverTools), + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + ark.WithThinking(&responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + }), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + meta := concatenated.ResponseMeta.Extension.(*ark.ResponseMetaExtension) + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall == nil { + continue + } + + serverToolArgs := block.ServerToolCall.Arguments.(*ark.ServerToolCallArguments) + + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/ark/extension.go b/components/agentic/ark/extension.go new file mode 100644 index 000000000..30f36ff62 --- /dev/null +++ b/components/agentic/ark/extension.go @@ -0,0 +1,231 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "fmt" + "sort" + + "github.com/cloudwego/eino/schema" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details"` + Error *ResponseError `json:"error"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Thinking *ResponseThinking `json:"thinking,omitempty"` + ExpireAt *int64 `json:"expire_at,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + + StreamingError *StreamingResponseError `json:"streaming_error,omitempty"` +} + +type AssistantGenTextExtension struct { + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ServerToolCallArguments struct { + WebSearch *WebSearchArguments `json:"web_search,omitempty"` +} + +func getResponseMeta(meta *schema.AgenticResponseMeta) *ResponseMetaExtension { + if meta == nil || meta.Extension == nil { + return nil + } + return meta.Extension.(*ResponseMetaExtension) +} + +func getServerToolCallArguments(call *schema.ServerToolCall) (*ServerToolCallArguments, error) { + if call == nil || call.Arguments == nil { + return nil, fmt.Errorf("server tool call arguments is nil") + } + arguments, ok := call.Arguments.(*ServerToolCallArguments) + if !ok { + return nil, fmt.Errorf("expected '*ServerToolCallArguments', but got '%T'", call.Arguments) + } + return arguments, nil +} + +type ResponseError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type StreamingResponseError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` + ContentFilter *ContentFilter `json:"content_filter,omitempty"` +} + +type ContentFilter struct { + Type string `json:"type,omitempty"` + Details string `json:"details,omitempty"` +} + +type ResponseThinking struct { + Type ThinkingType `json:"type,omitempty"` +} + +type WebSearchArguments struct { + ActionType WebSearchAction `json:"action_type,omitempty"` + + Search *WebSearchQuery `json:"search,omitempty"` +} + +type WebSearchQuery struct { + Query string `json:"query,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + URLCitation *URLCitation `json:"url_citation,omitempty"` + DocCitation *DocCitation `json:"doc_citation,omitempty"` +} + +type URLCitation struct { + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` + LogoURL string `json:"logo_url,omitempty"` + MobileURL string `json:"mobile_url,omitempty"` + SiteName string `json:"site_name,omitempty"` + PublishTime string `json:"publish_time,omitempty"` + CoverImage *CoverImage `json:"cover_image,omitempty"` + Summary string `json:"summary,omitempty"` + FreshnessInfo string `json:"freshness_info,omitempty"` +} + +type CoverImage struct { + URL string `json:"url,omitempty"` + Width *int64 `json:"width,omitempty"` + Height *int64 `json:"height,omitempty"` +} + +type DocCitation struct { + DocID string `json:"doc_id,omitempty"` + DocName string `json:"doc_name,omitempty"` + ChunkID *int32 `json:"chunk_id,omitempty"` + ChunkAttachment []map[string]any `json:"chunk_attachment,omitempty"` +} + +func concatResponseMetaExtensions(chunks []*ResponseMetaExtension) (ret *ResponseMetaExtension, err error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret = &ResponseMetaExtension{} + + for _, chunk := range chunks { + if chunk == nil { + continue + } + if chunk.ID != "" { + ret.ID = chunk.ID + } + if chunk.Status != "" { + ret.Status = chunk.Status + } + if chunk.IncompleteDetails != nil { + ret.IncompleteDetails = chunk.IncompleteDetails + } + if chunk.Error != nil { + ret.Error = chunk.Error + } + if chunk.PreviousResponseID != "" { + ret.PreviousResponseID = chunk.PreviousResponseID + } + if chunk.Thinking != nil { + ret.Thinking = chunk.Thinking + } + if chunk.ExpireAt != nil { + ret.ExpireAt = chunk.ExpireAt + } + if chunk.ServiceTier != "" { + ret.ServiceTier = chunk.ServiceTier + } + if chunk.StreamingError != nil { + ret.StreamingError = chunk.StreamingError + } + } + + return ret, nil +} + +func concatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (ret *AssistantGenTextExtension, err error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret = &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + return ret, nil +} + +func concatServerToolCallArguments(chunks []*ServerToolCallArguments) (ret *ServerToolCallArguments, err error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no server tool call arguments found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + return nil, fmt.Errorf("cannot concat multiple server tool call arguments") +} diff --git a/components/agentic/ark/extension_test.go b/components/agentic/ark/extension_test.go new file mode 100644 index 000000000..c803c1fe7 --- /dev/null +++ b/components/agentic/ark/extension_test.go @@ -0,0 +1,142 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "testing" + + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +func TestGetResponseMeta(t *testing.T) { + var nilMeta *schema.AgenticResponseMeta + assert.Nil(t, getResponseMeta(nilMeta)) + + metaWithoutExt := &schema.AgenticResponseMeta{} + assert.Nil(t, getResponseMeta(metaWithoutExt)) + + meta := &schema.AgenticResponseMeta{ + Extension: &ResponseMetaExtension{ + ID: "id", + Status: "ok", + }, + } + ext := getResponseMeta(meta) + assert.NotNil(t, ext) + assert.Equal(t, "id", ext.ID) + assert.Equal(t, ResponseStatus("ok"), ext.Status) +} + +func TestGetServerToolCallArguments(t *testing.T) { + args, err := getServerToolCallArguments(nil) + assert.Error(t, err) + assert.Nil(t, args) + + callWithNilArgs := &schema.ServerToolCall{} + args, err = getServerToolCallArguments(callWithNilArgs) + assert.Error(t, err) + assert.Nil(t, args) + + callWithWrongType := &schema.ServerToolCall{ + Arguments: struct{ X string }{X: "v"}, + } + args, err = getServerToolCallArguments(callWithWrongType) + assert.Error(t, err) + assert.Nil(t, args) + + expected := &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: WebSearchActionSearch, + Search: &WebSearchQuery{ + Query: "q", + }, + }, + } + callWithCorrectArgs := &schema.ServerToolCall{ + Arguments: expected, + } + args, err = getServerToolCallArguments(callWithCorrectArgs) + assert.NoError(t, err) + assert.Equal(t, expected, args) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + ret, err := concatResponseMetaExtensions(nil) + assert.Error(t, err) + assert.Nil(t, ret) + + one := &ResponseMetaExtension{ID: "id1"} + ret, err = concatResponseMetaExtensions([]*ResponseMetaExtension{one}) + assert.NoError(t, err) + assert.Equal(t, one, ret) + + id2 := &ResponseMetaExtension{ID: "id2"} + err2 := &ResponseError{Code: "c"} + meta1 := &ResponseMetaExtension{ + ID: "base", + Status: "s1", + IncompleteDetails: &IncompleteDetails{Reason: "r"}, + Error: err2, + } + meta2 := &ResponseMetaExtension{ + ID: id2.ID, + Status: "s2", + PreviousResponseID: "prev", + } + ret, err = concatResponseMetaExtensions([]*ResponseMetaExtension{meta1, meta2, nil}) + assert.NoError(t, err) + assert.Equal(t, meta2.ID, ret.ID) + assert.Equal(t, ResponseStatus("s2"), ret.Status) + assert.Equal(t, meta1.IncompleteDetails, ret.IncompleteDetails) + assert.Equal(t, err2, ret.Error) + assert.Equal(t, "prev", ret.PreviousResponseID) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + a0 := &TextAnnotation{Index: 0} + a1 := &TextAnnotation{Index: 1} + e0 := &AssistantGenTextExtension{Annotations: []*TextAnnotation{a0}} + e1 := &AssistantGenTextExtension{Annotations: []*TextAnnotation{a1}} + ret, err := concatAssistantGenTextExtensions([]*AssistantGenTextExtension{e0, e1}) + assert.NoError(t, err) + assert.Len(t, ret.Annotations, 2) + assert.Equal(t, &TextAnnotation{Index: 0}, ret.Annotations[0]) + assert.Equal(t, &TextAnnotation{Index: 0}, ret.Annotations[1]) + + dup := &TextAnnotation{Index: 0} + _, err = concatAssistantGenTextExtensions([]*AssistantGenTextExtension{ + {Annotations: []*TextAnnotation{a0}}, + {Annotations: []*TextAnnotation{dup}}, + }) + assert.Error(t, err) +} + +func TestConcatServerToolCallArguments(t *testing.T) { + ret, err := concatServerToolCallArguments(nil) + assert.Error(t, err) + assert.Nil(t, ret) + + one := &ServerToolCallArguments{} + ret, err = concatServerToolCallArguments([]*ServerToolCallArguments{one}) + assert.NoError(t, err) + assert.Equal(t, one, ret) + + two := &ServerToolCallArguments{} + _, err = concatServerToolCallArguments([]*ServerToolCallArguments{one, two}) + assert.Error(t, err) +} diff --git a/components/agentic/ark/go.mod b/components/agentic/ark/go.mod new file mode 100644 index 000000000..b158ae4d0 --- /dev/null +++ b/components/agentic/ark/go.mod @@ -0,0 +1,51 @@ +module github.com/cloudwego/eino-ext/components/agentic/ark + +go 1.18 + +require ( + github.com/bytedance/mockey v1.4.0 + github.com/bytedance/sonic v1.14.1 + github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31 + github.com/eino-contrib/jsonschema v1.0.3 + github.com/stretchr/testify v1.10.0 + github.com/volcengine/volcengine-go-sdk v1.2.4 + github.com/wk8/go-ordered-map/v2 v2.1.8 + golang.org/x/sync v0.8.0 + google.golang.org/protobuf v1.31.0 +) + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/volcengine/volc-sdk-golang v1.0.23 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.29.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/agentic/ark/go.sum b/components/agentic/ark/go.sum new file mode 100644 index 000000000..163c8738a --- /dev/null +++ b/components/agentic/ark/go.sum @@ -0,0 +1,216 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/mockey v1.4.0 h1:xwuZ3rr4mpbGkkBOYoSM+cO112dvzQ/sY0cVdP9FBSA= +github.com/bytedance/mockey v1.4.0/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31 h1:Vj2VKfW6A+FpzGdU4MJyIOEDVcI5Zyr0uEQanPa7PyE= +github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31/go.mod h1:OdDJi17QawFUJRIFrVJRgdc9grjrh3eFDD0k34ZRH8M= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8= +github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU= +github.com/volcengine/volcengine-go-sdk v1.2.4 h1:smBdDwwkXoXldCuumuZDJQASlAgVUIeL/RQ26D0OgI4= +github.com/volcengine/volcengine-go-sdk v1.2.4/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/components/agentic/ark/model.go b/components/agentic/ark/model.go new file mode 100644 index 000000000..8b3f09d41 --- /dev/null +++ b/components/agentic/ark/model.go @@ -0,0 +1,941 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "context" + "errors" + "fmt" + "net/http" + "runtime/debug" + "time" + + "github.com/bytedance/sonic" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" +) + +var _ agentic.Model = (*Model)(nil) + +type Config struct { + // Timeout specifies the maximum duration to wait for API responses. + // If HTTPClient is set, Timeout will not be used. + // Optional. + Timeout *time.Duration + + // HTTPClient specifies the HTTP client used to send requests. + // If HTTPClient is set, Timeout will not be used. + // Optional. Default: &http.Client{Timeout: Timeout} + HTTPClient *http.Client + + // RetryTimes specifies the number of retry attempts for failed API calls. + // Optional. + RetryTimes *int + + // BaseURL specifies the base URL for the Ark service endpoint. + // Optional. + BaseURL string + + // Region specifies the geographic region where the Ark service is located. + // Optional. + Region string + + // APIKey specifies the API key for authentication. + // Either APIKey or both AccessKey and SecretKey must be provided. + // APIKey takes precedence if both authentication methods are provided. + // For details, see: https://www.volcengine.com/docs/82379/1298459 + APIKey string + + // AccessKey specifies the access key for authentication. + // Must be used together with SecretKey. + AccessKey string + + // SecretKey specifies the secret key for authentication. + // Must be used together with AccessKey. + SecretKey string + + // Model specifies the identifier of the model endpoint on the Ark platform. + // For details, see: https://www.volcengine.com/docs/82379/1298454 + // Required. + Model string + + // MaxOutputTokens specifies the maximum number of tokens to generate in the response. + // Optional. + MaxOutputTokens *int64 + + // Temperature controls the randomness of the model's output. + // Lower values (e.g., 0.2) make the output more focused and deterministic. + // Higher values (e.g., 1.0) make the output more creative and varied. + // Range: 0.0 to 2.0. + // Optional. + Temperature *float64 + + // TopP controls diversity via nucleus sampling, an alternative to Temperature. + // TopP specifies the cumulative probability threshold for token selection. + // For example, 0.1 means only tokens comprising the top 10% probability mass are considered. + // We recommend using either Temperature or TopP, but not both. + // Range: 0.0 to 1.0. + // Optional. + TopP *float64 + + // ServiceTier specifies the service tier to use for the request. + // Optional. + ServiceTier *responses.ResponsesServiceTier_Enum + + // Text specifies text generation configuration options. + // Optional. + Text *responses.ResponsesText + + // Thinking controls whether the model uses deep thinking mode. + // Optional. + Thinking *responses.ResponsesThinking + + // Reasoning specifies the effort level for the model's reasoning process. + // Optional. + Reasoning *responses.ResponsesReasoning + + // EnablePassBackReasoning controls whether the model passes back reasoning items in the response. + // Optional. Default: true + EnablePassBackReasoning *bool + + // MaxToolCalls specifies the maximum number of tool calls the model can make in a single response. + // Optional. + MaxToolCalls *int64 + + // ParallelToolCalls determines whether the model can invoke multiple tools simultaneously. + // Optional. + ParallelToolCalls *bool + + // ServerTools specifies server-side tools available to the model. + // Optional. + ServerTools []*ServerToolConfig + + // MCPTools specifies Model Context Protocol tools available to the model. + // Optional. + MCPTools []*responses.ToolMcp + + // Cache specifies response caching configuration for the session. + // Optional. + Cache *CacheConfig + + // CustomHeader specifies custom HTTP headers to include in API requests. + // CustomHeader allows passing additional metadata or authentication information. + // Optional. + CustomHeader map[string]string +} + +type CacheConfig struct { + // SessionCache can be overridden by [WithCache]. + // Optional. + SessionCache *SessionCacheConfig +} + +type SessionCacheConfig struct { + // EnableCache controls whether session caching is active. + // When enabled, conversation turns are stored, and the model automatically maintains context + // by locating the most recent cached message in the input (via Response ID in ResponseMeta). + // This cached message and all preceding inputs are excluded from the request. + // The detected Response ID takes precedence over HeadPreviousResponseID. + EnableCache bool + + ExpireAtSec int64 +} + +type ServerToolConfig struct { + WebSearch *responses.ToolWebSearch +} + +func New(_ context.Context, config *Config) (*Model, error) { + if config == nil { + config = &Config{} + } + + c, err := buildClient(config) + if err != nil { + return nil, err + } + + return c, nil +} + +func buildClient(config *Config) (*Model, error) { + var opts []arkruntime.ConfigOption + + if config.Region != "" { + opts = append(opts, arkruntime.WithRegion(config.Region)) + } + if config.Timeout != nil { + opts = append(opts, arkruntime.WithTimeout(*config.Timeout)) + } + if config.HTTPClient != nil { + opts = append(opts, arkruntime.WithHTTPClient(config.HTTPClient)) + } + if config.RetryTimes != nil { + opts = append(opts, arkruntime.WithRetryTimes(*config.RetryTimes)) + } + if config.BaseURL != "" { + opts = append(opts, arkruntime.WithBaseUrl(config.BaseURL)) + } + + var client *arkruntime.Client + if len(config.APIKey) > 0 { + client = arkruntime.NewClientWithApiKey(config.APIKey, opts...) + } else if config.AccessKey != "" && config.SecretKey != "" { + client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey, opts...) + } else { + return nil, fmt.Errorf("new client fail, missing credentials: set 'APIKey' or both 'AccessKey' and 'SecretKey'") + } + + cm := &Model{ + cli: client, + model: config.Model, + maxOutputTokens: config.MaxOutputTokens, + temperature: config.Temperature, + topP: config.TopP, + serviceTier: config.ServiceTier, + text: config.Text, + thinking: config.Thinking, + reasoning: config.Reasoning, + enablePassBackReasoning: config.EnablePassBackReasoning, + maxToolCalls: config.MaxToolCalls, + parallelToolCalls: config.ParallelToolCalls, + serverTools: config.ServerTools, + mcpTools: config.MCPTools, + cache: config.Cache, + customHeader: config.CustomHeader, + } + + return cm, nil +} + +type Model struct { + cli *arkruntime.Client + + rawFunctionTools []*schema.ToolInfo + functionTools []*responses.ResponsesTool + + model string + maxOutputTokens *int64 + temperature *float64 + topP *float64 + serviceTier *responses.ResponsesServiceTier_Enum + text *responses.ResponsesText + thinking *responses.ResponsesThinking + reasoning *responses.ResponsesReasoning + maxToolCalls *int64 + parallelToolCalls *bool + serverTools []*ServerToolConfig + mcpTools []*responses.ToolMcp + + cache *CacheConfig + enablePassBackReasoning *bool + customHeader map[string]string +} + +func (m *Model) Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...agentic.Option) ( + outMsg *schema.AgenticMessage, err error) { + + ctx = callbacks.EnsureRunInfo(ctx, m.GetType(), components.ComponentOfAgenticModel) + + options, specOptions, err := m.getOptions(opts) + if err != nil { + return nil, err + } + + responseReq, err := m.genRequestAndOptions(input, options, specOptions) + if err != nil { + return nil, fmt.Errorf("genRequestAndOptions failed: %w", err) + } + + config := m.toCallbackConfig(responseReq) + + tools := m.rawFunctionTools + if options.Tools != nil { + tools = options.Tools + } + + ctx = callbacks.OnStart(ctx, &agentic.CallbackInput{ + Messages: input, + Tools: tools, + ToolChoice: options.ToolChoice, + Config: config, + }) + + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + responseObject, err := m.cli.CreateResponses(ctx, responseReq, arkruntime.WithCustomHeaders(specOptions.customHeaders)) + if err != nil { + return nil, fmt.Errorf("failed to create responses, err: %w", err) + } + + outMsg, err = toOutputMessage(responseObject) + if err != nil { + return nil, fmt.Errorf("failed to convert output to message, err: %w", err) + } + + callbacks.OnEnd(ctx, &agentic.CallbackOutput{ + Message: outMsg, + Config: config, + }) + + return outMsg, nil +} + +func (m *Model) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...agentic.Option) ( + outStream *schema.StreamReader[*schema.AgenticMessage], err error) { + + ctx = callbacks.EnsureRunInfo(ctx, m.GetType(), components.ComponentOfAgenticModel) + + options, specOptions, err := m.getOptions(opts) + if err != nil { + return nil, err + } + + responseReq, err := m.genRequestAndOptions(input, options, specOptions) + if err != nil { + return nil, fmt.Errorf("genRequestAndOptions failed: %w", err) + } + + config := m.toCallbackConfig(responseReq) + tools := m.rawFunctionTools + if options.Tools != nil { + tools = options.Tools + } + + ctx = callbacks.OnStart(ctx, &agentic.CallbackInput{ + Messages: input, + Tools: tools, + ToolChoice: options.ToolChoice, + Config: config, + }) + + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + responseStreamReader, err := m.cli.CreateResponsesStream(ctx, responseReq, arkruntime.WithCustomHeaders(specOptions.customHeaders)) + if err != nil { + return nil, fmt.Errorf("failed to create responses, err: %w", err) + } + + sr, sw := schema.Pipe[*agentic.CallbackOutput](1) + + go func() { + defer func() { + pe := recover() + if pe != nil { + _ = sw.Send(nil, newPanicErr(pe, debug.Stack())) + } + + _ = responseStreamReader.Close() + sw.Close() + }() + + receivedStreamResponse(responseStreamReader, config, sw) + + }() + + ctx, nsr := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(sr, + func(src *agentic.CallbackOutput) (callbacks.CallbackOutput, error) { + if src.Extra == nil { + src.Extra = make(map[string]any) + } + return src, nil + }, + )) + + outStream = schema.StreamReaderWithConvert(nsr, + func(src callbacks.CallbackOutput) (*schema.AgenticMessage, error) { + s := src.(*agentic.CallbackOutput) + if s.Message == nil { + return nil, schema.ErrNoValue + } + return s.Message, nil + }, + ) + + return outStream, err +} + +func (m *Model) WithTools(functionTools []*schema.ToolInfo) (agentic.Model, error) { + if len(functionTools) == 0 { + return nil, errors.New("function tools are required") + } + + fts, err := m.toFunctionTools(functionTools) + if err != nil { + return nil, fmt.Errorf("failed to convert function tools, err: %w", err) + } + + m_ := *m + m_.rawFunctionTools = functionTools + m_.functionTools = fts + + return &m_, nil +} + +func (m *Model) GetType() string { + return implType +} + +func (m *Model) IsCallbacksEnabled() bool { + return true +} + +type CacheInfo struct { + // ResponseID return by ResponsesAPI, it's specifies the id of prefix that can be used with [WithCache.HeadPreviousResponseID] option. + ResponseID string + // Usage specifies the token usage of prefix + Usage schema.TokenUsage +} + +// CreatePrefixCache creates a prefix context on the server side. +// The server will input the prefix cached context and this turn of input into the model for processing. +// This improves efficiency by reducing token usage and request size. +// +// Parameters: +// - ctx: The context for the request +// - prefix: Initial messages to be cached as prefix context +// - expireAtSec: Expiration Unix timestamp (in seconds) for the cached prefix. Defaults to 3 days from now if not specified. +// +// Returns: +// - info: Information about the created prefix cache, including the context ID and token usage +// - err: Any error encountered during the operation +// +// ref: https://www.volcengine.com/docs/82379/1396490#_1-%E5%88%9B%E5%BB%BA%E5%89%8D%E7%BC%80%E7%BC%93%E5%AD%98 +// +// Note: +// - It is unavailable for doubao models of version 1.6 and above. +func (m *Model) CreatePrefixCache(ctx context.Context, prefix []*schema.AgenticMessage, expireAtSec *int64, + opts ...agentic.Option) (info *CacheInfo, err error) { + + responseReq := &responses.ResponsesRequest{ + Model: m.model, + ExpireAt: expireAtSec, + Store: ptrOf(true), + Caching: &responses.ResponsesCaching{ + Type: responses.CacheType_enabled.Enum(), + Prefix: ptrOf(true), + }, + } + + options, specOptions, err := m.getOptions(opts) + if err != nil { + return nil, fmt.Errorf("failed to getOptions, err: %w", err) + } + + err = m.prePopulateConfig(responseReq, options, specOptions) + if err != nil { + return nil, fmt.Errorf("failed to prePopulateConfig, err: %w", err) + } + + err = m.populateInput(prefix, responseReq) + if err != nil { + return nil, fmt.Errorf("failed to populateInput, err: %w", err) + } + + err = m.populateTools(responseReq, options, specOptions) + if err != nil { + return nil, fmt.Errorf("failed to populateTools, err: %w", err) + } + + err = m.populateToolChoice(responseReq, options) + if err != nil { + return nil, fmt.Errorf("failed to populateToolChoice, err: %w", err) + } + + responseObj, err := m.cli.CreateResponses(ctx, responseReq) + if err != nil { + return nil, err + } + + info = &CacheInfo{ + ResponseID: responseObj.Id, + Usage: *toTokenUsage(responseObj), + } + + return info, nil +} + +func (m *Model) toCallbackConfig(req *responses.ResponsesRequest) *agentic.Config { + return &agentic.Config{ + Model: req.Model, + Temperature: ptrFromOrZero(req.Temperature), + TopP: ptrFromOrZero(req.TopP), + } +} + +func (m *Model) toFunctionTools(functionTools []*schema.ToolInfo) ([]*responses.ResponsesTool, error) { + tools := make([]*responses.ResponsesTool, len(functionTools)) + for i := range functionTools { + ti := functionTools[i] + + paramsJSONSchema, err := ti.ParamsOneOf.ToJSONSchema() + if err != nil { + return nil, fmt.Errorf("failed to convert tool parameters to JSONSchema, err: %w", err) + } + + b, err := sonic.Marshal(paramsJSONSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSONSchema, err: %w", err) + } + + tools[i] = &responses.ResponsesTool{ + Union: &responses.ResponsesTool_ToolFunction{ + ToolFunction: &responses.ToolFunction{ + Name: ti.Name, + Type: responses.ToolType_function, + Description: &ti.Desc, + Parameters: &responses.Bytes{ + Value: b, + }, + }, + }, + } + } + + return tools, nil +} + +func (m *Model) toServerTools(serverTools []*ServerToolConfig) (tools []*responses.ResponsesTool, rr error) { + tools = make([]*responses.ResponsesTool, len(serverTools)) + + for i := range serverTools { + ti := serverTools[i] + switch { + case ti.WebSearch != nil: + tools[i] = &responses.ResponsesTool{ + Union: &responses.ResponsesTool_ToolWebSearch{ + ToolWebSearch: ti.WebSearch, + }, + } + + default: + continue + } + } + + return tools, nil +} + +func (m *Model) toMCPTools(mcpTools []*responses.ToolMcp) []*responses.ResponsesTool { + tools := make([]*responses.ResponsesTool, len(mcpTools)) + for i := range mcpTools { + tools[i] = &responses.ResponsesTool{ + Union: &responses.ResponsesTool_ToolMcp{ + ToolMcp: mcpTools[i], + }, + } + } + return tools +} + +func (m *Model) getOptions(opts []agentic.Option) (*agentic.Options, *arkOptions, error) { + options := agentic.GetCommonOptions(&agentic.Options{ + Temperature: m.temperature, + Model: &m.model, + TopP: m.topP, + }, opts...) + + arkOpts := agentic.GetImplSpecificOptions(&arkOptions{ + reasoning: m.reasoning, + thinking: m.thinking, + text: m.text, + maxToolCalls: m.maxToolCalls, + parallelToolCalls: m.parallelToolCalls, + maxOutputTokens: m.maxOutputTokens, + serverTools: m.serverTools, + mcpTools: m.mcpTools, + customHeaders: m.customHeader, + }, opts...) + + return options, arkOpts, nil +} + +func (m *Model) genRequestAndOptions(in []*schema.AgenticMessage, options *agentic.Options, + specOptions *arkOptions) (responseReq *responses.ResponsesRequest, err error) { + + responseReq = &responses.ResponsesRequest{} + + err = m.prePopulateConfig(responseReq, options, specOptions) + if err != nil { + return nil, fmt.Errorf("failed to prePopulateConfig, err: %w", err) + } + + in, err = m.populateCache(in, responseReq, specOptions) + if err != nil { + return nil, fmt.Errorf("failed to populateCache, err: %w", err) + } + + err = m.populateInput(in, responseReq) + if err != nil { + return nil, fmt.Errorf("failed to populateInput, err: %w", err) + } + + err = m.populateTools(responseReq, options, specOptions) + if err != nil { + return nil, fmt.Errorf("failed to populateTools, err: %w", err) + } + + err = m.populateToolChoice(responseReq, options) + if err != nil { + return nil, fmt.Errorf("failed to populateToolChoice, err: %w", err) + } + + return responseReq, nil +} + +func (m *Model) prePopulateConfig(responseReq *responses.ResponsesRequest, options *agentic.Options, + specOptions *arkOptions) error { + + // instance configuration + responseReq.ServiceTier = m.serviceTier + + // options configuration + responseReq.TopP = options.TopP + responseReq.Temperature = options.Temperature + if options.Model != nil { + responseReq.Model = *options.Model + } + + // specific options configuration + responseReq.Thinking = specOptions.thinking + responseReq.Reasoning = specOptions.reasoning + responseReq.Text = specOptions.text + responseReq.MaxOutputTokens = specOptions.maxOutputTokens + responseReq.MaxToolCalls = specOptions.maxToolCalls + responseReq.ParallelToolCalls = specOptions.parallelToolCalls + + return nil +} + +func (m *Model) populateCache(in []*schema.AgenticMessage, responseReq *responses.ResponsesRequest, + arkOpts *arkOptions) ([]*schema.AgenticMessage, error) { + + var ( + store = false + enableCache = false + expireAtSec *int64 + headRespID *string + ) + + if m.cache != nil { + if sCache := m.cache.SessionCache; sCache != nil { + if sCache.EnableCache { + store = true + enableCache = true + } + expireAtSec = &sCache.ExpireAtSec + } + } + + if cacheOpt := arkOpts.cache; cacheOpt != nil { + headRespID = cacheOpt.HeadPreviousResponseID + + if sCacheOpt := cacheOpt.SessionCache; sCacheOpt != nil { + expireAtSec = &sCacheOpt.ExpireAtSec + + if sCacheOpt.EnableCache { + store = true + enableCache = true + } else { + store = false + enableCache = false + } + } + } + + var ( + preRespID *string + inputIdx int + ) + + now := time.Now().Unix() + + if enableCache { + for i := len(in) - 1; i >= 0; i-- { + msg := in[i] + if msg.ResponseMeta == nil { + continue + } + + extensions := getResponseMeta(msg.ResponseMeta) + if extensions == nil || ptrFromOrZero(extensions.ExpireAt) <= now { + continue + } + + inputIdx = i + preRespID = &extensions.ID + + break + } + } + + if preRespID != nil { + if inputIdx+1 >= len(in) { + return in, fmt.Errorf("not found incremental input after ResponseID") + } + in = in[inputIdx+1:] + } + + // ResponseID has a higher priority than HeadPreviousResponseID + if preRespID == nil { + preRespID = headRespID + } + + responseReq.PreviousResponseId = preRespID + responseReq.Store = &store + + if expireAtSec != nil { + responseReq.ExpireAt = expireAtSec + } + + responseReq.Caching = &responses.ResponsesCaching{ + Type: func() *responses.CacheType_Enum { + if enableCache { + return responses.CacheType_enabled.Enum() + } + return responses.CacheType_disabled.Enum() + }(), + } + + return in, nil +} + +func (m *Model) populateInput(in []*schema.AgenticMessage, responseReq *responses.ResponsesRequest) (err error) { + if len(in) == 0 { + return nil + } + + itemList := make([]*responses.InputItem, 0, len(in)) + + for _, msg := range in { + var inputItems []*responses.InputItem + + switch msg.Role { + case schema.AgenticRoleTypeUser: + inputItems, err = toUserRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeAssistant: + inputItems, err = toAssistantRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeDeveloper: + inputItems, err = toDeveloperRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeSystem: + inputItems, err = toSystemRoleInputItems(msg) + if err != nil { + return err + } + + default: + return fmt.Errorf("invalid role: %s", msg.Role) + } + + itemList = append(itemList, inputItems...) + } + + if m.enablePassBackReasoning != nil && !*m.enablePassBackReasoning { + itemList = removeReasoningItems(itemList) + } + + responseReq.Input = &responses.ResponsesInput{ + Union: &responses.ResponsesInput_ListValue{ + ListValue: &responses.InputItemList{ + ListValue: itemList, + }, + }, + } + + return nil +} + +func removeReasoningItems(itemList []*responses.InputItem) []*responses.InputItem { + newItemList := make([]*responses.InputItem, 0, len(itemList)) + + for i := len(itemList) - 1; i >= 0; i-- { + if itemList[i].Union == nil { + continue + } + if _, ok := itemList[i].Union.(*responses.InputItem_Reasoning); ok { + continue + } + newItemList = append(newItemList, itemList[i]) + } + + return newItemList +} + +func (m *Model) populateTools(responseReq *responses.ResponsesRequest, options *agentic.Options, specOptions *arkOptions) (err error) { + if responseReq.PreviousResponseId != nil { + return nil + } + + var functionTools []*responses.ResponsesTool + if options.Tools != nil { + functionTools, err = m.toFunctionTools(options.Tools) + if err != nil { + return err + } + } else { + functionTools = m.functionTools + } + + responseReq.Tools = append(responseReq.Tools, functionTools...) + + serverTools, err := m.toServerTools(specOptions.serverTools) + if err != nil { + return err + } + + responseReq.Tools = append(responseReq.Tools, serverTools...) + + mcpTools := m.toMCPTools(specOptions.mcpTools) + + responseReq.Tools = append(responseReq.Tools, mcpTools...) + + return nil +} + +func (m *Model) populateToolChoice(responseReq *responses.ResponsesRequest, options *agentic.Options) (err error) { + if responseReq.PreviousResponseId != nil { + return nil + } + + if options.ToolChoice == nil && len(options.AllowedTools) > 0 { + return fmt.Errorf("tool choice must be specified when allowed tools are provided") + } + if options.ToolChoice == nil { + return nil + } + + switch *options.ToolChoice { + case schema.ToolChoiceForbidden: + responseReq.ToolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_Mode{ + Mode: responses.ToolChoiceMode_none, + }, + } + + case schema.ToolChoiceAllowed: + if len(options.AllowedTools) > 1 { + return fmt.Errorf("only one allowed tool is supported when tool choice is 'allowed'") + } + if len(options.AllowedTools) == 0 { + responseReq.ToolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_Mode{ + Mode: responses.ToolChoiceMode_auto, + }, + } + return nil + } + + responseReq.ToolChoice, err = toForcedToolChoice(options.AllowedTools[0]) + if err != nil { + return err + } + + case schema.ToolChoiceForced: + if len(options.AllowedTools) > 1 { + return fmt.Errorf("only one allowed tool is supported when tool choice is 'forced'") + } + if len(options.AllowedTools) == 0 { + responseReq.ToolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_Mode{ + Mode: responses.ToolChoiceMode_required, + }, + } + return nil + } + + responseReq.ToolChoice, err = toForcedToolChoice(options.AllowedTools[0]) + if err != nil { + return err + } + + default: + return fmt.Errorf("invalid tool choice: %s", *options.ToolChoice) + } + + return nil +} + +func toForcedToolChoice(tool *schema.AllowedTool) (*responses.ResponsesToolChoice, error) { + var toolChoice *responses.ResponsesToolChoice + + switch { + case tool.FunctionToolName != "": + toolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_FunctionToolChoice{ + FunctionToolChoice: &responses.FunctionToolChoice{ + Type: responses.ToolType_function, + Name: tool.FunctionToolName, + }, + }, + } + + case tool.MCPTool != nil: + toolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_McpToolChoice{ + McpToolChoice: &responses.McpToolChoice{ + Type: responses.ToolType_mcp, + Name: ptrIfNonZero(tool.MCPTool.Name), + ServerLabel: tool.MCPTool.ServerLabel, + }, + }, + } + + case tool.ServerTool != nil: + switch tool.ServerTool.Name { + case string(ServerToolNameWebSearch): + toolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_WebSearchToolChoice{ + WebSearchToolChoice: &responses.WebSearchToolChoice{ + Type: responses.ToolType_web_search, + }, + }, + } + default: + return nil, fmt.Errorf("invalid server tool name: %s", tool.ServerTool.Name) + } + + default: + return nil, fmt.Errorf("found unknown allowed tool") + } + + return toolChoice, nil +} diff --git a/components/agentic/ark/model_test.go b/components/agentic/ark/model_test.go new file mode 100644 index 000000000..2f60d7070 --- /dev/null +++ b/components/agentic/ark/model_test.go @@ -0,0 +1,575 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/bytedance/mockey" + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" +) + +func TestNew(t *testing.T) { + mockey.PatchConvey("TestNew", t, func() { + ctx := context.Background() + config := &Config{ + Model: "test-model", + APIKey: "test-api-key", + } + + mockey.Mock(arkruntime.NewClientWithApiKey).Return(&arkruntime.Client{}).Build() + mockey.Mock(arkruntime.NewClientWithAkSk).Return(&arkruntime.Client{}).Build() + + mockey.PatchConvey("success with api key", func() { + m, err := New(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, m) + assert.Equal(t, "test-model", m.model) + }) + + mockey.PatchConvey("success with ak/sk", func() { + config.APIKey = "" + config.AccessKey = "ak" + config.SecretKey = "sk" + m, err := New(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, m) + }) + + mockey.PatchConvey("fail with missing credentials", func() { + config.APIKey = "" + config.AccessKey = "" + m, err := New(ctx, config) + assert.Error(t, err) + assert.Nil(t, m) + }) + + mockey.PatchConvey("full config", func() { + timeout := 10 * time.Second + retry := 3 + config.APIKey = "key" + config.Timeout = &timeout + config.RetryTimes = &retry + config.Region = "cn-beijing" + config.BaseURL = "http://test.com" + config.HTTPClient = &http.Client{} + + m, err := New(ctx, config) + assert.NoError(t, err) + assert.NotNil(t, m) + }) + }) +} + +func TestModelWithTools(t *testing.T) { + mockey.PatchConvey("TestModelWithTools", t, func() { + m := &Model{} + tools := []*schema.ToolInfo{ + { + Name: "test_tool", + Desc: "test tool desc", + }, + } + + mockey.Mock((*Model).toFunctionTools).Return([]*responses.ResponsesTool{{}}, nil).Build() + + mockey.PatchConvey("success", func() { + nm, err := m.WithTools(tools) + assert.NoError(t, err) + assert.NotNil(t, nm) + assert.Equal(t, tools, nm.(*Model).rawFunctionTools) + assert.Equal(t, 1, len(nm.(*Model).functionTools)) + }) + + mockey.PatchConvey("empty tools", func() { + _, err := m.WithTools(nil) + assert.Error(t, err) + }) + }) +} + +func TestModelGetType(t *testing.T) { + mockey.PatchConvey("TestModelGetType", t, func() { + m := &Model{} + assert.Equal(t, implType, m.GetType()) + }) +} + +func TestModelIsCallbacksEnabled(t *testing.T) { + mockey.PatchConvey("TestModelIsCallbacksEnabled", t, func() { + m := &Model{} + assert.True(t, m.IsCallbacksEnabled()) + }) +} + +func TestModelToCallbackConfig(t *testing.T) { + mockey.PatchConvey("TestModelToCallbackConfig", t, func() { + m := &Model{} + temp := 0.7 + topP := 0.9 + req := &responses.ResponsesRequest{ + Model: "m", + Temperature: &temp, + TopP: &topP, + } + cfg := m.toCallbackConfig(req) + assert.Equal(t, "m", cfg.Model) + assert.Equal(t, temp, cfg.Temperature) + assert.Equal(t, topP, cfg.TopP) + }) +} + +func TestModelGenerate(t *testing.T) { + mockey.PatchConvey("TestModelGenerate", t, func() { + ctx := context.Background() + m := &Model{ + cli: &arkruntime.Client{}, + model: "m", + } + input := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "hello"}), + }, + }, + } + + mockey.Mock((*arkruntime.Client).CreateResponses).Return(&responses.ResponseObject{ + Id: "rid", + Output: []*responses.OutputItem{ + { + Union: &responses.OutputItem_OutputMessage{ + OutputMessage: &responses.ItemOutputMessage{ + Role: responses.MessageRole_assistant, + Status: responses.ItemStatus_completed, + Content: []*responses.OutputContentItem{ + { + Union: &responses.OutputContentItem_Text{ + Text: &responses.OutputContentItemText{Text: "hi"}, + }, + }, + }, + }, + }, + }, + }, + Usage: &responses.Usage{ + InputTokens: 10, + }, + }, nil).Build() + + mockey.PatchConvey("success", func() { + out, err := m.Generate(ctx, input) + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Equal(t, "hi", out.ContentBlocks[0].AssistantGenText.Text) + }) + + mockey.PatchConvey("error", func() { + mockey.Mock((*arkruntime.Client).CreateResponses).Return(nil, errors.New("err")).Build() + _, err := m.Generate(ctx, input) + assert.Error(t, err) + }) + }) +} + +func TestModelStream(t *testing.T) { + mockey.PatchConvey("TestModelStream", t, func() { + ctx := context.Background() + m := &Model{ + cli: &arkruntime.Client{}, + model: "m", + } + input := []*schema.AgenticMessage{ + {Role: schema.AgenticRoleTypeUser}, + } + + mockey.PatchConvey("error creating stream", func() { + mockey.Mock((*arkruntime.Client).CreateResponsesStream).Return(nil, errors.New("err")).Build() + _, err := m.Stream(ctx, input) + assert.Error(t, err) + }) + + mockey.PatchConvey("success", func() { + mockey.Mock((*arkruntime.Client).CreateResponsesStream).Return(&utils.ResponsesStreamReader{}, nil).Build() + mockey.Mock((*utils.ChatCompletionStreamReader).Close).Return(nil).Build() + mockey.Mock(receivedStreamResponse).Return().Build() + + sr, err := m.Stream(ctx, input) + assert.NoError(t, err) + assert.NotNil(t, sr) + + sr.Close() + time.Sleep(10 * time.Millisecond) // wait for goroutine to finish + }) + }) +} + +func TestModelCreatePrefixCache(t *testing.T) { + mockey.PatchConvey("TestModelCreatePrefixCache", t, func() { + ctx := context.Background() + m := &Model{ + cli: &arkruntime.Client{}, + model: "m", + } + prefix := []*schema.AgenticMessage{ + {Role: schema.AgenticRoleTypeUser}, + } + + mockey.Mock((*arkruntime.Client).CreateResponses).Return(&responses.ResponseObject{ + Id: "rid", + Usage: &responses.Usage{ + InputTokens: 10, + }, + }, nil).Build() + + info, err := m.CreatePrefixCache(ctx, prefix, ptrOf(int64(3600))) + assert.NoError(t, err) + assert.NotNil(t, info) + assert.Equal(t, "rid", info.ResponseID) + }) +} + +func TestModelToServerTools(t *testing.T) { + mockey.PatchConvey("TestModelToServerTools", t, func() { + m := &Model{} + serverTools := []*ServerToolConfig{ + { + WebSearch: &responses.ToolWebSearch{}, + }, + {}, // empty one to trigger continue + } + + tools, err := m.toServerTools(serverTools) + assert.NoError(t, err) + assert.Equal(t, 2, len(tools)) + assert.NotNil(t, tools[0].Union.(*responses.ResponsesTool_ToolWebSearch)) + }) +} + +func TestModelToMCPTools(t *testing.T) { + mockey.PatchConvey("TestModelToMCPTools", t, func() { + m := &Model{} + mcpTools := []*responses.ToolMcp{ + {}, + } + + tools := m.toMCPTools(mcpTools) + assert.Equal(t, 1, len(tools)) + assert.NotNil(t, tools[0].Union.(*responses.ResponsesTool_ToolMcp)) + }) +} + +func TestModelToFunctionTools(t *testing.T) { + mockey.PatchConvey("TestModelToFunctionTools", t, func() { + m := &Model{} + tools := []*schema.ToolInfo{ + { + Name: "t", + Desc: "d", + }, + } + + mockey.Mock(sonic.Marshal).Return([]byte("{}"), nil).Build() + + res, err := m.toFunctionTools(tools) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + assert.Equal(t, "t", res[0].Union.(*responses.ResponsesTool_ToolFunction).ToolFunction.Name) + }) +} + +func TestModelGetOptions(t *testing.T) { + mockey.PatchConvey("TestModelGetOptions", t, func() { + m := &Model{ + model: "m", + temperature: ptrOf(0.7), + } + + mockey.PatchConvey("default", func() { + opts, specOpts, err := m.getOptions(nil) + assert.NoError(t, err) + assert.Equal(t, "m", *opts.Model) + assert.Equal(t, 0.7, *opts.Temperature) + assert.NotNil(t, specOpts) + }) + + mockey.PatchConvey("override", func() { + opts, _, err := m.getOptions([]agentic.Option{ + agentic.WithTemperature(0.9), + }) + assert.NoError(t, err) + assert.Equal(t, 0.9, *opts.Temperature) + }) + }) +} + +func TestModelGenRequestAndOptions(t *testing.T) { + mockey.PatchConvey("TestModelGenRequestAndOptions", t, func() { + m := &Model{model: "m"} + input := []*schema.AgenticMessage{{Role: schema.AgenticRoleTypeUser}} + opts := &agentic.Options{Model: ptrOf("m")} + arkOpts := &arkOptions{} + + req, err := m.genRequestAndOptions(input, opts, arkOpts) + assert.NoError(t, err) + assert.NotNil(t, req) + assert.Equal(t, "m", req.Model) + }) +} + +func TestModelPrePopulateConfig(t *testing.T) { + mockey.PatchConvey("TestModelPrePopulateConfig", t, func() { + m := &Model{serviceTier: responses.ResponsesServiceTier_default.Enum()} + req := &responses.ResponsesRequest{} + opts := &agentic.Options{ + TopP: ptrOf(0.9), + Temperature: ptrOf(0.7), + Model: ptrOf("m2"), + } + specOpts := &arkOptions{ + thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_enabled.Enum()}, + } + + err := m.prePopulateConfig(req, opts, specOpts) + assert.NoError(t, err) + assert.Equal(t, "m2", req.Model) + assert.Equal(t, 0.9, *req.TopP) + assert.Equal(t, responses.ThinkingType_enabled, *req.Thinking.Type) + assert.Equal(t, responses.ResponsesServiceTier_default, *req.ServiceTier) + }) +} + +func TestModelPopulateCache(t *testing.T) { + mockey.PatchConvey("TestModelPopulateCache", t, func() { + m := &Model{} + req := &responses.ResponsesRequest{} + specOpts := &arkOptions{} + + mockey.PatchConvey("no cache", func() { + in := []*schema.AgenticMessage{{Role: schema.AgenticRoleTypeUser}} + outIn, err := m.populateCache(in, req, specOpts) + assert.NoError(t, err) + assert.Equal(t, in, outIn) + assert.False(t, *req.Store) + }) + + mockey.PatchConvey("session cache enabled in model", func() { + m.cache = &CacheConfig{ + SessionCache: &SessionCacheConfig{ + EnableCache: true, + ExpireAtSec: 3600, + }, + } + in := []*schema.AgenticMessage{{Role: schema.AgenticRoleTypeUser}} + _, err := m.populateCache(in, req, specOpts) + assert.NoError(t, err) + assert.True(t, *req.Store) + assert.Equal(t, responses.CacheType_enabled, *req.Caching.Type) + }) + + mockey.PatchConvey("response id in messages", func() { + m.cache = &CacheConfig{ + SessionCache: &SessionCacheConfig{ + EnableCache: true, + ExpireAtSec: 3600, + }, + } + + // Mock time.Now to control expiration check + mockey.Mock(time.Now).Return(time.Unix(1000, 0)).Build() + + in := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeAssistant, + ResponseMeta: &schema.AgenticResponseMeta{ + Extension: &ResponseMetaExtension{ + ID: "rid", + ExpireAt: ptrOf(int64(2000)), + }, + }, + }, + {Role: schema.AgenticRoleTypeUser}, + } + + outIn, err := m.populateCache(in, req, specOpts) + assert.NoError(t, err) + assert.Equal(t, 1, len(outIn)) + assert.Equal(t, "rid", *req.PreviousResponseId) + }) + + mockey.PatchConvey("response id in messages - no incremental input", func() { + m.cache = &CacheConfig{ + SessionCache: &SessionCacheConfig{ + EnableCache: true, + ExpireAtSec: 3600, + }, + } + mockey.Mock(time.Now).Return(time.Unix(1000, 0)).Build() + in := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeAssistant, + ResponseMeta: &schema.AgenticResponseMeta{ + Extension: &ResponseMetaExtension{ + ID: "rid", + ExpireAt: ptrOf(int64(2000)), + }, + }, + }, + } + _, err := m.populateCache(in, req, specOpts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found incremental input") + }) + }) +} + +func TestModelPopulateInput(t *testing.T) { + mockey.PatchConvey("TestModelPopulateInput", t, func() { + m := &Model{} + req := &responses.ResponsesRequest{} + + mockey.PatchConvey("mixed roles", func() { + in := []*schema.AgenticMessage{ + {Role: schema.AgenticRoleTypeSystem}, + {Role: schema.AgenticRoleTypeUser}, + {Role: schema.AgenticRoleTypeAssistant}, + } + err := m.populateInput(in, req) + assert.NoError(t, err) + assert.NotNil(t, req.Input) + }) + + mockey.PatchConvey("invalid role", func() { + in := []*schema.AgenticMessage{ + {Role: "invalid"}, + } + err := m.populateInput(in, req) + assert.Error(t, err) + }) + }) +} + +func TestModelPopulateTools(t *testing.T) { + mockey.PatchConvey("TestModelPopulateTools", t, func() { + m := &Model{ + functionTools: []*responses.ResponsesTool{{}}, + } + req := &responses.ResponsesRequest{} + opts := &agentic.Options{} + specOpts := &arkOptions{} + + mockey.PatchConvey("default tools", func() { + err := m.populateTools(req, opts, specOpts) + assert.NoError(t, err) + assert.Equal(t, 1, len(req.Tools)) + }) + + mockey.PatchConvey("override tools", func() { + opts.Tools = []*schema.ToolInfo{{Name: "t"}} + err := m.populateTools(req, opts, specOpts) + assert.NoError(t, err) + assert.Equal(t, 1, len(req.Tools)) + }) + + mockey.PatchConvey("previous response id exists", func() { + req.PreviousResponseId = ptrOf("rid") + err := m.populateTools(req, opts, specOpts) + assert.NoError(t, err) + assert.Equal(t, 0, len(req.Tools)) + }) + }) +} + +func TestModelPopulateToolChoice(t *testing.T) { + mockey.PatchConvey("TestModelPopulateToolChoice", t, func() { + m := &Model{} + req := &responses.ResponsesRequest{} + opts := &agentic.Options{} + + mockey.PatchConvey("forbidden", func() { + choice := schema.ToolChoiceForbidden + opts.ToolChoice = &choice + err := m.populateToolChoice(req, opts) + assert.NoError(t, err) + assert.Equal(t, responses.ToolChoiceMode_none, req.ToolChoice.Union.(*responses.ResponsesToolChoice_Mode).Mode) + }) + + mockey.PatchConvey("allowed - auto", func() { + choice := schema.ToolChoiceAllowed + opts.ToolChoice = &choice + err := m.populateToolChoice(req, opts) + assert.NoError(t, err) + assert.Equal(t, responses.ToolChoiceMode_auto, req.ToolChoice.Union.(*responses.ResponsesToolChoice_Mode).Mode) + }) + + mockey.PatchConvey("forced - required", func() { + choice := schema.ToolChoiceForced + opts.ToolChoice = &choice + err := m.populateToolChoice(req, opts) + assert.NoError(t, err) + assert.Equal(t, responses.ToolChoiceMode_required, req.ToolChoice.Union.(*responses.ResponsesToolChoice_Mode).Mode) + }) + + mockey.PatchConvey("forced - specific", func() { + choice := schema.ToolChoiceForced + opts.ToolChoice = &choice + opts.AllowedTools = []*schema.AllowedTool{{FunctionToolName: "f"}} + err := m.populateToolChoice(req, opts) + assert.NoError(t, err) + assert.Equal(t, "f", req.ToolChoice.Union.(*responses.ResponsesToolChoice_FunctionToolChoice).FunctionToolChoice.Name) + }) + }) +} + +func TestToForcedToolChoice(t *testing.T) { + mockey.PatchConvey("TestToForcedToolChoice", t, func() { + mockey.PatchConvey("function", func() { + res, err := toForcedToolChoice(&schema.AllowedTool{FunctionToolName: "f"}) + assert.NoError(t, err) + assert.Equal(t, "f", res.Union.(*responses.ResponsesToolChoice_FunctionToolChoice).FunctionToolChoice.Name) + }) + + mockey.PatchConvey("mcp", func() { + res, err := toForcedToolChoice(&schema.AllowedTool{MCPTool: &schema.AllowedMCPTool{Name: "m", ServerLabel: "s"}}) + assert.NoError(t, err) + assert.Equal(t, "m", *res.Union.(*responses.ResponsesToolChoice_McpToolChoice).McpToolChoice.Name) + }) + + mockey.PatchConvey("server tool - web search", func() { + res, err := toForcedToolChoice(&schema.AllowedTool{ServerTool: &schema.AllowedServerTool{Name: string(ServerToolNameWebSearch)}}) + assert.NoError(t, err) + assert.Equal(t, responses.ToolType_web_search, res.Union.(*responses.ResponsesToolChoice_WebSearchToolChoice).WebSearchToolChoice.Type) + }) + + mockey.PatchConvey("unknown", func() { + _, err := toForcedToolChoice(&schema.AllowedTool{}) + assert.Error(t, err) + }) + }) +} diff --git a/components/agentic/ark/option.go b/components/agentic/ark/option.go new file mode 100644 index 000000000..0910c2798 --- /dev/null +++ b/components/agentic/ark/option.go @@ -0,0 +1,110 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "github.com/cloudwego/eino/components/agentic" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +type arkOptions struct { + reasoning *responses.ResponsesReasoning + thinking *responses.ResponsesThinking + maxOutputTokens *int64 + maxToolCalls *int64 + parallelToolCalls *bool + text *responses.ResponsesText + + serverTools []*ServerToolConfig + mcpTools []*responses.ToolMcp + + customHeaders map[string]string + cache *CacheOption +} + +type CacheOption struct { + // HeadPreviousResponseID is a response ID from a previous ResponsesAPI call. + // This ID links the current request to a previous conversation context, enabling + // features like conversation continuation and prefix caching. + // The referenced response must be cached before use. + // Optional. + HeadPreviousResponseID *string + + // SessionCache is the configuration of ResponsesAPI session cache. + // Optional. + SessionCache *SessionCacheConfig +} + +func WithReasoning(reasoning *responses.ResponsesReasoning) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.reasoning = reasoning + }) +} + +func WithThinking(thinking *responses.ResponsesThinking) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.thinking = thinking + }) +} + +func WithText(text *responses.ResponsesText) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.text = text + }) +} + +func WithMaxOutputTokens(maxOutputTokens int64) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.maxOutputTokens = &maxOutputTokens + }) +} + +func WithMaxToolCalls(maxToolCalls int64) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.maxToolCalls = &maxToolCalls + }) +} + +func WithParallelToolCalls(parallelToolCalls bool) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.parallelToolCalls = ¶llelToolCalls + }) +} + +func WithServerTools(tools []*ServerToolConfig) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.serverTools = tools + }) +} + +func WithMCPTools(tools []*responses.ToolMcp) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.mcpTools = tools + }) +} + +func WithCustomHeaders(headers map[string]string) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.customHeaders = headers + }) +} + +func WithCache(option *CacheOption) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *arkOptions) { + o.cache = option + }) +} diff --git a/components/agentic/ark/option_test.go b/components/agentic/ark/option_test.go new file mode 100644 index 000000000..1917f91b9 --- /dev/null +++ b/components/agentic/ark/option_test.go @@ -0,0 +1,102 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "testing" + + "github.com/cloudwego/eino/components/agentic" + "github.com/stretchr/testify/assert" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" +) + +func TestWithReasoning(t *testing.T) { + base := &arkOptions{} + reasoning := &responses.ResponsesReasoning{} + got := agentic.GetImplSpecificOptions(base, WithReasoning(reasoning)) + assert.Same(t, reasoning, got.reasoning) +} + +func TestWithThinking(t *testing.T) { + base := &arkOptions{} + thinking := &responses.ResponsesThinking{} + got := agentic.GetImplSpecificOptions(base, WithThinking(thinking)) + assert.Same(t, thinking, got.thinking) +} + +func TestWithText(t *testing.T) { + base := &arkOptions{} + text := &responses.ResponsesText{} + got := agentic.GetImplSpecificOptions(base, WithText(text)) + assert.Same(t, text, got.text) +} + +func TestWithMaxOutputTokens(t *testing.T) { + base := &arkOptions{} + got := agentic.GetImplSpecificOptions(base, WithMaxOutputTokens(128)) + if assert.NotNil(t, got.maxOutputTokens) { + assert.Equal(t, int64(128), *got.maxOutputTokens) + } +} + +func TestWithMaxToolCalls(t *testing.T) { + base := &arkOptions{} + got := agentic.GetImplSpecificOptions(base, WithMaxToolCalls(10)) + if assert.NotNil(t, got.maxToolCalls) { + assert.Equal(t, int64(10), *got.maxToolCalls) + } +} + +func TestWithParallelToolCalls(t *testing.T) { + base := &arkOptions{} + got := agentic.GetImplSpecificOptions(base, WithParallelToolCalls(true)) + if assert.NotNil(t, got.parallelToolCalls) { + assert.True(t, *got.parallelToolCalls) + } +} + +func TestWithServerTools(t *testing.T) { + base := &arkOptions{} + serverTools := []*ServerToolConfig{ + {WebSearch: &responses.ToolWebSearch{}}, + } + got := agentic.GetImplSpecificOptions(base, WithServerTools(serverTools)) + assert.Equal(t, serverTools, got.serverTools) +} + +func TestWithMCPTools(t *testing.T) { + base := &arkOptions{} + mcpTools := []*responses.ToolMcp{ + {Type: responses.ToolType_mcp}, + } + got := agentic.GetImplSpecificOptions(base, WithMCPTools(mcpTools)) + assert.Equal(t, mcpTools, got.mcpTools) +} + +func TestWithCustomHeaders(t *testing.T) { + base := &arkOptions{} + headers := map[string]string{"K": "V"} + got := agentic.GetImplSpecificOptions(base, WithCustomHeaders(headers)) + assert.Equal(t, headers, got.customHeaders) +} + +func TestWithCache(t *testing.T) { + base := &arkOptions{} + cache := &CacheOption{} + got := agentic.GetImplSpecificOptions(base, WithCache(cache)) + assert.Same(t, cache, got.cache) +} diff --git a/components/agentic/ark/register.go b/components/agentic/ark/register.go new file mode 100644 index 000000000..ab7956416 --- /dev/null +++ b/components/agentic/ark/register.go @@ -0,0 +1,36 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[blockExtraItemID]("_eino_ext_ark_content_block_extra_item_id") + schema.RegisterName[blockExtraItemStatus]("_eino_ext_ark_content_block_extra_item_status") + schema.RegisterName[*ResponseMetaExtension]("_eino_ext_ark_response_meta_extension") + schema.RegisterName[*AssistantGenTextExtension]("_eino_ext_ark_assistant_gen_text_extension") + schema.RegisterName[*ServerToolCallArguments]("_eino_ext_ark_server_tool_call_arguments") + + compose.RegisterStreamChunkConcatFunc(concatFirstNonZero[blockExtraItemID]) + compose.RegisterStreamChunkConcatFunc(concatLast[blockExtraItemStatus]) + compose.RegisterStreamChunkConcatFunc(concatResponseMetaExtensions) + compose.RegisterStreamChunkConcatFunc(concatAssistantGenTextExtensions) + compose.RegisterStreamChunkConcatFunc(concatServerToolCallArguments) +} diff --git a/components/agentic/ark/utils.go b/components/agentic/ark/utils.go new file mode 100644 index 000000000..1dd22e219 --- /dev/null +++ b/components/agentic/ark/utils.go @@ -0,0 +1,140 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/eino-contrib/jsonschema" +) + +func ptrIfNonZero[T any](v T) *T { + if reflect.ValueOf(v).IsZero() { + return nil + } + return &v +} + +func coalesce[T any](x, y T) T { + if !reflect.ValueOf(x).IsZero() { + return x + } + return y +} + +func ptrFromOrZero[T any](v *T) T { + if v == nil { + var t T + return t + } + return *v +} + +func ptrOf[T any](v T) *T { + return &v +} + +func int64ToStr(i int64) string { + return strconv.FormatInt(i, 10) +} + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +func newPanicErr(info any, stack []byte) error { + return &panicErr{ + info: info, + stack: stack, + } +} + +func jsonschemaToMap(sc *jsonschema.Schema) (map[string]any, error) { + if sc == nil { + return nil, fmt.Errorf("jsonschema is nil") + } + + val := reflect.ValueOf(sc) + val = val.Elem() + + if val.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct, got %v", val.Kind()) + } + + typ := val.Type() + result := make(map[string]any) + + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldValue := val.Field(i) + + if field.Name == "Extra" && !fieldValue.IsZero() { + return nil, fmt.Errorf("extra field must be nil") + } + + if !field.IsExported() { + continue + } + + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + + tagParts := strings.Split(jsonTag, ",") + keyName := tagParts[0] + + omitempty := false + for _, opt := range tagParts[1:] { + if opt == "omitempty" { + omitempty = true + break + } + } + + if omitempty && fieldValue.IsZero() { + continue + } + + result[keyName] = fieldValue.Interface() + } + + if sc.Extras != nil { + for k, v := range sc.Extras { + if _, ok := result[k]; ok { + return nil, fmt.Errorf("extra field %q is duplicated", k) + } + result[k] = v + } + } + + if sc.Type != "" { + result["type"] = sc.Type + } else if sc.TypeEnhanced != nil { + result["type"] = sc.TypeEnhanced + } + + return result, nil +} diff --git a/components/agentic/ark/utils_test.go b/components/agentic/ark/utils_test.go new file mode 100644 index 000000000..5df0631b6 --- /dev/null +++ b/components/agentic/ark/utils_test.go @@ -0,0 +1,48 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ark + +import ( + "github.com/bytedance/mockey" + "github.com/stretchr/testify/assert" + + "testing" +) + +func TestCoalesce(t *testing.T) { + mockey.PatchConvey("TestCoalesce", t, func() { + mockey.PatchConvey("x is not zero", func() { + x := "1" + y := "" + got := coalesce(x, y) + assert.Equal(t, x, got) + }) + + mockey.PatchConvey("x and y is pointer", func() { + x := ptrOf("1") + y := ptrOf("") + got := coalesce(x, y) + assert.Equal(t, x, got) + }) + + mockey.PatchConvey("x and y is nil pointer", func() { + var x, y *string + got := coalesce(x, y) + assert.Equal(t, x, got) + }) + }) +} diff --git a/components/agentic/openai/README.md b/components/agentic/openai/README.md new file mode 100644 index 000000000..a5f39eb63 --- /dev/null +++ b/components/agentic/openai/README.md @@ -0,0 +1,432 @@ +# OpenAI Agentic Model + +An OpenAI model implementation for [Eino](https://github.com/cloudwego/eino) that implements the `Model` interface in `agentic` component. This enables seamless integration with Eino's Agent capabilities for enhanced natural language processing and generation. + +## Features + +- Implements `github.com/cloudwego/eino/components/agentic.Model` +- Easy integration with Eino's agent system +- Configurable model parameters +- Support for responses api +- Support for streaming responses +- Support for tool calling (Function Tools, MCP Tools, Server Tools) +- Support for Azure OpenAI + +## Installation + +```bash +go get github.com/cloudwego/eino-ext/components/agentic/openai@latest +``` + +## Quick Start + +Here's a quick example of how to use the `Model`: + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/schema" + openaischema "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + am_, err := am.WithTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model with tools, err: %v", err) + } + + msg, err := am_.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*openaischema.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +## Configuration + +The `Model` can be configured using the `openai.Config` struct: + +```go +type Config struct { + // ByAzure specifies whether to use Azure OpenAI service. + // Optional. + ByAzure bool + + // BaseURL specifies the base URL for the OpenAI service endpoint. + // Optional. + BaseURL string + + // APIKey specifies the API key for authentication. + // Required. + APIKey string + + // Timeout specifies the maximum duration to wait for API responses. + // Optional. + Timeout *time.Duration + + // HTTPClient specifies the HTTP client used to send requests. + // Optional. + HTTPClient *http.Client + + // MaxRetries specifies the maximum number of retry attempts for failed requests. + // Optional. + MaxRetries *int + + // Model specifies the ID of the model to use for the response. + // Required. + Model string + + // MaxOutputTokens specifies the maximum number of tokens to generate in the response. + // Optional. + MaxOutputTokens *int64 + + // Temperature controls the randomness of the model's output. + // Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic. + // Range: 0.0 to 2.0. + // Optional. + Temperature *float64 + + // TopP controls diversity via nucleus sampling. + // It specifies the cumulative probability threshold for token selection. + // Recommended to use either Temperature or TopP, but not both. + // Range: 0.0 to 1.0. + // Optional. + TopP *float64 + + // ServiceTier specifies the latency tier for processing the request. + // Optional. + ServiceTier *responses.ResponseNewParamsServiceTier + + // Text specifies configuration for text generation output. + // Optional. + Text *responses.ResponseTextConfigParam + + // Reasoning specifies configuration for reasoning models. + // Optional. + Reasoning *responses.ReasoningParam + + // Store specifies whether to store the response on the server. + // Optional. + Store *bool + + // MaxToolCalls specifies the maximum number of tool calls allowed in a single turn. + // Optional. + MaxToolCalls *int + + // ParallelToolCalls specifies whether to allow multiple tool calls in a single turn. + // Optional. + ParallelToolCalls *bool + + // Include specifies a list of additional fields to include in the response. + // Optional. + Include []responses.ResponseIncludable + + // ServerTools specifies server-side tools available to the model. + // Optional. + ServerTools []*ServerToolConfig + + // MCPTools specifies Model Context Protocol tools available to the model. + // Optional. + MCPTools []*responses.ToolMcpParam + + // CustomHeader specifies custom HTTP headers to include in API requests. + // CustomHeader allows passing additional metadata or authentication information. + // Optional. + CustomHeader map[string]string + + // ExtraFields specifies additional fields that will be directly added to the HTTP request body. + // This allows for vendor-specific or future parameters not yet explicitly supported. + // Optional. + ExtraFields map[string]any +} +``` + +## Advanced Usage + +### Tool Calling + +The `Model` supports tool calling, including Function Tools, MCP Tools, and Server Tools. + +#### Function Tool Example + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + if lastBlock.Type != schema.ContentBlockTypeFunctionToolCall { + log.Fatalf("last block is not function tool call, type: %s", lastBlock.Type) + } + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + + +#### Server Tool Example + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/responses" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + Include: []responses.ResponseIncludable{ + responses.ResponseIncludableWebSearchCallActionSources, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*openai.ServerToolConfig{ + { + WebSearch: &responses.WebSearchToolParam{ + Type: responses.WebSearchToolTypeWebSearch, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(openai.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + openai.WithServerTools(serverTools), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's cloudwego/eino"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall != nil { + serverToolArgs := block.ServerToolCall.Arguments.(*openai.ServerToolCallArguments) + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + if block.ServerToolResult != nil { + result := block.ServerToolResult.Result.(*openai.ServerToolResult) + resultJSON, _ := sonic.MarshalIndent(result, " ", " ") + log.Printf("server_tool_result: %s\n", string(resultJSON)) + } + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +For more examples, please refer to the `examples` directory. diff --git a/components/agentic/openai/README.zh_CN.md b/components/agentic/openai/README.zh_CN.md new file mode 100644 index 000000000..be4f8a341 --- /dev/null +++ b/components/agentic/openai/README.zh_CN.md @@ -0,0 +1,432 @@ +# OpenAI Agentic Model + +基于 [Eino](https://github.com/cloudwego/eino) 的 OpenAI 模型实现,实现了 `agentic` 组件中的 `Model` 接口。这使得该模型能够无缝集成到 Eino 的 Agent 能力中,提供增强的自然语言处理和生成功能。 + +## 功能特性 + +- 实现了 `github.com/cloudwego/eino/components/agentic.Model` 接口 +- 易于集成到 Eino 的 agent 系统中 +- 可配置的模型参数 +- 支持 Responses API +- 支持流式响应 (Streaming) +- 支持工具调用 (Tools),包括函数工具 (Function Tools)、MCP 工具 (MCP Tools) 和服务器工具 (Server Tools) +- 支持 Azure OpenAI + +## 安装 + +```bash +go get github.com/cloudwego/eino-ext/components/agentic/openai@latest +``` + +## 快速开始 + +以下是如何使用 `Model` 的一个快速示例: + +```go +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/schema" + openaischema "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + am_, err := am.WithTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model with tools, err: %v", err) + } + + msg, err := am_.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*openaischema.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +## 配置 + +可以使用 `openai.Config` 结构体配置 `Model`: + +```go +type Config struct { + // ByAzure 指定是否使用 Azure OpenAI 服务。 + // 可选。 + ByAzure bool + + // BaseURL 指定 OpenAI 服务端点的基准 URL。 + // 可选。 + BaseURL string + + // APIKey 指定用于认证的 API 密钥。 + // 必填。 + APIKey string + + // Timeout 指定等待 API 响应的最大持续时间。 + // 可选。 + Timeout *time.Duration + + // HTTPClient 指定用于发送 HTTP 请求的客户端。 + // 可选。 + HTTPClient *http.Client + + // MaxRetries 指定失败请求的最大重试次数。 + // 可选。 + MaxRetries *int + + // Model 指定用于响应的模型 ID。 + // 必填。 + Model string + + // MaxOutputTokens 指定响应中生成的最大 token 数。 + // 可选。 + MaxOutputTokens *int64 + + // Temperature 控制模型输出的随机性。 + // 较高的值(如 0.8)使输出更随机,而较低的值(如 0.2)使输出更集中和确定。 + // 范围:0.0 到 2.0。 + // 可选。 + Temperature *float64 + + // TopP 通过核心采样控制多样性。 + // 它指定 token 选择的累积概率阈值。 + // 建议修改此项或 Temperature,但不要同时修改。 + // 范围:0.0 到 1.0。 + // 可选。 + TopP *float64 + + // ServiceTier 指定处理请求的延迟层级。 + // 可选。 + ServiceTier *responses.ResponseNewParamsServiceTier + + // Text 指定文本生成输出的配置。 + // 可选。 + Text *responses.ResponseTextConfigParam + + // Reasoning 指定推理模型的配置。 + // 可选。 + Reasoning *responses.ReasoningParam + + // Store 指定是否在服务器上存储响应。 + // 可选。 + Store *bool + + // MaxToolCalls 指定单轮中允许的最大工具调用次数。 + // 可选。 + MaxToolCalls *int + + // ParallelToolCalls 指定是否允许在单轮中进行多次工具调用。 + // 可选。 + ParallelToolCalls *bool + + // Include 指定响应中包含的额外字段列表。 + // 可选。 + Include []responses.ResponseIncludable + + // ServerTools 指定模型可用的服务器端工具。 + // 可选。 + ServerTools []*ServerToolConfig + + // MCPTools 指定模型可用的 Model Context Protocol 工具。 + // 可选。 + MCPTools []*responses.ToolMcpParam + + // CustomHeader 指定 API 请求中包含的自定义 HTTP 标头。 + // CustomHeader 允许传递额外的元数据或身份验证信息。 + // 可选。 + CustomHeader map[string]string + + // ExtraFields 指定将直接添加到 HTTP 请求体的额外字段。 + // 这允许支持尚未显式支持的供应商特定或未来参数。 + // 可选。 + ExtraFields map[string]any +} +``` + +## 高级用法 + +### 工具调用 (Tool Calling) + +`Model` 支持工具调用,包括函数工具、MCP 工具和服务器工具。 + +#### 函数工具示例 + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + if lastBlock.Type != schema.ContentBlockTypeFunctionToolCall { + log.Fatalf("last block is not function tool call, type: %s", lastBlock.Type) + } + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + + +#### 服务器工具示例 + +```go +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/responses" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + Include: []responses.ResponseIncludable{ + responses.ResponseIncludableWebSearchCallActionSources, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*openai.ServerToolConfig{ + { + WebSearch: &responses.WebSearchToolParam{ + Type: responses.WebSearchToolTypeWebSearch, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(openai.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + openai.WithServerTools(serverTools), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's cloudwego/eino"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall != nil { + serverToolArgs := block.ServerToolCall.Arguments.(*openai.ServerToolCallArguments) + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + if block.ServerToolResult != nil { + result := block.ServerToolResult.Result.(*openai.ServerToolResult) + resultJSON, _ := sonic.MarshalIndent(result, " ", " ") + log.Printf("server_tool_result: %s\n", string(resultJSON)) + } + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} +``` + +更多示例请参考 `examples` 目录。 diff --git a/components/agentic/openai/consts.go b/components/agentic/openai/consts.go new file mode 100644 index 000000000..9e5a620ac --- /dev/null +++ b/components/agentic/openai/consts.go @@ -0,0 +1,33 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +const implType = "OpenAI" + +type WebSearchAction string + +const ( + WebSearchActionSearch WebSearchAction = "search" + WebSearchActionOpenPage WebSearchAction = "open_page" + WebSearchActionFind WebSearchAction = "find" +) + +type ServerToolName string + +const ( + ServerToolNameWebSearch ServerToolName = "web_search" +) diff --git a/components/agentic/openai/content_block_extra.go b/components/agentic/openai/content_block_extra.go new file mode 100644 index 000000000..f278c625f --- /dev/null +++ b/components/agentic/openai/content_block_extra.go @@ -0,0 +1,98 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "reflect" + + "github.com/cloudwego/eino/schema" +) + +type blockExtraItemID string +type blockExtraItemStatus string + +const ( + itemIDKey = "openai-item-id" + itemStatusKey = "openai-item-status" +) + +func setItemID(block *schema.ContentBlock, itemID string) { + setBlockExtraValue(block, itemIDKey, blockExtraItemID(itemID)) +} + +func getItemID(block *schema.ContentBlock) (string, bool) { + itemID, ok := getBlockExtraValue[blockExtraItemID](block, itemIDKey) + if !ok { + return "", false + } + return string(itemID), true +} + +func setItemStatus(block *schema.ContentBlock, status string) { + setBlockExtraValue(block, itemStatusKey, blockExtraItemStatus(status)) +} + +func GetItemStatus(block *schema.ContentBlock) (string, bool) { + itemStatus, ok := getBlockExtraValue[blockExtraItemStatus](block, itemStatusKey) + if !ok { + return "", false + } + return string(itemStatus), true +} + +func setBlockExtraValue[T any](block *schema.ContentBlock, key string, value T) { + if block == nil { + return + } + if block.Extra == nil { + block.Extra = map[string]any{} + } + block.Extra[key] = value +} + +func getBlockExtraValue[T any](block *schema.ContentBlock, key string) (T, bool) { + var zero T + if block == nil { + return zero, false + } + if block.Extra == nil { + return zero, false + } + val, ok := block.Extra[key].(T) + if !ok { + return zero, false + } + return val, true +} + +func concatFirstNonZero[T any](chunks []T) (T, error) { + for _, chunk := range chunks { + if !reflect.ValueOf(chunk).IsZero() { + return chunk, nil + } + } + var zero T + return zero, nil +} + +func concatLast[T any](chunks []T) (T, error) { + if len(chunks) == 0 { + var zero T + return zero, nil + } + return chunks[len(chunks)-1], nil +} diff --git a/components/agentic/openai/content_block_extra_test.go b/components/agentic/openai/content_block_extra_test.go new file mode 100644 index 000000000..ab137a15a --- /dev/null +++ b/components/agentic/openai/content_block_extra_test.go @@ -0,0 +1,164 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +func TestSetItemID(t *testing.T) { + mockey.PatchConvey("TestSetItemID", t, func() { + mockey.PatchConvey("set value into Extra", func() { + block := &schema.ContentBlock{} + setItemID(block, "id-1") + + val, ok := block.Extra[itemIDKey] + assert.True(t, ok) + assert.Equal(t, blockExtraItemID("id-1"), val) + }) + + mockey.PatchConvey("nil block should not panic", func() { + assert.NotPanics(t, func() { + setItemID(nil, "id-1") + }) + }) + }) +} + +func TestGetItemID(t *testing.T) { + mockey.PatchConvey("TestGetItemID", t, func() { + mockey.PatchConvey("found", func() { + block := &schema.ContentBlock{Extra: map[string]any{itemIDKey: blockExtraItemID("id-2")}} + itemID, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "id-2", itemID) + }) + + mockey.PatchConvey("not found", func() { + block := &schema.ContentBlock{Extra: map[string]any{}} + itemID, ok := getItemID(block) + assert.False(t, ok) + assert.Equal(t, "", itemID) + }) + }) +} + +func TestSetItemStatus(t *testing.T) { + mockey.PatchConvey("TestSetItemStatus", t, func() { + mockey.PatchConvey("set value into Extra", func() { + block := &schema.ContentBlock{} + setItemStatus(block, "in_progress") + + val, ok := block.Extra[itemStatusKey] + assert.True(t, ok) + assert.Equal(t, blockExtraItemStatus("in_progress"), val) + }) + + mockey.PatchConvey("nil block should not panic", func() { + assert.NotPanics(t, func() { + setItemStatus(nil, "in_progress") + }) + }) + }) +} + +func TestGetItemStatus(t *testing.T) { + mockey.PatchConvey("TestGetItemStatus", t, func() { + mockey.PatchConvey("found", func() { + block := &schema.ContentBlock{Extra: map[string]any{itemStatusKey: blockExtraItemStatus("completed")}} + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, "completed", status) + }) + + mockey.PatchConvey("not found", func() { + block := &schema.ContentBlock{Extra: map[string]any{}} + status, ok := GetItemStatus(block) + assert.False(t, ok) + assert.Equal(t, "", status) + }) + }) +} + +func TestSetBlockExtraValue(t *testing.T) { + mockey.PatchConvey("TestSetBlockExtraValue", t, func() { + mockey.PatchConvey("nil block should not panic", func() { + assert.NotPanics(t, func() { + setBlockExtraValue[*schema.ContentBlock](nil, "k", nil) + }) + }) + + mockey.PatchConvey("init Extra map when nil", func() { + block := &schema.ContentBlock{} + setBlockExtraValue(block, "k", 123) + assert.Equal(t, 123, block.Extra["k"]) + }) + }) +} + +func TestGetBlockExtraValue(t *testing.T) { + mockey.PatchConvey("TestGetBlockExtraValue", t, func() { + mockey.PatchConvey("nil block", func() { + v, ok := getBlockExtraValue[int](nil, "k") + assert.False(t, ok) + assert.Equal(t, 0, v) + }) + + mockey.PatchConvey("type mismatch", func() { + block := &schema.ContentBlock{Extra: map[string]any{"k": "v"}} + v, ok := getBlockExtraValue[int](block, "k") + assert.False(t, ok) + assert.Equal(t, 0, v) + }) + }) +} + +func TestConcatFirstNonZero(t *testing.T) { + mockey.PatchConvey("TestConcatFirstNonZero", t, func() { + mockey.PatchConvey("pick first non-zero", func() { + v, err := concatFirstNonZero([]blockExtraItemID{"", "id"}) + assert.NoError(t, err) + assert.Equal(t, blockExtraItemID("id"), v) + }) + + mockey.PatchConvey("all zero", func() { + v, err := concatFirstNonZero([]blockExtraItemID{"", ""}) + assert.NoError(t, err) + assert.Equal(t, blockExtraItemID(""), v) + }) + }) +} + +func TestConcatLast(t *testing.T) { + mockey.PatchConvey("TestConcatLast", t, func() { + mockey.PatchConvey("non-empty", func() { + v, err := concatLast([]blockExtraItemStatus{"a", "b"}) + assert.NoError(t, err) + assert.Equal(t, blockExtraItemStatus("b"), v) + }) + + mockey.PatchConvey("empty", func() { + v, err := concatLast([]blockExtraItemStatus{}) + assert.NoError(t, err) + assert.Equal(t, blockExtraItemStatus(""), v) + }) + }) +} diff --git a/components/agentic/openai/convertor.go b/components/agentic/openai/convertor.go new file mode 100644 index 000000000..1d1e06442 --- /dev/null +++ b/components/agentic/openai/convertor.go @@ -0,0 +1,1351 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "strings" + "sync" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + "golang.org/x/sync/errgroup" +) + +func toSystemRoleInputItems(msg *schema.AgenticMessage) (items []responses.ResponseInputItemUnionParam, err error) { + items = make([]responses.ResponseInputItemUnionParam, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item responses.ResponseInputItemUnionParam + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.EasyInputMessageRoleSystem, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.EasyInputMessageRoleSystem, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with system role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func toDeveloperRoleInputItems(msg *schema.AgenticMessage) (items []responses.ResponseInputItemUnionParam, err error) { + items = make([]responses.ResponseInputItemUnionParam, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item responses.ResponseInputItemUnionParam + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.EasyInputMessageRoleDeveloper, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.EasyInputMessageRoleDeveloper, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type '%s' with developer role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func toAssistantRoleInputItems(msg *schema.AgenticMessage) (items []responses.ResponseInputItemUnionParam, err error) { + items = make([]responses.ResponseInputItemUnionParam, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item responses.ResponseInputItemUnionParam + + switch block.Type { + case schema.ContentBlockTypeAssistantGenText: + item, err = assistantGenTextToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert assistant generated text to input item, err: %w", err) + } + + case schema.ContentBlockTypeReasoning: + item, err = reasoningToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert reasoning to input item, err: %w", err) + } + + case schema.ContentBlockTypeFunctionToolCall: + item, err = functionToolCallToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeServerToolCall: + item, err = serverToolCallToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert server tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeServerToolResult: + item, err = serverToolResultToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert server tool result to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolApprovalRequest: + item, err = mcpToolApprovalRequestToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool approval request to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPListToolsResult: + item, err = mcpListToolsResultToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp list tools result to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolCall: + item, err = mcpToolCallToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool call to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolResult: + item, err = mcpToolResultToInputItem(block) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool result to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with assistant role", block.Type) + } + + items = append(items, item) + } + + items, err = pairMCPToolCallItems(items) + if err != nil { + return nil, fmt.Errorf("pairMCPToolCallItems failed: %w", err) + } + + items, err = pairWebServerToolCallItems(items) + if err != nil { + return nil, fmt.Errorf("pairWebServerToolCallItems failed: %w", err) + } + + return items, nil +} + +func pairMCPToolCallItems(items []responses.ResponseInputItemUnionParam) (newItems []responses.ResponseInputItemUnionParam, err error) { + processed := make(map[int]bool) + mcpCallItemIDIndices := make(map[string][]int) + + for i, item := range items { + mcpCall := item.OfMcpCall + if mcpCall == nil { + continue + } + + id := mcpCall.ID + if id == "" { + return nil, fmt.Errorf("found mcp tool call item with empty id") + } + + mcpCallItemIDIndices[id] = append(mcpCallItemIDIndices[id], i) + } + + for id, indices := range mcpCallItemIDIndices { + if len(indices) != 2 { + return nil, fmt.Errorf("mcp tool call %q should have exactly 2 items (call and result), "+ + "but found %d", id, len(indices)) + } + } + + for i, item := range items { + if processed[i] { + continue + } + + mcpCall := item.OfMcpCall + if mcpCall == nil { + newItems = append(newItems, item) + continue + } + + id := mcpCall.ID + indices := mcpCallItemIDIndices[id] + + var pairIndex int + if indices[0] == i { + pairIndex = indices[1] + } else { + pairIndex = indices[0] + } + + pairMcpCall := items[pairIndex].OfMcpCall + + mergedItem := responses.ResponseInputItemUnionParam{ + OfMcpCall: &responses.ResponseInputItemMcpCallParam{ + ID: mcpCall.ID, + ServerLabel: coalesce(mcpCall.ServerLabel, pairMcpCall.ServerLabel), + ApprovalRequestID: coalesce(mcpCall.ApprovalRequestID, pairMcpCall.ApprovalRequestID), + Name: mcpCall.Name, + Arguments: coalesce(mcpCall.Arguments, pairMcpCall.Arguments), + Output: coalesce(mcpCall.Output, pairMcpCall.Output), + Error: coalesce(mcpCall.Error, pairMcpCall.Error), + Status: coalesce(mcpCall.Status, pairMcpCall.Status), + }, + } + + newItems = append(newItems, mergedItem) + + processed[i] = true + processed[pairIndex] = true + } + + return newItems, nil +} + +func pairWebServerToolCallItems(items []responses.ResponseInputItemUnionParam) (newItems []responses.ResponseInputItemUnionParam, err error) { + processed := make(map[int]bool) + serverCallItemIDIndices := make(map[string][]int) + + for i, item := range items { + serverCall := item.OfWebSearchCall + if serverCall == nil { + continue + } + + id := serverCall.ID + if id == "" { + return nil, fmt.Errorf("found server tool call item with empty id at index %d", i) + } + + serverCallItemIDIndices[id] = append(serverCallItemIDIndices[id], i) + } + + for id, indices := range serverCallItemIDIndices { + if len(indices) != 2 { + return nil, fmt.Errorf("server tool call %q should have exactly 2 items (call and result), "+ + "but found %d", id, len(indices)) + } + } + + for i, item := range items { + if processed[i] { + continue + } + + serverCall := item.OfWebSearchCall + if serverCall == nil { + newItems = append(newItems, item) + continue + } + + id := serverCall.ID + indices := serverCallItemIDIndices[id] + + var pairIndex int + if indices[0] == i { + pairIndex = indices[1] + } else { + pairIndex = indices[0] + } + + pairServerCall := items[pairIndex].OfWebSearchCall + + mergedItem := responses.ResponseInputItemUnionParam{ + OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ + ID: serverCall.ID, + Action: pairWebSearchAction(serverCall.Action, pairServerCall.Action), + Status: coalesce(serverCall.Status, pairServerCall.Status), + }, + } + + newItems = append(newItems, mergedItem) + + processed[i] = true + processed[pairIndex] = true + } + + return newItems, nil +} + +func pairWebSearchAction(action, pairAction responses.ResponseFunctionWebSearchActionUnionParam) responses.ResponseFunctionWebSearchActionUnionParam { + ret := responses.ResponseFunctionWebSearchActionUnionParam{} + + if action.OfFind != nil { + ret.OfFind = action.OfFind + } else if pairAction.OfFind != nil { + ret.OfFind = pairAction.OfFind + } + + if action.OfOpenPage != nil { + ret.OfOpenPage = action.OfOpenPage + } else if pairAction.OfOpenPage != nil { + ret.OfOpenPage = pairAction.OfOpenPage + } + + if action.OfSearch == nil { + ret.OfSearch = pairAction.OfSearch + } + if pairAction.OfSearch == nil { + ret.OfSearch = action.OfSearch + } + if action.OfSearch != nil && pairAction.OfSearch != nil { + ret.OfSearch = action.OfSearch + if pairAction.OfSearch.Query != "" { + ret.OfSearch.Query = pairAction.OfSearch.Query + } + if len(pairAction.OfSearch.Sources) > 0 { + ret.OfSearch.Sources = pairAction.OfSearch.Sources + } + } + + return ret +} + +func toUserRoleInputItems(msg *schema.AgenticMessage) (items []responses.ResponseInputItemUnionParam, err error) { + items = make([]responses.ResponseInputItemUnionParam, 0, len(msg.ContentBlocks)) + + for _, block := range msg.ContentBlocks { + var item responses.ResponseInputItemUnionParam + + switch block.Type { + case schema.ContentBlockTypeUserInputText: + item, err = userInputTextToInputItem(responses.EasyInputMessageRoleUser, block.UserInputText) + if err != nil { + return nil, fmt.Errorf("failed to convert user input text to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputImage: + item, err = userInputImageToInputItem(responses.EasyInputMessageRoleUser, block.UserInputImage) + if err != nil { + return nil, fmt.Errorf("failed to convert user input image to input item, err: %w", err) + } + + case schema.ContentBlockTypeUserInputFile: + item, err = userInputFileToInputItem(responses.EasyInputMessageRoleUser, block.UserInputFile) + if err != nil { + return nil, fmt.Errorf("failed to convert user input file to input item, err: %w", err) + } + + case schema.ContentBlockTypeFunctionToolResult: + item, err = functionToolResultToInputItem(block.FunctionToolResult) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool result to input item, err: %w", err) + } + + case schema.ContentBlockTypeMCPToolApprovalResponse: + item, err = mcpToolApprovalResponseToInputItem(block.MCPToolApprovalResponse) + if err != nil { + return nil, fmt.Errorf("failed to convert mcp tool approval response to input item, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid content block type %q with user role", block.Type) + } + + items = append(items, item) + } + + return items, nil +} + +func userInputTextToInputItem(role responses.EasyInputMessageRole, block *schema.UserInputText) (item responses.ResponseInputItemUnionParam, err error) { + item = responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: role, + Content: responses.EasyInputMessageContentUnionParam{ + OfString: param.NewOpt(block.Text), + }, + }, + } + + return item, nil +} + +func userInputImageToInputItem(role responses.EasyInputMessageRole, block *schema.UserInputImage) (item responses.ResponseInputItemUnionParam, err error) { + imageURL, err := resolveURL(block.URL, block.Base64Data, block.MIMEType) + if err != nil { + return item, err + } + + detail, err := toInputItemImageDetail(block.Detail) + if err != nil { + return item, err + } + + contentItem := responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: newOpenaiStrOpt(imageURL), + Detail: detail, + }, + } + + msgItem := &responses.EasyInputMessageParam{ + Role: role, + Content: responses.EasyInputMessageContentUnionParam{ + OfInputItemContentList: []responses.ResponseInputContentUnionParam{ + contentItem, + }, + }, + } + + item = responses.ResponseInputItemUnionParam{ + OfMessage: msgItem, + } + + return item, nil +} + +func toInputItemImageDetail(detail schema.ImageURLDetail) (responses.ResponseInputImageDetail, error) { + if detail == "" { + return "", nil + } + switch detail { + case schema.ImageURLDetailHigh: + return responses.ResponseInputImageDetailHigh, nil + case schema.ImageURLDetailLow: + return responses.ResponseInputImageDetailLow, nil + case schema.ImageURLDetailAuto: + return responses.ResponseInputImageDetailAuto, nil + default: + return "", fmt.Errorf("invalid image detail: %s", detail) + } +} + +func userInputFileToInputItem(role responses.EasyInputMessageRole, block *schema.UserInputFile) (item responses.ResponseInputItemUnionParam, err error) { + fileURl, err := resolveURL(block.URL, block.Base64Data, block.MIMEType) + if err != nil { + return item, err + } + + contentItem := responses.ResponseInputContentUnionParam{ + OfInputFile: &responses.ResponseInputFileParam{ + Filename: newOpenaiStrOpt(block.Name), + }, + } + if block.URL != "" { + contentItem.OfInputFile.FileURL = newOpenaiStrOpt(fileURl) + } else if block.Base64Data != "" { + contentItem.OfInputFile.FileData = newOpenaiStrOpt(block.Base64Data) + } + + item = responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: role, + Content: responses.EasyInputMessageContentUnionParam{ + OfInputItemContentList: []responses.ResponseInputContentUnionParam{ + contentItem, + }, + }, + }, + } + + return item, nil +} + +func functionToolResultToInputItem(block *schema.FunctionToolResult) (item responses.ResponseInputItemUnionParam, err error) { + item = responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: block.CallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ + OfString: param.NewOpt(block.Result), + }, + }, + } + + return item, nil +} + +func assistantGenTextToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.AssistantGenText + if content == nil { + return item, fmt.Errorf("assistant generated text is nil") + } + + var annotations []responses.ResponseOutputTextAnnotationUnionParam + if content.OpenAIExtension != nil { + annotations = make([]responses.ResponseOutputTextAnnotationUnionParam, 0, len(content.OpenAIExtension.Annotations)) + for _, anno := range content.OpenAIExtension.Annotations { + if anno == nil { + return item, fmt.Errorf("text annotation is nil") + } + anno_, err := textAnnotationToOutputTextAnnotation(anno) + if err != nil { + return item, fmt.Errorf("failed to convert text annotation to output text annotation, err: %w", err) + } + annotations = append(annotations, anno_) + } + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + contentItem := responses.ResponseOutputMessageContentUnionParam{ + OfOutputText: &responses.ResponseOutputTextParam{ + Annotations: annotations, + Text: content.Text, + }, + } + + item = responses.ResponseInputItemUnionParam{ + OfOutputMessage: &responses.ResponseOutputMessageParam{ + ID: id, + Status: responses.ResponseOutputMessageStatus(status), + Content: []responses.ResponseOutputMessageContentUnionParam{contentItem}, + }, + } + + return item, nil +} + +func textAnnotationToOutputTextAnnotation(annotation *openai.TextAnnotation) (param responses.ResponseOutputTextAnnotationUnionParam, err error) { + switch annotation.Type { + case openai.TextAnnotationTypeFileCitation: + citation := annotation.FileCitation + if citation == nil { + return param, fmt.Errorf("file citation is nil") + } + return responses.ResponseOutputTextAnnotationUnionParam{ + OfFileCitation: &responses.ResponseOutputTextAnnotationFileCitationParam{ + Index: int64(citation.Index), + FileID: citation.FileID, + Filename: citation.Filename, + }, + }, nil + + case openai.TextAnnotationTypeURLCitation: + citation := annotation.URLCitation + if citation == nil { + return param, fmt.Errorf("url citation is nil") + } + return responses.ResponseOutputTextAnnotationUnionParam{ + OfURLCitation: &responses.ResponseOutputTextAnnotationURLCitationParam{ + Title: citation.Title, + URL: citation.URL, + StartIndex: int64(citation.StartIndex), + EndIndex: int64(citation.EndIndex), + }, + }, nil + + case openai.TextAnnotationTypeContainerFileCitation: + citation := annotation.ContainerFileCitation + if citation == nil { + return param, fmt.Errorf("container file citation is nil") + } + return responses.ResponseOutputTextAnnotationUnionParam{ + OfContainerFileCitation: &responses.ResponseOutputTextAnnotationContainerFileCitationParam{ + ContainerID: citation.ContainerID, + StartIndex: int64(citation.StartIndex), + EndIndex: int64(citation.EndIndex), + FileID: citation.FileID, + Filename: citation.Filename, + }, + }, nil + + case openai.TextAnnotationTypeFilePath: + filePath := annotation.FilePath + if filePath == nil { + return param, fmt.Errorf("file path is nil") + } + return responses.ResponseOutputTextAnnotationUnionParam{ + OfFilePath: &responses.ResponseOutputTextAnnotationFilePathParam{ + FileID: filePath.FileID, + Index: int64(filePath.Index), + }, + }, nil + + default: + return param, fmt.Errorf("invalid text annotation type: %s", annotation.Type) + } +} + +func functionToolCallToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.FunctionToolCall + if content == nil { + return item, fmt.Errorf("function tool call is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + item = responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + ID: newOpenaiStrOpt(id), + Status: responses.ResponseFunctionToolCallStatus(status), + CallID: content.CallID, + Name: content.Name, + Arguments: content.Arguments, + }, + } + + return item, nil +} + +func reasoningToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.Reasoning + if content == nil { + return item, fmt.Errorf("reasoning is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + summary := make([]responses.ResponseReasoningItemSummaryParam, 0, len(content.Summary)) + for _, s := range content.Summary { + summary = append(summary, responses.ResponseReasoningItemSummaryParam{ + Text: s.Text, + }) + } + + item = responses.ResponseInputItemUnionParam{ + OfReasoning: &responses.ResponseReasoningItemParam{ + ID: id, + Status: responses.ResponseReasoningItemStatus(status), + Summary: summary, + EncryptedContent: newOpenaiStrOpt(content.EncryptedContent), + }, + } + + return item, nil +} + +func serverToolCallToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.ServerToolCall + if content == nil { + return item, fmt.Errorf("server tool call is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + arguments, err := getServerToolCallArguments(content) + if err != nil { + return item, err + } + + var action responses.ResponseFunctionWebSearchActionUnionParam + switch { + case arguments.WebSearch != nil: + action, err = getWebSearchToolCallActionParam(arguments.WebSearch) + default: + return item, fmt.Errorf("server tool call arguments are nil") + } + if err != nil { + return item, err + } + + item = responses.ResponseInputItemUnionParam{ + OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ + ID: id, + Status: responses.ResponseFunctionWebSearchStatus(status), + Action: action, + }, + } + + return item, nil +} + +func getWebSearchToolCallActionParam(ws *WebSearchArguments) (action responses.ResponseFunctionWebSearchActionUnionParam, err error) { + switch ws.ActionType { + case WebSearchActionSearch: + return responses.ResponseFunctionWebSearchActionUnionParam{ + OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{ + Query: ws.Search.Query, + }, + }, nil + + case WebSearchActionOpenPage: + return responses.ResponseFunctionWebSearchActionUnionParam{ + OfOpenPage: &responses.ResponseFunctionWebSearchActionOpenPageParam{ + URL: ws.OpenPage.URL, + }, + }, nil + + case WebSearchActionFind: + return responses.ResponseFunctionWebSearchActionUnionParam{ + OfFind: &responses.ResponseFunctionWebSearchActionFindParam{ + URL: ws.Find.URL, + Pattern: ws.Find.Pattern, + }, + }, nil + + default: + return action, fmt.Errorf("invalid web search action type: %s", ws.ActionType) + } +} + +func serverToolResultToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.ServerToolResult + if content == nil { + return item, fmt.Errorf("server tool result is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + result, err := getServerToolResult(content) + if err != nil { + return item, err + } + + var action responses.ResponseFunctionWebSearchActionUnionParam + switch { + case result.WebSearch != nil: + action, err = getWebSearchToolResultActionParam(result.WebSearch) + default: + return item, fmt.Errorf("server tool result is nil") + } + if err != nil { + return item, err + } + + item = responses.ResponseInputItemUnionParam{ + OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ + ID: id, + Status: responses.ResponseFunctionWebSearchStatus(status), + Action: action, + }, + } + + return item, nil +} + +func getWebSearchToolResultActionParam(ws *WebSearchResult) (action responses.ResponseFunctionWebSearchActionUnionParam, err error) { + switch ws.ActionType { + case WebSearchActionSearch: + sources := make([]responses.ResponseFunctionWebSearchActionSearchSourceParam, 0, len(ws.Search.Sources)) + for _, s := range ws.Search.Sources { + sources = append(sources, responses.ResponseFunctionWebSearchActionSearchSourceParam{ + URL: s.URL, + }) + } + return responses.ResponseFunctionWebSearchActionUnionParam{ + OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{ + Sources: sources, + }, + }, nil + + default: + return action, fmt.Errorf("invalid web search result action type: %s", ws.ActionType) + } +} + +func mcpToolApprovalRequestToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.MCPToolApprovalRequest + if content == nil { + return item, fmt.Errorf("mcp tool approval request is nil") + } + + id, _ := getItemID(block) + + item = responses.ResponseInputItemUnionParam{ + OfMcpApprovalRequest: &responses.ResponseInputItemMcpApprovalRequestParam{ + ID: id, + ServerLabel: content.ServerLabel, + Name: content.Name, + Arguments: content.Arguments, + }, + } + + return item, nil +} + +func mcpToolApprovalResponseToInputItem(block *schema.MCPToolApprovalResponse) (item responses.ResponseInputItemUnionParam, err error) { + item = responses.ResponseInputItemUnionParam{ + OfMcpApprovalResponse: &responses.ResponseInputItemMcpApprovalResponseParam{ + ApprovalRequestID: block.ApprovalRequestID, + Approve: block.Approve, + Reason: newOpenaiStrOpt(block.Reason), + }, + } + + return item, nil +} + +func mcpListToolsResultToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.MCPListToolsResult + if content == nil { + return item, fmt.Errorf("mcp list tools result is nil") + } + + tools := make([]responses.ResponseInputItemMcpListToolsToolParam, 0, len(content.Tools)) + for i := range content.Tools { + tool := content.Tools[i] + + tools = append(tools, responses.ResponseInputItemMcpListToolsToolParam{ + Name: tool.Name, + Description: newOpenaiStrOpt(tool.Description), + InputSchema: tool.InputSchema, + }) + } + + id, _ := getItemID(block) + + item = responses.ResponseInputItemUnionParam{ + OfMcpListTools: &responses.ResponseInputItemMcpListToolsParam{ + ID: id, + ServerLabel: content.ServerLabel, + Tools: tools, + Error: newOpenaiStrOpt(content.Error), + }, + } + + return item, nil +} + +func mcpToolCallToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.MCPToolCall + if content == nil { + return item, fmt.Errorf("mcp tool call is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + item = responses.ResponseInputItemUnionParam{ + OfMcpCall: &responses.ResponseInputItemMcpCallParam{ + ID: id, + ApprovalRequestID: newOpenaiStrOpt(content.ApprovalRequestID), + ServerLabel: content.ServerLabel, + Arguments: content.Arguments, + Name: content.Name, + Status: status, + }, + } + + return item, nil +} + +func mcpToolResultToInputItem(block *schema.ContentBlock) (item responses.ResponseInputItemUnionParam, err error) { + content := block.MCPToolResult + if content == nil { + return item, fmt.Errorf("mcp tool result is nil") + } + + id, _ := getItemID(block) + status, _ := GetItemStatus(block) + + var errorMsg string + if content.Error != nil { + errorMsg = content.Error.Message + } + + item = responses.ResponseInputItemUnionParam{ + OfMcpCall: &responses.ResponseInputItemMcpCallParam{ + ID: id, + ServerLabel: content.ServerLabel, + Name: content.Name, + Error: newOpenaiStrOpt(errorMsg), + Output: newOpenaiStrOpt(content.Result), + Status: status, + }, + } + + return item, nil +} + +func toOutputMessage(resp *responses.Response) (msg *schema.AgenticMessage, err error) { + blocks := make([]*schema.ContentBlock, 0, len(resp.Output)) + + for _, item := range resp.Output { + var tmpBlocks []*schema.ContentBlock + + switch variant := item.AsAny().(type) { + case responses.ResponseReasoningItem: + block, err := reasoningToContentBlocks(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert reasoning to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case responses.ResponseOutputMessage: + tmpBlocks, err = outputMessageToContentBlocks(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert output message to content blocks, err: %w", err) + } + + case responses.ResponseFunctionToolCall: + block, err := functionToolCallToContentBlock(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert function tool call to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case responses.ResponseOutputItemMcpListTools: + block, err := mcpListToolsToContentBlock(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp list tools to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case responses.ResponseOutputItemMcpCall: + tmpBlocks, err = mcpCallToContentBlocks(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp call to content block, err: %w", err) + } + + case responses.ResponseOutputItemMcpApprovalRequest: + block, err := mcpApprovalRequestToContentBlock(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert function mcp approval request to content block, err: %w", err) + } + + tmpBlocks = append(tmpBlocks, block) + + case responses.ResponseFunctionWebSearch: + tmpBlocks, err = webSearchToContentBlocks(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert function web search to content block, err: %w", err) + } + + default: + return nil, fmt.Errorf("invalid output item type: %T", variant) + } + + blocks = append(blocks, tmpBlocks...) + } + + msg = &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: blocks, + ResponseMeta: responseObjectToResponseMeta(resp), + } + + return msg, nil +} + +func outputMessageToContentBlocks(item responses.ResponseOutputMessage) (blocks []*schema.ContentBlock, err error) { + blocks = make([]*schema.ContentBlock, 0, len(item.Content)) + + for _, content := range item.Content { + var block *schema.ContentBlock + + switch variant := content.AsAny().(type) { + case responses.ResponseOutputText: + block, err = outputContentTextToContentBlock(variant) + if err != nil { + return nil, fmt.Errorf("failed to convert output text to content block, err: %w", err) + } + + case responses.ResponseOutputRefusal: + block = schema.NewContentBlock(&schema.AssistantGenText{ + OpenAIExtension: &openai.AssistantGenTextExtension{ + Refusal: &openai.OutputRefusal{ + Reason: variant.Refusal, + }, + }, + }) + + default: + return nil, fmt.Errorf("invalid output message content type: %s", content.Type) + } + + setItemID(block, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(block, s) + } + + blocks = append(blocks, block) + } + + return blocks, nil +} + +func outputContentTextToContentBlock(text responses.ResponseOutputText) (block *schema.ContentBlock, err error) { + annotations := make([]*openai.TextAnnotation, 0, len(text.Annotations)) + for _, union := range text.Annotations { + anno, err := outputTextAnnotationToTextAnnotation(union) + if err != nil { + return nil, fmt.Errorf("failed to convert text annotation to text annotation, err: %w", err) + } + annotations = append(annotations, anno) + } + + block = schema.NewContentBlock(&schema.AssistantGenText{ + Text: text.Text, + OpenAIExtension: &openai.AssistantGenTextExtension{ + Annotations: annotations, + }, + }) + + return block, nil +} + +func outputTextAnnotationToTextAnnotation(anno responses.ResponseOutputTextAnnotationUnion) (*openai.TextAnnotation, error) { + switch variant := anno.AsAny().(type) { + case responses.ResponseOutputTextAnnotationFileCitation: + return &openai.TextAnnotation{ + Type: openai.TextAnnotationTypeFileCitation, + FileCitation: &openai.TextAnnotationFileCitation{ + Index: int(variant.Index), + FileID: variant.FileID, + Filename: variant.Filename, + }, + }, nil + + case responses.ResponseOutputTextAnnotationURLCitation: + return &openai.TextAnnotation{ + Type: openai.TextAnnotationTypeURLCitation, + URLCitation: &openai.TextAnnotationURLCitation{ + Title: variant.Title, + URL: variant.URL, + StartIndex: int(variant.StartIndex), + EndIndex: int(variant.EndIndex), + }, + }, nil + + case responses.ResponseOutputTextAnnotationContainerFileCitation: + return &openai.TextAnnotation{ + Type: openai.TextAnnotationTypeContainerFileCitation, + ContainerFileCitation: &openai.TextAnnotationContainerFileCitation{ + ContainerID: variant.ContainerID, + FileID: variant.FileID, + Filename: variant.Filename, + StartIndex: int(variant.StartIndex), + EndIndex: int(variant.EndIndex), + }, + }, nil + + case responses.ResponseOutputTextAnnotationFilePath: + return &openai.TextAnnotation{ + Type: openai.TextAnnotationTypeFilePath, + FilePath: &openai.TextAnnotationFilePath{ + FileID: variant.FileID, + Index: int(variant.Index), + }, + }, nil + + default: + return nil, fmt.Errorf("invalid annotation type: %s", anno.Type) + } +} + +func functionToolCallToContentBlock(item responses.ResponseFunctionToolCall) (block *schema.ContentBlock, err error) { + block = schema.NewContentBlock(&schema.FunctionToolCall{ + CallID: item.CallID, + Name: item.Name, + Arguments: item.Arguments, + }) + + setItemID(block, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(block, s) + } + + return block, nil +} + +func webSearchToContentBlocks(item responses.ResponseFunctionWebSearch) (blocks []*schema.ContentBlock, err error) { + var ( + args *ServerToolCallArguments + res *ServerToolResult + ) + + switch variant := item.Action.AsAny().(type) { + case responses.ResponseFunctionWebSearchActionSearch: + args = &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: WebSearchActionSearch, + Search: &WebSearchQuery{ + Query: variant.Query, + }, + }, + } + + sources := make([]*WebSearchQuerySource, 0, len(variant.Sources)) + for _, src := range variant.Sources { + sources = append(sources, &WebSearchQuerySource{ + URL: src.URL, + }) + } + res = &ServerToolResult{ + WebSearch: &WebSearchResult{ + ActionType: WebSearchActionSearch, + Search: &WebSearchQueryResult{ + Sources: sources, + }, + }, + } + + case responses.ResponseFunctionWebSearchActionOpenPage: + args = &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: WebSearchActionOpenPage, + OpenPage: &WebSearchOpenPage{ + URL: variant.URL, + }, + }, + } + + case responses.ResponseFunctionWebSearchActionFind: + args = &ServerToolCallArguments{ + WebSearch: &WebSearchArguments{ + ActionType: WebSearchActionFind, + Find: &WebSearchFind{ + URL: variant.URL, + Pattern: variant.Pattern, + }, + }, + } + + default: + return nil, fmt.Errorf("invalid web search variant type: %s", item.Type) + } + + callBlock := schema.NewContentBlock(&schema.ServerToolCall{ + Name: string(ServerToolNameWebSearch), + Arguments: args, + }) + setItemID(callBlock, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(callBlock, s) + } + + resBlock := schema.NewContentBlock(&schema.ServerToolResult{ + Name: string(ServerToolNameWebSearch), + Result: res, + }) + setItemID(resBlock, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(resBlock, s) + } + + blocks = []*schema.ContentBlock{callBlock, resBlock} + + return blocks, nil +} + +func reasoningToContentBlocks(item responses.ResponseReasoningItem) (block *schema.ContentBlock, err error) { + summary := make([]*schema.ReasoningSummary, 0, len(item.Summary)) + for _, s := range item.Summary { + summary = append(summary, &schema.ReasoningSummary{ + Text: s.Text, + }) + } + + block = schema.NewContentBlock(&schema.Reasoning{ + Summary: summary, + }) + + setItemID(block, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(block, s) + } + + return block, nil +} + +func mcpCallToContentBlocks(item responses.ResponseOutputItemMcpCall) (blocks []*schema.ContentBlock, err error) { + callBlock := schema.NewContentBlock(&schema.MCPToolCall{ + ServerLabel: item.ServerLabel, + ApprovalRequestID: item.ApprovalRequestID, + Name: item.Name, + Arguments: item.Arguments, + }) + setItemID(callBlock, item.ID) + + resultBlock := schema.NewContentBlock(&schema.MCPToolResult{ + ServerLabel: item.ServerLabel, + Name: item.Name, + Result: item.Output, + Error: func() *schema.MCPToolCallError { + if item.Error == "" { + return nil + } + return &schema.MCPToolCallError{ + Message: item.Error, + } + }(), + }) + setItemID(resultBlock, item.ID) + + blocks = []*schema.ContentBlock{callBlock, resultBlock} + + return blocks, nil +} + +func mcpListToolsToContentBlock(item responses.ResponseOutputItemMcpListTools) (block *schema.ContentBlock, err error) { + group := &errgroup.Group{} + group.SetLimit(5) + mu := sync.Mutex{} + + tools := make([]*schema.MCPListToolsItem, 0, len(item.Tools)) + for i := range item.Tools { + tool := item.Tools[i] + + group.Go(func() error { + b, err := sonic.Marshal(tool.InputSchema) + if err != nil { + return fmt.Errorf("failed to marshal tool input schema, err: %w", err) + } + + sc := &jsonschema.Schema{} + if err := sonic.Unmarshal(b, sc); err != nil { + return fmt.Errorf("failed to unmarshal tool input schema, err: %w", err) + } + + mu.Lock() + defer mu.Unlock() + + tools = append(tools, &schema.MCPListToolsItem{ + Name: tool.Name, + Description: tool.Description, + InputSchema: sc, + }) + + return nil + }) + } + + if err = group.Wait(); err != nil { + return nil, err + } + + block = schema.NewContentBlock(&schema.MCPListToolsResult{ + ServerLabel: item.ServerLabel, + Tools: tools, + Error: item.Error, + }) + + setItemID(block, item.ID) + + return block, nil +} + +func mcpApprovalRequestToContentBlock(item responses.ResponseOutputItemMcpApprovalRequest) (block *schema.ContentBlock, err error) { + block = schema.NewContentBlock(&schema.MCPToolApprovalRequest{ + ID: item.ID, + ServerLabel: item.ServerLabel, + Name: item.Name, + Arguments: item.Arguments, + }) + + setItemID(block, item.ID) + + return block, nil +} + +func responseObjectToResponseMeta(res *responses.Response) *schema.AgenticResponseMeta { + return &schema.AgenticResponseMeta{ + TokenUsage: toTokenUsage(res), + OpenAIExtension: toResponseMetaExtension(res), + } +} + +func toTokenUsage(resp *responses.Response) (tokenUsage *schema.TokenUsage) { + usage := &schema.TokenUsage{ + PromptTokens: int(resp.Usage.InputTokens), + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: int(resp.Usage.InputTokensDetails.CachedTokens), + }, + CompletionTokens: int(resp.Usage.OutputTokens), + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: int(resp.Usage.OutputTokensDetails.ReasoningTokens), + }, + TotalTokens: int(resp.Usage.TotalTokens), + } + + return usage +} + +func toResponseMetaExtension(resp *responses.Response) *openai.ResponseMetaExtension { + var incompleteDetails *openai.IncompleteDetails + if resp.IncompleteDetails.Reason != "" { + incompleteDetails = &openai.IncompleteDetails{ + Reason: resp.IncompleteDetails.Reason, + } + } + + var respErr *openai.ResponseError + if resp.Error.Code != "" || resp.Error.Message != "" { + respErr = &openai.ResponseError{ + Code: openai.ResponseErrorCode(resp.Error.Code), + Message: resp.Error.Message, + } + } + + reasoning := &openai.Reasoning{ + Effort: openai.ReasoningEffort(resp.Reasoning.Effort), + Summary: openai.ReasoningSummary(resp.Reasoning.Summary), + } + + extension := &openai.ResponseMetaExtension{ + ID: resp.ID, + Status: openai.ResponseStatus(resp.Status), + Error: respErr, + IncompleteDetails: incompleteDetails, + PreviousResponseID: resp.PreviousResponseID, + Reasoning: reasoning, + ServiceTier: openai.ServiceTier(resp.ServiceTier), + CreatedAt: int64(resp.CreatedAt), + PromptCacheRetention: openai.PromptCacheRetention(resp.PromptCacheRetention), + } + + return extension +} + +func resolveURL(url string, base64Data string, mimeType string) (real string, err error) { + if url != "" { + return url, nil + } + + if mimeType == "" { + return "", fmt.Errorf("mimeType is required when using base64Data") + } + + real, err = ensureDataURL(base64Data, mimeType) + if err != nil { + return "", err + } + + return real, nil +} + +func ensureDataURL(base64Data, mimeType string) (string, error) { + if strings.HasPrefix(base64Data, "data:") { + return "", fmt.Errorf("base64Data field must be a raw base64 string, but got a string with prefix 'data:'") + } + if mimeType == "" { + return "", fmt.Errorf("mimeType is required") + } + return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil +} diff --git a/components/agentic/openai/convertor_test.go b/components/agentic/openai/convertor_test.go new file mode 100644 index 000000000..76273a713 --- /dev/null +++ b/components/agentic/openai/convertor_test.go @@ -0,0 +1,838 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/schema" + openaischema "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" +) + +func TestToSystemRoleInputItems(t *testing.T) { + mockey.PatchConvey("toSystemRoleInputItems", t, func() { + mockey.PatchConvey("user_input_text", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "hi"}), + }} + items, err := toSystemRoleInputItems(msg) + assert.NoError(t, err) + if assert.Len(t, items, 1) { + assert.NotNil(t, items[0].OfMessage) + assert.Equal(t, responses.EasyInputMessageRoleSystem, items[0].OfMessage.Role) + assert.True(t, items[0].OfMessage.Content.OfString.Valid()) + assert.Equal(t, "hi", items[0].OfMessage.Content.OfString.Value) + } + }) + + mockey.PatchConvey("invalid_block_type", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "x"}), + }} + _, err := toSystemRoleInputItems(msg) + assert.Error(t, err) + }) + }) +} + +func TestToDeveloperRoleInputItems(t *testing.T) { + mockey.PatchConvey("toDeveloperRoleInputItems", t, func() { + mockey.PatchConvey("user_input_image", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputImage{URL: "http://example.com/a.png", Detail: schema.ImageURLDetailHigh}), + }} + items, err := toDeveloperRoleInputItems(msg) + assert.NoError(t, err) + if assert.Len(t, items, 1) { + assert.NotNil(t, items[0].OfMessage) + assert.Equal(t, responses.EasyInputMessageRoleDeveloper, items[0].OfMessage.Role) + list := items[0].OfMessage.Content.OfInputItemContentList + if assert.Len(t, list, 1) { + img := list[0].OfInputImage + assert.NotNil(t, img) + assert.True(t, img.ImageURL.Valid()) + assert.Equal(t, "http://example.com/a.png", img.ImageURL.Value) + assert.Equal(t, responses.ResponseInputImageDetailHigh, img.Detail) + } + } + }) + + mockey.PatchConvey("invalid_block_type", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.FunctionToolCall{Name: "t"}), + }} + _, err := toDeveloperRoleInputItems(msg) + assert.Error(t, err) + }) + }) +} + +func TestToAssistantRoleInputItems(t *testing.T) { + mockey.PatchConvey("toAssistantRoleInputItems", t, func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{}} + + assistantText := schema.NewContentBlock(&schema.AssistantGenText{Text: "ok"}) + setItemID(assistantText, "msg1") + setItemStatus(assistantText, "completed") + msg.ContentBlocks = append(msg.ContentBlocks, assistantText) + + reasoning := schema.NewContentBlock(&schema.Reasoning{Summary: []*schema.ReasoningSummary{{Text: "r"}}}) + setItemID(reasoning, "r1") + setItemStatus(reasoning, "completed") + msg.ContentBlocks = append(msg.ContentBlocks, reasoning) + + fc := schema.NewContentBlock(&schema.FunctionToolCall{CallID: "c1", Name: "f", Arguments: "{}"}) + setItemID(fc, "f1") + setItemStatus(fc, "completed") + msg.ContentBlocks = append(msg.ContentBlocks, fc) + + wsCall := schema.NewContentBlock(&schema.ServerToolCall{ + Name: string(ServerToolNameWebSearch), + Arguments: &ServerToolCallArguments{WebSearch: &WebSearchArguments{ActionType: WebSearchActionSearch, Search: &WebSearchQuery{Query: "q"}}}, + }) + setItemID(wsCall, "ws1") + setItemStatus(wsCall, "in_progress") + msg.ContentBlocks = append(msg.ContentBlocks, wsCall) + + wsRes := schema.NewContentBlock(&schema.ServerToolResult{ + Name: string(ServerToolNameWebSearch), + Result: &ServerToolResult{WebSearch: &WebSearchResult{ActionType: WebSearchActionSearch, Search: &WebSearchQueryResult{Sources: []*WebSearchQuerySource{{URL: "u"}}}}}, + }) + setItemID(wsRes, "ws1") + setItemStatus(wsRes, "completed") + msg.ContentBlocks = append(msg.ContentBlocks, wsRes) + + mcpCall := schema.NewContentBlock(&schema.MCPToolCall{ServerLabel: "srv", Name: "tool", Arguments: "{\"a\":1}"}) + setItemID(mcpCall, "m1") + setItemStatus(mcpCall, "calling") + msg.ContentBlocks = append(msg.ContentBlocks, mcpCall) + + mcpRes := schema.NewContentBlock(&schema.MCPToolResult{ServerLabel: "srv", Name: "tool", Result: "out"}) + setItemID(mcpRes, "m1") + setItemStatus(mcpRes, "completed") + msg.ContentBlocks = append(msg.ContentBlocks, mcpRes) + + items, err := toAssistantRoleInputItems(msg) + assert.NoError(t, err) + assert.Len(t, items, 5) + + var gotMcp *responses.ResponseInputItemMcpCallParam + var gotWS *responses.ResponseFunctionWebSearchParam + for i := range items { + if items[i].OfMcpCall != nil { + gotMcp = items[i].OfMcpCall + } + if items[i].OfWebSearchCall != nil { + gotWS = items[i].OfWebSearchCall + } + } + if assert.NotNil(t, gotMcp) { + assert.Equal(t, "m1", gotMcp.ID) + assert.Equal(t, "srv", gotMcp.ServerLabel) + assert.Equal(t, "tool", gotMcp.Name) + assert.Equal(t, "{\"a\":1}", gotMcp.Arguments) + assert.True(t, gotMcp.Output.Valid()) + assert.Equal(t, "out", gotMcp.Output.Value) + } + if assert.NotNil(t, gotWS) { + assert.Equal(t, "ws1", gotWS.ID) + assert.NotNil(t, gotWS.Action.OfSearch) + assert.Equal(t, "q", gotWS.Action.OfSearch.Query) + assert.Len(t, gotWS.Action.OfSearch.Sources, 1) + assert.Equal(t, "u", gotWS.Action.OfSearch.Sources[0].URL) + } + }) +} + +func TestPairMCPToolCallItems(t *testing.T) { + mockey.PatchConvey("pairMCPToolCallItems", t, func() { + mockey.PatchConvey("merge_call_and_result", func() { + items := []responses.ResponseInputItemUnionParam{ + {OfMcpCall: &responses.ResponseInputItemMcpCallParam{ID: "m1", ServerLabel: "s", Name: "n", Arguments: "{}"}}, + {OfMcpCall: &responses.ResponseInputItemMcpCallParam{ID: "m1", ServerLabel: "s", Name: "n", Output: param.NewOpt("out")}}, + } + newItems, err := pairMCPToolCallItems(items) + assert.NoError(t, err) + if assert.Len(t, newItems, 1) { + m := newItems[0].OfMcpCall + assert.NotNil(t, m) + assert.Equal(t, "{}", m.Arguments) + assert.True(t, m.Output.Valid()) + assert.Equal(t, "out", m.Output.Value) + } + }) + + mockey.PatchConvey("missing_pair", func() { + items := []responses.ResponseInputItemUnionParam{{OfMcpCall: &responses.ResponseInputItemMcpCallParam{ID: "m1", ServerLabel: "s", Name: "n", Arguments: "{}"}}} + _, err := pairMCPToolCallItems(items) + assert.Error(t, err) + }) + }) +} + +func TestPairWebServerToolCallItems(t *testing.T) { + mockey.PatchConvey("pairWebServerToolCallItems", t, func() { + mockey.PatchConvey("merge_call_and_result", func() { + items := []responses.ResponseInputItemUnionParam{ + {OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ID: "ws1", Action: responses.ResponseFunctionWebSearchActionUnionParam{OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{Query: "q"}}}}, + {OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ID: "ws1", Action: responses.ResponseFunctionWebSearchActionUnionParam{OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{Sources: []responses.ResponseFunctionWebSearchActionSearchSourceParam{{URL: "u"}}}}}}, + } + newItems, err := pairWebServerToolCallItems(items) + assert.NoError(t, err) + if assert.Len(t, newItems, 1) { + ws := newItems[0].OfWebSearchCall + assert.NotNil(t, ws) + assert.NotNil(t, ws.Action.OfSearch) + assert.Equal(t, "q", ws.Action.OfSearch.Query) + assert.Len(t, ws.Action.OfSearch.Sources, 1) + assert.Equal(t, "u", ws.Action.OfSearch.Sources[0].URL) + } + }) + + mockey.PatchConvey("missing_pair", func() { + items := []responses.ResponseInputItemUnionParam{{OfWebSearchCall: &responses.ResponseFunctionWebSearchParam{ID: "ws1"}}} + _, err := pairWebServerToolCallItems(items) + assert.Error(t, err) + }) + }) +} + +func TestPairWebSearchAction(t *testing.T) { + mockey.PatchConvey("pairWebSearchAction", t, func() { + a := responses.ResponseFunctionWebSearchActionUnionParam{OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{Query: "q"}} + b := responses.ResponseFunctionWebSearchActionUnionParam{OfSearch: &responses.ResponseFunctionWebSearchActionSearchParam{Query: "q2", Sources: []responses.ResponseFunctionWebSearchActionSearchSourceParam{{URL: "u"}}}} + merged := pairWebSearchAction(a, b) + assert.NotNil(t, merged.OfSearch) + assert.Equal(t, "q2", merged.OfSearch.Query) + assert.Len(t, merged.OfSearch.Sources, 1) + assert.Equal(t, "u", merged.OfSearch.Sources[0].URL) + }) +} + +func TestToUserRoleInputItems(t *testing.T) { + mockey.PatchConvey("toUserRoleInputItems", t, func() { + mockey.PatchConvey("mix_user_inputs", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.UserInputText{Text: "hi"}), + schema.NewContentBlock(&schema.FunctionToolResult{CallID: "c", Result: "r"}), + schema.NewContentBlock(&schema.MCPToolApprovalResponse{ApprovalRequestID: "a", Approve: true, Reason: "ok"}), + }} + items, err := toUserRoleInputItems(msg) + assert.NoError(t, err) + assert.Len(t, items, 3) + }) + + mockey.PatchConvey("invalid_block_type", func() { + msg := &schema.AgenticMessage{ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.Reasoning{}), + }} + _, err := toUserRoleInputItems(msg) + assert.Error(t, err) + }) + }) +} + +func TestUserInputTextToInputItem(t *testing.T) { + mockey.PatchConvey("userInputTextToInputItem", t, func() { + item, err := userInputTextToInputItem(responses.EasyInputMessageRoleUser, &schema.UserInputText{Text: "hi"}) + assert.NoError(t, err) + assert.NotNil(t, item.OfMessage) + assert.True(t, item.OfMessage.Content.OfString.Valid()) + assert.Equal(t, "hi", item.OfMessage.Content.OfString.Value) + }) +} + +func TestUserInputImageToInputItem(t *testing.T) { + mockey.PatchConvey("userInputImageToInputItem", t, func() { + mockey.PatchConvey("url", func() { + item, err := userInputImageToInputItem(responses.EasyInputMessageRoleUser, &schema.UserInputImage{URL: "http://x", Detail: schema.ImageURLDetailAuto}) + assert.NoError(t, err) + assert.NotNil(t, item.OfMessage) + list := item.OfMessage.Content.OfInputItemContentList + if assert.Len(t, list, 1) { + img := list[0].OfInputImage + assert.NotNil(t, img) + assert.True(t, img.ImageURL.Valid()) + assert.Equal(t, "http://x", img.ImageURL.Value) + assert.Equal(t, responses.ResponseInputImageDetailAuto, img.Detail) + } + }) + + mockey.PatchConvey("base64_missing_mime", func() { + _, err := userInputImageToInputItem(responses.EasyInputMessageRoleUser, &schema.UserInputImage{Base64Data: "abc"}) + assert.Error(t, err) + }) + }) +} + +func TestToInputItemImageDetail(t *testing.T) { + mockey.PatchConvey("toInputItemImageDetail", t, func() { + mockey.PatchConvey("empty", func() { + d, err := toInputItemImageDetail("") + assert.NoError(t, err) + assert.Equal(t, responses.ResponseInputImageDetail(""), d) + }) + mockey.PatchConvey("invalid", func() { + _, err := toInputItemImageDetail("bad") + assert.Error(t, err) + }) + }) +} + +func TestUserInputFileToInputItem(t *testing.T) { + mockey.PatchConvey("userInputFileToInputItem", t, func() { + mockey.PatchConvey("url", func() { + item, err := userInputFileToInputItem(responses.EasyInputMessageRoleUser, &schema.UserInputFile{URL: "http://f", Name: "a.txt"}) + assert.NoError(t, err) + assert.NotNil(t, item.OfMessage) + list := item.OfMessage.Content.OfInputItemContentList + if assert.Len(t, list, 1) { + f := list[0].OfInputFile + assert.NotNil(t, f) + assert.True(t, f.FileURL.Valid()) + assert.Equal(t, "http://f", f.FileURL.Value) + assert.True(t, f.Filename.Valid()) + assert.Equal(t, "a.txt", f.Filename.Value) + } + }) + + mockey.PatchConvey("base64", func() { + item, err := userInputFileToInputItem(responses.EasyInputMessageRoleUser, &schema.UserInputFile{Base64Data: "abc", MIMEType: "text/plain", Name: "a.txt"}) + assert.NoError(t, err) + list := item.OfMessage.Content.OfInputItemContentList + if assert.Len(t, list, 1) { + f := list[0].OfInputFile + assert.NotNil(t, f) + assert.True(t, f.FileData.Valid()) + assert.Equal(t, "abc", f.FileData.Value) + assert.False(t, f.FileURL.Valid()) + } + }) + }) +} + +func TestFunctionToolResultToInputItem(t *testing.T) { + mockey.PatchConvey("functionToolResultToInputItem", t, func() { + item, err := functionToolResultToInputItem(&schema.FunctionToolResult{CallID: "c", Result: "r"}) + assert.NoError(t, err) + assert.NotNil(t, item.OfFunctionCallOutput) + assert.Equal(t, "c", item.OfFunctionCallOutput.CallID) + assert.True(t, item.OfFunctionCallOutput.Output.OfString.Valid()) + assert.Equal(t, "r", item.OfFunctionCallOutput.Output.OfString.Value) + }) +} + +func TestAssistantGenTextToInputItem(t *testing.T) { + mockey.PatchConvey("assistantGenTextToInputItem", t, func() { + mockey.PatchConvey("nil_content", func() { + _, err := assistantGenTextToInputItem(&schema.ContentBlock{Type: schema.ContentBlockTypeAssistantGenText}) + assert.Error(t, err) + }) + + mockey.PatchConvey("with_annotations", func() { + block := schema.NewContentBlock(&schema.AssistantGenText{ + Text: "t", + OpenAIExtension: &openaischema.AssistantGenTextExtension{Annotations: []*openaischema.TextAnnotation{ + {Type: openaischema.TextAnnotationTypeURLCitation, URLCitation: &openaischema.TextAnnotationURLCitation{Title: "tt", URL: "u", StartIndex: 1, EndIndex: 2}}, + }}, + }) + setItemID(block, "msg1") + setItemStatus(block, "completed") + + item, err := assistantGenTextToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfOutputMessage) + assert.Equal(t, "msg1", item.OfOutputMessage.ID) + assert.Equal(t, responses.ResponseOutputMessageStatus("completed"), item.OfOutputMessage.Status) + if assert.Len(t, item.OfOutputMessage.Content, 1) { + ot := item.OfOutputMessage.Content[0].OfOutputText + assert.NotNil(t, ot) + assert.Equal(t, "t", ot.Text) + assert.Len(t, ot.Annotations, 1) + assert.NotNil(t, ot.Annotations[0].OfURLCitation) + assert.Equal(t, "u", ot.Annotations[0].OfURLCitation.URL) + } + }) + }) +} + +func TestTextAnnotationToOutputTextAnnotation(t *testing.T) { + mockey.PatchConvey("textAnnotationToOutputTextAnnotation", t, func() { + mockey.PatchConvey("file_citation", func() { + p, err := textAnnotationToOutputTextAnnotation(&openaischema.TextAnnotation{Type: openaischema.TextAnnotationTypeFileCitation, FileCitation: &openaischema.TextAnnotationFileCitation{FileID: "f", Filename: "n", Index: 3}}) + assert.NoError(t, err) + assert.NotNil(t, p.OfFileCitation) + assert.Equal(t, int64(3), p.OfFileCitation.Index) + }) + + mockey.PatchConvey("invalid", func() { + _, err := textAnnotationToOutputTextAnnotation(&openaischema.TextAnnotation{Type: "bad"}) + assert.Error(t, err) + }) + }) +} + +func TestFunctionToolCallToInputItem(t *testing.T) { + mockey.PatchConvey("functionToolCallToInputItem", t, func() { + mockey.PatchConvey("nil_content", func() { + _, err := functionToolCallToInputItem(&schema.ContentBlock{Type: schema.ContentBlockTypeFunctionToolCall}) + assert.Error(t, err) + }) + + mockey.PatchConvey("normal", func() { + block := schema.NewContentBlock(&schema.FunctionToolCall{CallID: "c", Name: "n", Arguments: "{}"}) + setItemID(block, "id") + setItemStatus(block, "completed") + item, err := functionToolCallToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfFunctionCall) + assert.True(t, item.OfFunctionCall.ID.Valid()) + assert.Equal(t, "id", item.OfFunctionCall.ID.Value) + assert.Equal(t, "c", item.OfFunctionCall.CallID) + }) + }) +} + +func TestReasoningToInputItem(t *testing.T) { + mockey.PatchConvey("reasoningToInputItem", t, func() { + block := schema.NewContentBlock(&schema.Reasoning{Summary: []*schema.ReasoningSummary{{Text: "s"}}, EncryptedContent: "e"}) + setItemID(block, "r") + setItemStatus(block, "completed") + item, err := reasoningToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfReasoning) + assert.Equal(t, "r", item.OfReasoning.ID) + assert.True(t, item.OfReasoning.EncryptedContent.Valid()) + assert.Equal(t, "e", item.OfReasoning.EncryptedContent.Value) + }) +} + +func TestServerToolCallToInputItem(t *testing.T) { + mockey.PatchConvey("serverToolCallToInputItem", t, func() { + block := schema.NewContentBlock(&schema.ServerToolCall{ + Name: string(ServerToolNameWebSearch), + Arguments: &ServerToolCallArguments{WebSearch: &WebSearchArguments{ActionType: WebSearchActionSearch, Search: &WebSearchQuery{Query: "q"}}}, + }) + setItemID(block, "ws1") + setItemStatus(block, "searching") + item, err := serverToolCallToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfWebSearchCall) + assert.Equal(t, "ws1", item.OfWebSearchCall.ID) + assert.NotNil(t, item.OfWebSearchCall.Action.OfSearch) + assert.Equal(t, "q", item.OfWebSearchCall.Action.OfSearch.Query) + }) +} + +func TestGetWebSearchToolCallActionParam(t *testing.T) { + mockey.PatchConvey("getWebSearchToolCallActionParam", t, func() { + a, err := getWebSearchToolCallActionParam(&WebSearchArguments{ActionType: WebSearchActionFind, Find: &WebSearchFind{URL: "u", Pattern: "p"}}) + assert.NoError(t, err) + assert.NotNil(t, a.OfFind) + assert.Equal(t, "u", a.OfFind.URL) + + _, err = getWebSearchToolCallActionParam(&WebSearchArguments{ActionType: "bad"}) + assert.Error(t, err) + }) +} + +func TestServerToolResultToInputItem(t *testing.T) { + mockey.PatchConvey("serverToolResultToInputItem", t, func() { + block := schema.NewContentBlock(&schema.ServerToolResult{ + Name: string(ServerToolNameWebSearch), + Result: &ServerToolResult{WebSearch: &WebSearchResult{ActionType: WebSearchActionSearch, Search: &WebSearchQueryResult{Sources: []*WebSearchQuerySource{{URL: "u"}}}}}, + }) + setItemID(block, "ws1") + setItemStatus(block, "completed") + item, err := serverToolResultToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfWebSearchCall) + assert.Len(t, item.OfWebSearchCall.Action.OfSearch.Sources, 1) + assert.Equal(t, "u", item.OfWebSearchCall.Action.OfSearch.Sources[0].URL) + }) +} + +func TestGetWebSearchToolResultActionParam(t *testing.T) { + mockey.PatchConvey("getWebSearchToolResultActionParam", t, func() { + a, err := getWebSearchToolResultActionParam(&WebSearchResult{ActionType: WebSearchActionSearch, Search: &WebSearchQueryResult{Sources: []*WebSearchQuerySource{{URL: "u"}}}}) + assert.NoError(t, err) + assert.NotNil(t, a.OfSearch) + assert.Len(t, a.OfSearch.Sources, 1) + assert.Equal(t, "u", a.OfSearch.Sources[0].URL) + + _, err = getWebSearchToolResultActionParam(&WebSearchResult{ActionType: "bad"}) + assert.Error(t, err) + }) +} + +func TestMcpToolApprovalRequestToInputItem(t *testing.T) { + mockey.PatchConvey("mcpToolApprovalRequestToInputItem", t, func() { + block := schema.NewContentBlock(&schema.MCPToolApprovalRequest{ID: "a", Name: "n", Arguments: "{}", ServerLabel: "s"}) + setItemID(block, "a") + item, err := mcpToolApprovalRequestToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfMcpApprovalRequest) + assert.Equal(t, "a", item.OfMcpApprovalRequest.ID) + assert.Equal(t, "n", item.OfMcpApprovalRequest.Name) + }) +} + +func TestMcpToolApprovalResponseToInputItem(t *testing.T) { + mockey.PatchConvey("mcpToolApprovalResponseToInputItem", t, func() { + mockey.PatchConvey("empty_reason", func() { + item, err := mcpToolApprovalResponseToInputItem(&schema.MCPToolApprovalResponse{ApprovalRequestID: "a", Approve: true}) + assert.NoError(t, err) + assert.NotNil(t, item.OfMcpApprovalResponse) + assert.False(t, item.OfMcpApprovalResponse.Reason.Valid()) + }) + + mockey.PatchConvey("with_reason", func() { + item, err := mcpToolApprovalResponseToInputItem(&schema.MCPToolApprovalResponse{ApprovalRequestID: "a", Approve: false, Reason: "r"}) + assert.NoError(t, err) + assert.True(t, item.OfMcpApprovalResponse.Reason.Valid()) + assert.Equal(t, "r", item.OfMcpApprovalResponse.Reason.Value) + }) + }) +} + +func TestMcpListToolsResultToInputItem(t *testing.T) { + mockey.PatchConvey("mcpListToolsResultToInputItem", t, func() { + block := schema.NewContentBlock(&schema.MCPListToolsResult{ServerLabel: "s", Tools: []*schema.MCPListToolsItem{{Name: "t", Description: "", InputSchema: &jsonschema.Schema{}}}}) + setItemID(block, "id") + item, err := mcpListToolsResultToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfMcpListTools) + assert.Equal(t, "id", item.OfMcpListTools.ID) + if assert.Len(t, item.OfMcpListTools.Tools, 1) { + assert.False(t, item.OfMcpListTools.Tools[0].Description.Valid()) + } + }) +} + +func TestMcpToolCallToInputItem(t *testing.T) { + mockey.PatchConvey("mcpToolCallToInputItem", t, func() { + block := schema.NewContentBlock(&schema.MCPToolCall{ServerLabel: "s", Name: "n", Arguments: "{}"}) + setItemID(block, "id") + setItemStatus(block, "calling") + item, err := mcpToolCallToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfMcpCall) + assert.Equal(t, "id", item.OfMcpCall.ID) + assert.Equal(t, "{}", item.OfMcpCall.Arguments) + }) +} + +func TestMcpToolResultToInputItem(t *testing.T) { + mockey.PatchConvey("mcpToolResultToInputItem", t, func() { + block := schema.NewContentBlock(&schema.MCPToolResult{ServerLabel: "s", Name: "n", Result: "out"}) + setItemID(block, "id") + setItemStatus(block, "completed") + item, err := mcpToolResultToInputItem(block) + assert.NoError(t, err) + assert.NotNil(t, item.OfMcpCall) + assert.True(t, item.OfMcpCall.Output.Valid()) + assert.Equal(t, "out", item.OfMcpCall.Output.Value) + assert.False(t, item.OfMcpCall.Error.Valid()) + }) +} + +func TestToOutputMessage(t *testing.T) { + mockey.PatchConvey("toOutputMessage", t, func() { + resp := &responses.Response{ + Output: []responses.ResponseOutputItemUnion{ + { + Type: "message", + ID: "m1", + Status: "completed", + Content: []responses.ResponseOutputMessageContentUnion{ + {Type: "output_text", Text: "hi", Annotations: []responses.ResponseOutputTextAnnotationUnion{}}, + }, + }, + { + Type: "reasoning", + ID: "r1", + Status: "completed", + Summary: []responses.ResponseReasoningItemSummary{{Text: "s"}}, + }, + }, + Usage: responses.ResponseUsage{ + InputTokens: 1, + InputTokensDetails: responses.ResponseUsageInputTokensDetails{CachedTokens: 2}, + OutputTokens: 3, + OutputTokensDetails: responses.ResponseUsageOutputTokensDetails{ReasoningTokens: 4}, + TotalTokens: 5, + }, + } + + mockey.Mock(mockey.GetMethod(resp.Output[0], "AsAny")).Return(mockey.Sequence( + responses.ResponseOutputMessage{ + Type: "message", + ID: "m1", + Status: "completed", + Content: []responses.ResponseOutputMessageContentUnion{ + {Type: "output_text", Text: "hi", Annotations: []responses.ResponseOutputTextAnnotationUnion{}}, + }, + }).Then(responses.ResponseReasoningItem{ + Type: "reasoning", + ID: "r1", + Status: "completed", + Summary: []responses.ResponseReasoningItemSummary{{Text: "s"}}, + })).Build() + msg, err := toOutputMessage(resp) + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, schema.AgenticRoleTypeAssistant, msg.Role) + assert.Len(t, msg.ContentBlocks, 2) + assert.NotNil(t, msg.ResponseMeta) + }) +} + +func TestOutputMessageToContentBlocks(t *testing.T) { + mockey.PatchConvey("outputMessageToContentBlocks", t, func() { + item := responses.ResponseOutputMessage{ + ID: "m1", + Status: "completed", + Content: []responses.ResponseOutputMessageContentUnion{ + {Type: "output_text", Text: "hi", Annotations: []responses.ResponseOutputTextAnnotationUnion{}}, + {Type: "refusal", Refusal: "no"}, + }, + } + blocks, err := outputMessageToContentBlocks(item) + assert.NoError(t, err) + assert.Len(t, blocks, 2) + for _, b := range blocks { + id, ok := getItemID(b) + assert.True(t, ok) + assert.Equal(t, "m1", id) + } + }) +} + +func TestOutputContentTextToContentBlock(t *testing.T) { + mockey.PatchConvey("outputContentTextToContentBlock", t, func() { + text := responses.ResponseOutputText{Text: "hi", Annotations: []responses.ResponseOutputTextAnnotationUnion{{Type: "url_citation", Title: "t", URL: "u", StartIndex: 1, EndIndex: 2}}} + block, err := outputContentTextToContentBlock(text) + assert.NoError(t, err) + assert.NotNil(t, block) + assert.NotNil(t, block.AssistantGenText) + assert.Equal(t, "hi", block.AssistantGenText.Text) + if assert.NotNil(t, block.AssistantGenText.OpenAIExtension) { + assert.Len(t, block.AssistantGenText.OpenAIExtension.Annotations, 1) + } + }) +} + +func TestOutputTextAnnotationToTextAnnotation(t *testing.T) { + mockey.PatchConvey("outputTextAnnotationToTextAnnotation", t, func() { + mockey.PatchConvey("file_citation_index_should_preserve", func() { + a := responses.ResponseOutputTextAnnotationUnion{Type: "file_citation", FileID: "f", Filename: "n", Index: 5} + + mockey.Mock(responses.ResponseOutputTextAnnotationUnion.AsAny).Return(responses.ResponseOutputTextAnnotationFileCitation{ + FileID: "f", + Filename: "n", + Index: 5, + }).Build() + + ta, err := outputTextAnnotationToTextAnnotation(a) + assert.NoError(t, err) + assert.NotNil(t, ta) + assert.NotNil(t, ta.FileCitation) + assert.Equal(t, 5, ta.FileCitation.Index) + }) + }) +} + +func TestFunctionToolCallToContentBlock(t *testing.T) { + mockey.PatchConvey("functionToolCallToContentBlock", t, func() { + item := responses.ResponseFunctionToolCall{ID: "id", Status: "completed", CallID: "c", Name: "n", Arguments: "{}"} + block, err := functionToolCallToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block) + assert.NotNil(t, block.FunctionToolCall) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "id", id) + }) +} + +func TestWebSearchToContentBlocks(t *testing.T) { + mockey.PatchConvey("webSearchToContentBlocks", t, func() { + item := responses.ResponseFunctionWebSearch{ + ID: "ws1", + Status: "completed", + Action: responses.ResponseFunctionWebSearchActionUnion{Type: "search", Query: "q", Sources: []responses.ResponseFunctionWebSearchActionSearchSource{{URL: "u"}}}, + } + blocks, err := webSearchToContentBlocks(item) + assert.NoError(t, err) + assert.Len(t, blocks, 2) + for _, b := range blocks { + id, ok := getItemID(b) + assert.True(t, ok) + assert.Equal(t, "ws1", id) + } + }) +} + +func TestReasoningToContentBlocks(t *testing.T) { + mockey.PatchConvey("reasoningToContentBlocks", t, func() { + item := responses.ResponseReasoningItem{ID: "r1", Status: "completed", Summary: []responses.ResponseReasoningItemSummary{{Text: "s"}}} + block, err := reasoningToContentBlocks(item) + assert.NoError(t, err) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "r1", id) + assert.NotNil(t, block.Reasoning) + assert.Len(t, block.Reasoning.Summary, 1) + }) +} + +func TestMcpCallToContentBlocks(t *testing.T) { + mockey.PatchConvey("mcpCallToContentBlocks", t, func() { + item := responses.ResponseOutputItemMcpCall{ID: "m1", ServerLabel: "s", Name: "n", Arguments: "{}", Output: "out"} + blocks, err := mcpCallToContentBlocks(item) + assert.NoError(t, err) + assert.Len(t, blocks, 2) + for _, b := range blocks { + id, ok := getItemID(b) + assert.True(t, ok) + assert.Equal(t, "m1", id) + } + }) +} + +func TestMcpListToolsToContentBlock(t *testing.T) { + mockey.PatchConvey("mcpListToolsToContentBlock", t, func() { + item := responses.ResponseOutputItemMcpListTools{ + ID: "l1", + ServerLabel: "s", + Tools: []responses.ResponseOutputItemMcpListToolsTool{ + {Name: "t", Description: "d", InputSchema: map[string]any{"type": "object"}}, + }, + } + block, err := mcpListToolsToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block) + assert.NotNil(t, block.MCPListToolsResult) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "l1", id) + assert.Len(t, block.MCPListToolsResult.Tools, 1) + }) +} + +func TestMcpApprovalRequestToContentBlock(t *testing.T) { + mockey.PatchConvey("mcpApprovalRequestToContentBlock", t, func() { + item := responses.ResponseOutputItemMcpApprovalRequest{ID: "a1", ServerLabel: "s", Name: "n", Arguments: "{}"} + block, err := mcpApprovalRequestToContentBlock(item) + assert.NoError(t, err) + assert.NotNil(t, block) + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "a1", id) + }) +} + +func TestResponseObjectToResponseMeta(t *testing.T) { + mockey.PatchConvey("responseObjectToResponseMeta", t, func() { + resp := &responses.Response{Usage: responses.ResponseUsage{InputTokensDetails: responses.ResponseUsageInputTokensDetails{}, OutputTokensDetails: responses.ResponseUsageOutputTokensDetails{}}} + meta := responseObjectToResponseMeta(resp) + assert.NotNil(t, meta) + assert.NotNil(t, meta.TokenUsage) + assert.NotNil(t, meta.OpenAIExtension) + }) +} + +func TestToTokenUsage(t *testing.T) { + mockey.PatchConvey("toTokenUsage", t, func() { + resp := &responses.Response{Usage: responses.ResponseUsage{InputTokens: 1, InputTokensDetails: responses.ResponseUsageInputTokensDetails{CachedTokens: 2}, OutputTokens: 3, OutputTokensDetails: responses.ResponseUsageOutputTokensDetails{ReasoningTokens: 4}, TotalTokens: 5}} + u := toTokenUsage(resp) + assert.NotNil(t, u) + assert.Equal(t, 1, u.PromptTokens) + assert.Equal(t, 2, u.PromptTokenDetails.CachedTokens) + assert.Equal(t, 3, u.CompletionTokens) + assert.Equal(t, 4, u.CompletionTokensDetails.ReasoningTokens) + assert.Equal(t, 5, u.TotalTokens) + }) +} + +func TestToResponseMetaExtension(t *testing.T) { + mockey.PatchConvey("toResponseMetaExtension", t, func() { + resp := &responses.Response{} + resp.ID = "r" + resp.Status = "completed" + resp.Error.Code = "c" + resp.Error.Message = "m" + resp.IncompleteDetails.Reason = "x" + resp.Reasoning.Effort = "low" + resp.Reasoning.Summary = "sum" + resp.ServiceTier = "auto" + resp.CreatedAt = 123 + ext := toResponseMetaExtension(resp) + assert.NotNil(t, ext) + assert.Equal(t, "r", ext.ID) + assert.NotNil(t, ext.Error) + assert.Equal(t, openaischema.ResponseErrorCode("c"), ext.Error.Code) + assert.NotNil(t, ext.IncompleteDetails) + assert.Equal(t, "x", ext.IncompleteDetails.Reason) + }) +} + +func TestResolveURL(t *testing.T) { + mockey.PatchConvey("resolveURL", t, func() { + mockey.PatchConvey("url", func() { + u, err := resolveURL("http://x", "", "") + assert.NoError(t, err) + assert.Equal(t, "http://x", u) + }) + + mockey.PatchConvey("base64_without_mime", func() { + _, err := resolveURL("", "abc", "") + assert.Error(t, err) + }) + + mockey.PatchConvey("base64", func() { + u, err := resolveURL("", "abc", "text/plain") + assert.NoError(t, err) + assert.Equal(t, "data:text/plain;base64,abc", u) + }) + }) +} + +func TestEnsureDataURL(t *testing.T) { + mockey.PatchConvey("ensureDataURL", t, func() { + mockey.PatchConvey("already_data_url", func() { + _, err := ensureDataURL("data:text/plain;base64,abc", "text/plain") + assert.Error(t, err) + }) + + mockey.PatchConvey("missing_mime", func() { + _, err := ensureDataURL("abc", "") + assert.Error(t, err) + }) + + mockey.PatchConvey("ok", func() { + u, err := ensureDataURL("abc", "text/plain") + assert.NoError(t, err) + assert.Equal(t, "data:text/plain;base64,abc", u) + }) + }) +} diff --git a/components/agentic/openai/event_convertor.go b/components/agentic/openai/event_convertor.go new file mode 100644 index 000000000..8ca877477 --- /dev/null +++ b/components/agentic/openai/event_convertor.go @@ -0,0 +1,772 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/schema/openai" + "github.com/openai/openai-go/v3/responses" +) + +type streamScanner interface { + Next() bool + Current() responses.ResponseStreamEventUnion + Err() error +} + +func receivedStreamingResponse(sr streamScanner, + config *agentic.Config, sw *schema.StreamWriter[*agentic.CallbackOutput]) { + + receiver := newStreamReceiver() + sender := newCallbackSender(sw, config) + + if sr.Err() != nil { + _ = sw.Send(nil, fmt.Errorf("failed to read stream, err: %w", sr.Err())) + return + } + + for sr.Next() { + event := sr.Current() + + sender.errHeader = fmt.Sprintf("failed to convert event '%s'", event.Type) + + switch variant := event.AsAny().(type) { + case responses.ResponseTextDoneEvent, + responses.ResponseReasoningSummaryPartAddedEvent, + responses.ResponseReasoningSummaryPartDoneEvent, + responses.ResponseReasoningSummaryTextDoneEvent, + responses.ResponseFunctionCallArgumentsDoneEvent, + responses.ResponseMcpCallArgumentsDoneEvent, + responses.ResponseRefusalDoneEvent: + + // Do nothing. + continue + + case responses.ResponseErrorEvent: + _ = sw.Send(nil, fmt.Errorf("received error event: code=%s message=%s", variant.Code, variant.Message)) + + case responses.ResponseCreatedEvent: + meta := responseObjectToResponseMeta(&variant.Response) + sender.sendMeta(meta, nil) + + case responses.ResponseInProgressEvent: + meta := responseObjectToResponseMeta(&variant.Response) + sender.sendMeta(meta, nil) + + case responses.ResponseCompletedEvent: + meta := responseObjectToResponseMeta(&variant.Response) + sender.sendMeta(meta, nil) + + case responses.ResponseIncompleteEvent: + meta := responseObjectToResponseMeta(&variant.Response) + sender.sendMeta(meta, nil) + + case responses.ResponseFailedEvent: + meta := responseObjectToResponseMeta(&variant.Response) + sender.sendMeta(meta, nil) + + case responses.ResponseOutputItemAddedEvent: + blocks, err := receiver.itemAddedEventToContentBlock(variant) + for _, block := range blocks { + sender.sendBlock(block, err) + } + + case responses.ResponseOutputItemDoneEvent: + blocks, err := receiver.itemDoneEventToContentBlocks(variant) + for _, block := range blocks { + sender.sendBlock(block, err) + } + + case responses.ResponseContentPartAddedEvent: + block, err := receiver.contentPartAddedEventToContentBlock(variant) + sender.sendBlock(block, err) + + case responses.ResponseContentPartDoneEvent: + block, err := receiver.contentPartDoneEventToContentBlock(variant) + sender.sendBlock(block, err) + + case responses.ResponseRefusalDeltaEvent: + block := receiver.refusalDeltaEventToContentBlock(variant) + sender.sendBlock(block, nil) + + case responses.ResponseTextDeltaEvent: + block := receiver.outputTextDeltaEventToContentBlock(variant) + sender.sendBlock(block, nil) + + case responses.ResponseOutputTextAnnotationAddedEvent: + block, err := receiver.annotationAddedEventToContentBlock(variant) + sender.sendBlock(block, err) + + case responses.ResponseReasoningSummaryTextDeltaEvent: + block := receiver.reasoningSummaryTextDeltaEventToContentBlock(variant) + sender.sendBlock(block, nil) + + case responses.ResponseFunctionCallArgumentsDeltaEvent: + block := receiver.functionCallArgumentsDeltaEventToContentBlock(variant) + sender.sendBlock(block, nil) + + case responses.ResponseMcpListToolsInProgressEvent: + block := receiver.mcpListToolsPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusInProgress)) + sender.sendBlock(block, nil) + + case responses.ResponseMcpListToolsFailedEvent: + block := receiver.mcpListToolsPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusFailed)) + sender.sendBlock(block, nil) + + case responses.ResponseMcpListToolsCompletedEvent: + block := receiver.mcpListToolsPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusCompleted)) + sender.sendBlock(block, nil) + + case responses.ResponseMcpCallArgumentsDeltaEvent: + block := receiver.mcpCallArgumentsDeltaEventToContentBlock(variant) + sender.sendBlock(block, nil) + + case responses.ResponseMcpCallInProgressEvent: + block := receiver.mcpCallPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusInProgress)) + sender.sendBlock(block, nil) + + case responses.ResponseMcpCallCompletedEvent: + block := receiver.mcpCallPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusCompleted)) + sender.sendBlock(block, nil) + + case responses.ResponseMcpCallFailedEvent: + block := receiver.mcpCallPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusFailed)) + sender.sendBlock(block, nil) + + case responses.ResponseWebSearchCallInProgressEvent: + block := receiver.webSearchPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusInProgress)) + sender.sendBlock(block, nil) + + case responses.ResponseWebSearchCallSearchingEvent: + const phase = "searching" + block := receiver.webSearchPhaseToContentBlock(variant.ItemID, variant.OutputIndex, phase) + sender.sendBlock(block, nil) + + case responses.ResponseWebSearchCallCompletedEvent: + block := receiver.webSearchPhaseToContentBlock(variant.ItemID, variant.OutputIndex, string(responses.ResponseStatusCompleted)) + sender.sendBlock(block, nil) + + default: + sw.Send(nil, fmt.Errorf("invalid event type: %s", event.Type)) + } + } + + if sr.Err() != nil { + _ = sw.Send(nil, fmt.Errorf("failed to read stream, err: %w", sr.Err())) + return + } +} + +type callbackSender struct { + sw *schema.StreamWriter[*agentic.CallbackOutput] + config *agentic.Config + errHeader string +} + +func newCallbackSender(sw *schema.StreamWriter[*agentic.CallbackOutput], config *agentic.Config) *callbackSender { + return &callbackSender{ + sw: sw, + config: config, + } +} + +func (s *callbackSender) sendMeta(meta *schema.AgenticResponseMeta, err error) { + s.send(meta, nil, err) +} + +func (s *callbackSender) sendBlock(block *schema.ContentBlock, err error) { + s.send(nil, block, err) +} + +func (s *callbackSender) send(meta *schema.AgenticResponseMeta, block *schema.ContentBlock, err error) { + if err != nil { + _ = s.sw.Send(nil, fmt.Errorf("%s: %w", s.errHeader, err)) + return + } + + msg := &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ResponseMeta: meta, + } + if block != nil { + msg.ContentBlocks = []*schema.ContentBlock{block} + } + + s.sw.Send(&agentic.CallbackOutput{ + Message: msg, + Config: s.config, + }, nil) +} + +type streamReceiver struct { + ProcessingAssistantGenTextBlockIndex map[string]map[int]bool + + MaxBlockIndex int + IndexMapper map[string]int + + MaxReasoningSummaryIndex map[string]int + ReasoningSummaryIndexMapper map[string]int + + MaxTextAnnotationIndex map[string]int + TextAnnotationIndexMapper map[string]int +} + +func newStreamReceiver() *streamReceiver { + return &streamReceiver{ + ProcessingAssistantGenTextBlockIndex: map[string]map[int]bool{}, + MaxBlockIndex: -1, + IndexMapper: map[string]int{}, + MaxReasoningSummaryIndex: map[string]int{}, + ReasoningSummaryIndexMapper: map[string]int{}, + TextAnnotationIndexMapper: map[string]int{}, + MaxTextAnnotationIndex: map[string]int{}, + } +} + +func (r *streamReceiver) getBlockIndex(key string) int { + if idx, ok := r.IndexMapper[key]; ok { + return idx + } + + r.MaxBlockIndex++ + r.IndexMapper[key] = r.MaxBlockIndex + + return r.MaxBlockIndex +} + +func (r *streamReceiver) getReasoningSummaryIndex(outputIdx, summaryIdx int64) int { + maxSummaryIndex := -1 + if idx, ok := r.MaxReasoningSummaryIndex[int64ToStr(outputIdx)]; ok { + maxSummaryIndex = idx + } + + idxKey := fmt.Sprintf("%d:%d", outputIdx, summaryIdx) + if idx, ok := r.ReasoningSummaryIndexMapper[idxKey]; ok { + return idx + } + + maxSummaryIndex++ + r.ReasoningSummaryIndexMapper[idxKey] = maxSummaryIndex + r.MaxReasoningSummaryIndex[int64ToStr(outputIdx)] = maxSummaryIndex + + return maxSummaryIndex +} + +func (r *streamReceiver) getTextAnnotationIndex(outputIdx, contentIdx, annotationIdx int64) int { + maxAnnotationIndex := -1 + + maxIdxKey := fmt.Sprintf("%d:%d", outputIdx, contentIdx) + if idx, ok := r.MaxTextAnnotationIndex[maxIdxKey]; ok { + maxAnnotationIndex = idx + } + + idxKey := fmt.Sprintf("%d:%d:%d", outputIdx, contentIdx, annotationIdx) + if idx, ok := r.TextAnnotationIndexMapper[idxKey]; ok { + return idx + } + + maxAnnotationIndex++ + r.TextAnnotationIndexMapper[idxKey] = maxAnnotationIndex + r.MaxTextAnnotationIndex[maxIdxKey] = maxAnnotationIndex + + return maxAnnotationIndex +} + +func (r *streamReceiver) itemAddedEventToContentBlock(ev responses.ResponseOutputItemAddedEvent) (blocks []*schema.ContentBlock, err error) { + switch item := ev.Item.AsAny().(type) { + case responses.ResponseFunctionToolCall: + block, err := r.itemAddedEventFunctionToolCallToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case responses.ResponseReasoningItem: + block, err := r.itemAddedEventReasoningToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case responses.ResponseOutputMessage, + responses.ResponseFunctionWebSearch, + responses.ResponseOutputItemMcpListTools, + responses.ResponseOutputItemMcpApprovalRequest, + responses.ResponseOutputItemMcpCall: + + // Do nothing. + + default: + return nil, fmt.Errorf("invalid item type %T with 'output_item.added' event", item) + } + + return blocks, nil +} + +func (r *streamReceiver) itemAddedEventFunctionToolCallToContentBlock(outputIdx int64, item responses.ResponseFunctionToolCall) (block *schema.ContentBlock, err error) { + block, err = functionToolCallToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemAddedEventReasoningToContentBlock(outputIdx int64, item responses.ResponseReasoningItem) (block *schema.ContentBlock, err error) { + block, err = reasoningToContentBlocks(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventToContentBlocks(ev responses.ResponseOutputItemDoneEvent) (blocks []*schema.ContentBlock, err error) { + switch item := ev.Item.AsAny().(type) { + case responses.ResponseOutputMessage: + blocks, err = r.itemDoneEventOutputMessageToContentBlock(item) + if err != nil { + return nil, err + } + + case responses.ResponseReasoningItem: + block, err := r.itemDoneEventReasoningToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case responses.ResponseFunctionToolCall: + block, err := r.itemDoneEventFunctionToolCallToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case responses.ResponseFunctionWebSearch: + blocks, err = r.itemDoneEventFunctionWebSearchToContentBlocks(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + case responses.ResponseOutputItemMcpCall: + blocks, err = r.itemDoneEventFunctionMCPCallToContentBlocks(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + case responses.ResponseOutputItemMcpListTools: + block, err := r.itemDoneEventFunctionMCPListToolsToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + case responses.ResponseOutputItemMcpApprovalRequest: + block, err := r.itemDoneEventFunctionMCPApprovalRequestToContentBlock(ev.OutputIndex, item) + if err != nil { + return nil, err + } + + blocks = append(blocks, block) + + default: + return nil, fmt.Errorf("invalid item type %T with 'output_item.done' event", item) + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventOutputMessageToContentBlock(item responses.ResponseOutputMessage) (blocks []*schema.ContentBlock, err error) { + indices, ok := r.ProcessingAssistantGenTextBlockIndex[item.ID] + if !ok { + return nil, fmt.Errorf("item %s not found in processing queue", item.ID) + } + + for idx := range indices { + meta := &schema.StreamingMeta{Index: idx} + block := schema.NewContentBlockChunk(&schema.AssistantGenText{}, meta) + + setItemID(block, item.ID) + if string(item.Status) != "" { + setItemStatus(block, string(item.Status)) + } + + blocks = append(blocks, block) + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventReasoningToContentBlock(outputIdx int64, item responses.ResponseReasoningItem) (block *schema.ContentBlock, err error) { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(outputIdx)), + } + block = schema.NewContentBlockChunk(&schema.Reasoning{}, meta) + + setItemID(block, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(block, s) + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionToolCallToContentBlock(outputIdx int64, item responses.ResponseFunctionToolCall) (block *schema.ContentBlock, err error) { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(outputIdx)), + } + block = schema.NewContentBlockChunk(&schema.FunctionToolCall{ + CallID: item.CallID, + Name: item.Name, + }, meta) + + setItemID(block, item.ID) + if s := string(item.Status); s != "" { + setItemStatus(block, s) + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionWebSearchToContentBlocks(outputIdx int64, item responses.ResponseFunctionWebSearch) (blocks []*schema.ContentBlock, err error) { + blocks, err = webSearchToContentBlocks(item) + if err != nil { + return nil, err + } + + blocks[0].StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeServerToolCallIndexKey(outputIdx)), + } + blocks[1].StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeServerToolResultIndexKey(outputIdx)), + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPCallToContentBlocks(outputIdx int64, item responses.ResponseOutputItemMcpCall) (blocks []*schema.ContentBlock, err error) { + blocks, err = mcpCallToContentBlocks(item) + if err != nil { + return nil, err + } + + for _, block := range blocks { + switch block.Type { + case schema.ContentBlockTypeMCPToolCall: + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(outputIdx)), + } + case schema.ContentBlockTypeMCPToolResult: + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolResultIndexKey(outputIdx)), + } + default: + return nil, fmt.Errorf("expected mcp tool call or result block, but got '%s'", block.Type) + } + } + + return blocks, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPListToolsToContentBlock(outputIdx int64, item responses.ResponseOutputItemMcpListTools) (block *schema.ContentBlock, err error) { + block, err = mcpListToolsToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPListToolsResultIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) itemDoneEventFunctionMCPApprovalRequestToContentBlock(outputIdx int64, item responses.ResponseOutputItemMcpApprovalRequest) (block *schema.ContentBlock, err error) { + block, err = mcpApprovalRequestToContentBlock(item) + if err != nil { + return nil, err + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolApprovalRequestIndexKey(outputIdx)), + } + + return block, nil +} + +func (r *streamReceiver) contentPartAddedEventToContentBlock(ev responses.ResponseContentPartAddedEvent) (block *schema.ContentBlock, err error) { + key := makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex) + blockIdx := r.getBlockIndex(key) + + indices, ok := r.ProcessingAssistantGenTextBlockIndex[ev.ItemID] + if !ok { + indices = map[int]bool{} + r.ProcessingAssistantGenTextBlockIndex[ev.ItemID] = indices + } + + indices[blockIdx] = true + + meta := &schema.StreamingMeta{Index: blockIdx} + + switch ev.Part.AsAny().(type) { + case responses.ResponseOutputText, responses.ResponseOutputRefusal: + block = schema.NewContentBlockChunk(&schema.AssistantGenText{}, meta) + default: + return nil, fmt.Errorf("invalid content part type: %T", ev.Part) + } + + setItemStatus(block, string(responses.ResponseStatusInProgress)) + setItemID(block, ev.ItemID) + + return block, nil +} + +func (r *streamReceiver) contentPartDoneEventToContentBlock(ev responses.ResponseContentPartDoneEvent) (block *schema.ContentBlock, err error) { + key := makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex) + blockIdx := r.getBlockIndex(key) + + indices, ok := r.ProcessingAssistantGenTextBlockIndex[ev.ItemID] + if !ok { + return nil, fmt.Errorf("item '%s' has no processing assistant gen text block index", ev.ItemID) + } + + delete(indices, blockIdx) + + meta := &schema.StreamingMeta{Index: blockIdx} + + switch ev.Part.AsAny().(type) { + case responses.ResponseOutputText: + block = schema.NewContentBlockChunk(&schema.AssistantGenText{}, meta) + default: + return nil, fmt.Errorf("invalid content part type: %T", ev.Part) + } + + block.StreamingMeta = &schema.StreamingMeta{ + Index: blockIdx, + } + + setItemStatus(block, string(responses.ResponseStatusCompleted)) + setItemID(block, ev.ItemID) + + return block, nil +} + +func (r *streamReceiver) refusalDeltaEventToContentBlock(ev responses.ResponseRefusalDeltaEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex)), + } + block := schema.NewContentBlockChunk(&schema.AssistantGenText{ + OpenAIExtension: &openai.AssistantGenTextExtension{ + Refusal: &openai.OutputRefusal{ + Reason: ev.Delta, + }, + }, + }, meta) + + setItemID(block, ev.ItemID) + + return block +} + +func (r *streamReceiver) outputTextDeltaEventToContentBlock(ev responses.ResponseTextDeltaEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex)), + } + block := schema.NewContentBlockChunk(&schema.AssistantGenText{ + Text: ev.Delta, + }, meta) + + setItemID(block, ev.ItemID) + + return block +} + +func (r *streamReceiver) annotationAddedEventToContentBlock(ev responses.ResponseOutputTextAnnotationAddedEvent) (block *schema.ContentBlock, err error) { + annoBytes, err := sonic.Marshal(ev.Annotation) + if err != nil { + return nil, fmt.Errorf("failed to marshal annotation, err: %w", err) + } + + anno := responses.ResponseOutputTextAnnotationUnion{} + err = sonic.Unmarshal(annoBytes, &anno) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal annotation, err: %w", err) + } + + annotation, err := outputTextAnnotationToTextAnnotation(anno) + if err != nil { + return nil, fmt.Errorf("failed to convert annotation, err: %w", err) + } + + annotation.Index = r.getTextAnnotationIndex(ev.OutputIndex, ev.ContentIndex, ev.AnnotationIndex) + + genText := &schema.AssistantGenText{ + OpenAIExtension: &openai.AssistantGenTextExtension{ + Annotations: []*openai.TextAnnotation{annotation}, + }, + } + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeAssistantGenTextIndexKey(ev.OutputIndex, ev.ContentIndex)), + } + block = schema.NewContentBlockChunk(genText, meta) + + setItemID(block, ev.ItemID) + + return block, nil +} + +func (r *streamReceiver) reasoningSummaryTextDeltaEventToContentBlock(ev responses.ResponseReasoningSummaryTextDeltaEvent) *schema.ContentBlock { + reasoning := &schema.Reasoning{ + Summary: []*schema.ReasoningSummary{ + { + Index: r.getReasoningSummaryIndex(ev.OutputIndex, ev.SummaryIndex), + Text: ev.Delta, + }, + }, + } + + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeReasoningIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(reasoning, meta) + + setItemID(block, ev.ItemID) + + return block +} + +func (r *streamReceiver) functionCallArgumentsDeltaEventToContentBlock(ev responses.ResponseFunctionCallArgumentsDeltaEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeFunctionToolCallIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(&schema.FunctionToolCall{ + Arguments: ev.Delta, + }, meta) + + setItemID(block, ev.ItemID) + + return block +} + +func (r *streamReceiver) mcpListToolsPhaseToContentBlock(itemID string, outputIdx int64, status string) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPListToolsResultIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.MCPListToolsResult{}, meta) + + setItemID(block, itemID) + if status != "" { + setItemStatus(block, status) + } + + return block +} + +func (r *streamReceiver) mcpCallArgumentsDeltaEventToContentBlock(ev responses.ResponseMcpCallArgumentsDeltaEvent) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(ev.OutputIndex)), + } + block := schema.NewContentBlockChunk(&schema.MCPToolCall{ + Arguments: ev.Delta, + }, meta) + + setItemID(block, ev.ItemID) + + return block +} + +func (r *streamReceiver) mcpCallPhaseToContentBlock(itemID string, outputIdx int64, status string) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeMCPToolCallIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.MCPToolCall{}, meta) + + setItemID(block, itemID) + if status != "" { + setItemStatus(block, status) + } + + return block +} + +func (r *streamReceiver) webSearchPhaseToContentBlock(itemID string, outputIdx int64, status string) *schema.ContentBlock { + meta := &schema.StreamingMeta{ + Index: r.getBlockIndex(makeServerToolCallIndexKey(outputIdx)), + } + block := schema.NewContentBlockChunk(&schema.ServerToolCall{}, meta) + + setItemID(block, itemID) + if status != "" { + setItemStatus(block, status) + } + + return block +} + +func makeAssistantGenTextIndexKey(outputIndex, contentIndex int64) string { + return fmt.Sprintf("assistant_gen_text:%d:%d", outputIndex, contentIndex) +} + +func makeReasoningIndexKey(outputIndex int64) string { + return fmt.Sprintf("reasoning:%d", outputIndex) +} + +func makeFunctionToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("function_tool_call:%d", outputIndex) +} + +func makeServerToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("server_tool_call:%d", outputIndex) +} + +func makeServerToolResultIndexKey(outputIndex int64) string { + return fmt.Sprintf("server_tool_result:%d", outputIndex) +} + +func makeMCPListToolsResultIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_list_tools_result:%d", outputIndex) +} + +func makeMCPToolApprovalRequestIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_approval_request:%d", outputIndex) +} + +func makeMCPToolCallIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_call:%d", outputIndex) +} + +func makeMCPToolResultIndexKey(outputIndex int64) string { + return fmt.Sprintf("mcp_tool_result:%d", outputIndex) +} diff --git a/components/agentic/openai/event_convertor_test.go b/components/agentic/openai/event_convertor_test.go new file mode 100644 index 000000000..93ceff059 --- /dev/null +++ b/components/agentic/openai/event_convertor_test.go @@ -0,0 +1,580 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "errors" + "testing" + + "github.com/bytedance/mockey" + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" +) + +func TestNewStreamReceiverInit(t *testing.T) { + r := newStreamReceiver() + assert.NotNil(t, r.ProcessingAssistantGenTextBlockIndex) + assert.Equal(t, -1, r.MaxBlockIndex) + assert.NotNil(t, r.IndexMapper) + assert.NotNil(t, r.MaxReasoningSummaryIndex) + assert.NotNil(t, r.ReasoningSummaryIndexMapper) + assert.NotNil(t, r.TextAnnotationIndexMapper) + assert.NotNil(t, r.MaxTextAnnotationIndex) +} + +func TestGetBlockIndexAndReuse(t *testing.T) { + r := newStreamReceiver() + a := r.getBlockIndex("k1") + b := r.getBlockIndex("k2") + c := r.getBlockIndex("k1") + assert.Equal(t, a, c) + assert.NotEqual(t, a, b) + assert.GreaterOrEqual(t, r.MaxBlockIndex, 1) +} + +func TestGetReasoningSummaryIndex(t *testing.T) { + r := newStreamReceiver() + i1 := r.getReasoningSummaryIndex(1, 1) + i2 := r.getReasoningSummaryIndex(1, 2) + i3 := r.getReasoningSummaryIndex(2, 1) + i4 := r.getReasoningSummaryIndex(1, 1) + assert.Equal(t, 0, i1) + assert.Equal(t, 1, i2) + assert.Equal(t, 0, i3) + assert.Equal(t, i1, i4) +} + +func TestGetTextAnnotationIndex(t *testing.T) { + r := newStreamReceiver() + i1 := r.getTextAnnotationIndex(1, 1, 1) + i2 := r.getTextAnnotationIndex(1, 1, 2) + i3 := r.getTextAnnotationIndex(1, 2, 1) + i4 := r.getTextAnnotationIndex(2, 1, 1) + i5 := r.getTextAnnotationIndex(1, 1, 1) + assert.Equal(t, 0, i1) + assert.Equal(t, 1, i2) + assert.Equal(t, 0, i3) + assert.Equal(t, 0, i4) + assert.Equal(t, i1, i5) +} + +func TestItemAddedEventToContentBlockFunctionToolCall(t *testing.T) { + mockey.PatchConvey("TestItemAddedEventToContentBlockFunctionToolCall", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemAddedEvent{ + OutputIndex: 1, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseFunctionToolCall{ + ID: "id1", + CallID: "cid", + Name: "name", + }).Build() + + blocks, err := r.itemAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].FunctionToolCall) + assert.NotNil(t, blocks[0].StreamingMeta) + }) +} + +func TestItemAddedEventToContentBlockIgnoredTypes(t *testing.T) { + mockey.PatchConvey("TestItemAddedEventToContentBlockIgnoredTypes", t, func() { + r := newStreamReceiver() + + // Mock AsAny to return different types in sequence + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return( + mockey.Sequence( + responses.ResponseOutputMessage{}, + ).Then( + responses.ResponseFunctionWebSearch{}, + ).Then( + responses.ResponseOutputItemMcpListTools{}, + ).Then( + responses.ResponseOutputItemMcpApprovalRequest{}, + ).Then( + responses.ResponseOutputItemMcpCall{}, + ), + ).Build() + + ignoredTypes := []string{"OutputMessage", "WebSearch", "McpListTools", "McpApprovalRequest", "McpCall"} + + for range ignoredTypes { + ev := responses.ResponseOutputItemAddedEvent{ + OutputIndex: 1, + } + blocks, err := r.itemAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.Equal(t, 0, len(blocks)) + } + }) +} + +func TestItemDoneEventToContentBlocksOutputMessage(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksOutputMessage", t, func() { + r := newStreamReceiver() + r.ProcessingAssistantGenTextBlockIndex["mid"] = map[int]bool{0: true, 2: true} + + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 1, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseOutputMessage{ + ID: "mid", + Status: responses.ResponseOutputMessageStatusCompleted, + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + id, ok := getItemID(blocks[0]) + assert.True(t, ok) + assert.Equal(t, "mid", id) + }) +} + +func TestItemDoneEventToContentBlocksReasoning(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksReasoning", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 3, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseReasoningItem{ + ID: "rid", + Status: responses.ResponseReasoningItemStatusCompleted, + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].Reasoning) + }) +} + +func TestItemDoneEventToContentBlocksFunctionToolCall(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksFunctionToolCall", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 4, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseFunctionToolCall{ + ID: "fid", + CallID: "cid", + Name: "nm", + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].FunctionToolCall) + }) +} + +func TestItemDoneEventToContentBlocksWebSearch(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksWebSearch", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 5, + } + + action := responses.ResponseFunctionWebSearchActionUnion{} + mockey.Mock(responses.ResponseFunctionWebSearchActionUnion.AsAny).Return(responses.ResponseFunctionWebSearchActionSearch{ + Query: "test", + }).Build() + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseFunctionWebSearch{ + ID: "wid", + Status: responses.ResponseFunctionWebSearchStatusCompleted, + Action: action, + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + assert.NotNil(t, blocks[0].ServerToolCall) + assert.NotNil(t, blocks[1].ServerToolResult) + }) +} + +func TestItemDoneEventToContentBlocksMCPCall(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksMCPCall", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 6, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseOutputItemMcpCall{ + ID: "mid", + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + Output: "result", + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + assert.NotNil(t, blocks[0].MCPToolCall) + assert.NotNil(t, blocks[1].MCPToolResult) + }) +} + +func TestItemDoneEventToContentBlocksMCPListTools(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksMCPListTools", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 7, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseOutputItemMcpListTools{ + ID: "lid", + ServerLabel: "server", + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].MCPListToolsResult) + }) +} + +func TestItemDoneEventToContentBlocksMCPApprovalRequest(t *testing.T) { + mockey.PatchConvey("TestItemDoneEventToContentBlocksMCPApprovalRequest", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputItemDoneEvent{ + OutputIndex: 8, + } + + mockey.Mock(responses.ResponseOutputItemUnion.AsAny).Return(responses.ResponseOutputItemMcpApprovalRequest{ + ID: "aid", + ServerLabel: "server", + Name: "tool", + Arguments: "{}", + }).Build() + + blocks, err := r.itemDoneEventToContentBlocks(ev) + assert.NoError(t, err) + assert.Equal(t, 1, len(blocks)) + assert.NotNil(t, blocks[0].MCPToolApprovalRequest) + }) +} + +func TestItemDoneEventOutputMessageToContentBlockMissingProcessing(t *testing.T) { + r := newStreamReceiver() + item := responses.ResponseOutputMessage{ + ID: "mid", + Status: responses.ResponseOutputMessageStatusCompleted, + } + _, err := r.itemDoneEventOutputMessageToContentBlock(item) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found in processing queue") +} + +func TestItemDoneEventOutputMessageToContentBlockSuccess(t *testing.T) { + r := newStreamReceiver() + r.ProcessingAssistantGenTextBlockIndex["mid"] = map[int]bool{0: true, 2: true} + + item := responses.ResponseOutputMessage{ + ID: "mid", + Status: responses.ResponseOutputMessageStatusCompleted, + } + blocks, err := r.itemDoneEventOutputMessageToContentBlock(item) + assert.NoError(t, err) + assert.Equal(t, 2, len(blocks)) + + for _, block := range blocks { + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "mid", id) + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, string(responses.ResponseOutputMessageStatusCompleted), status) + } +} + +func TestContentPartAddedEventToContentBlockInvalidType(t *testing.T) { + mockey.PatchConvey("TestContentPartAddedEventToContentBlockInvalidType", t, func() { + r := newStreamReceiver() + ev := responses.ResponseContentPartAddedEvent{} + + mockey.Mock(responses.ResponseOutputMessageContentUnion.AsAny).Return("invalid").Build() + + _, err := r.contentPartAddedEventToContentBlock(ev) + assert.Error(t, err) + }) +} + +func TestContentPartDoneEventToContentBlockNoIndex(t *testing.T) { + mockey.PatchConvey("TestContentPartDoneEventToContentBlockNoIndex", t, func() { + r := newStreamReceiver() + ev := responses.ResponseContentPartDoneEvent{ + ItemID: "mid", + OutputIndex: 1, + ContentIndex: 2, + } + + mockey.Mock(responses.ResponseOutputMessageContentUnion.AsAny).Return(responses.ResponseOutputText{}).Build() + + _, err := r.contentPartDoneEventToContentBlock(ev) + assert.Error(t, err) + assert.Contains(t, err.Error(), "has no processing assistant gen text block index") + }) +} + +func TestRefusalDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := responses.ResponseRefusalDeltaEvent{ + ItemID: "iid", + OutputIndex: 1, + ContentIndex: 1, + Delta: "refused", + } + + block := r.refusalDeltaEventToContentBlock(ev) + assert.NotNil(t, block.AssistantGenText) + assert.NotNil(t, block.AssistantGenText.OpenAIExtension) + assert.NotNil(t, block.AssistantGenText.OpenAIExtension.Refusal) + assert.Equal(t, "refused", block.AssistantGenText.OpenAIExtension.Refusal.Reason) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestOutputTextDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := responses.ResponseTextDeltaEvent{ + ItemID: "iid", + OutputIndex: 1, + ContentIndex: 1, + Delta: "delta text", + } + + block := r.outputTextDeltaEventToContentBlock(ev) + assert.NotNil(t, block.AssistantGenText) + assert.Equal(t, "delta text", block.AssistantGenText.Text) + assert.NotNil(t, block.StreamingMeta) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestAnnotationAddedEventToContentBlockFileCitation(t *testing.T) { + mockey.PatchConvey("TestAnnotationAddedEventToContentBlockFileCitation", t, func() { + r := newStreamReceiver() + ev := responses.ResponseOutputTextAnnotationAddedEvent{ + ItemID: "iid", + OutputIndex: 1, + ContentIndex: 1, + AnnotationIndex: 0, + } + + annoData := responses.ResponseOutputTextAnnotationFileCitation{ + Index: 10, + FileID: "fid", + Filename: "file.txt", + } + + mockey.Mock(sonic.Marshal).To(func(val any) ([]byte, error) { + return []byte(`{"type":"file_citation","index":10,"file_id":"fid","filename":"file.txt"}`), nil + }).Build() + + mockey.Mock(sonic.Unmarshal).To(func(data []byte, v any) error { + return nil + }).Build() + + mockey.Mock(responses.ResponseOutputTextAnnotationUnion.AsAny).Return(annoData).Build() + + block, err := r.annotationAddedEventToContentBlock(ev) + assert.NoError(t, err) + assert.NotNil(t, block.AssistantGenText) + assert.NotNil(t, block.AssistantGenText.OpenAIExtension) + assert.Len(t, block.AssistantGenText.OpenAIExtension.Annotations, 1) + }) +} + +func TestReasoningSummaryTextDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := responses.ResponseReasoningSummaryTextDeltaEvent{ + ItemID: "iid", + OutputIndex: 2, + SummaryIndex: 0, + Delta: "summary text", + } + + block := r.reasoningSummaryTextDeltaEventToContentBlock(ev) + assert.NotNil(t, block.Reasoning) + assert.Len(t, block.Reasoning.Summary, 1) + assert.Equal(t, "summary text", block.Reasoning.Summary[0].Text) + assert.Equal(t, 0, block.Reasoning.Summary[0].Index) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestFunctionCallArgumentsDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := responses.ResponseFunctionCallArgumentsDeltaEvent{ + ItemID: "iid", + OutputIndex: 3, + Delta: `{"arg":"val"}`, + } + + block := r.functionCallArgumentsDeltaEventToContentBlock(ev) + assert.NotNil(t, block.FunctionToolCall) + assert.Equal(t, `{"arg":"val"}`, block.FunctionToolCall.Arguments) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestMcpListToolsPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + + block := r.mcpListToolsPhaseToContentBlock("iid", 4, string(responses.ResponseStatusInProgress)) + assert.NotNil(t, block.MCPListToolsResult) + assert.NotNil(t, block.StreamingMeta) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) + + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, string(responses.ResponseStatusInProgress), status) +} + +func TestMcpListToolsPhaseToContentBlockEmptyStatus(t *testing.T) { + r := newStreamReceiver() + + block := r.mcpListToolsPhaseToContentBlock("iid", 4, "") + assert.NotNil(t, block.MCPListToolsResult) + + _, ok := GetItemStatus(block) + assert.False(t, ok) +} + +func TestMcpCallArgumentsDeltaEventToContentBlock(t *testing.T) { + r := newStreamReceiver() + ev := responses.ResponseMcpCallArgumentsDeltaEvent{ + ItemID: "iid", + OutputIndex: 6, + Delta: `{"key":"value"}`, + } + + block := r.mcpCallArgumentsDeltaEventToContentBlock(ev) + assert.NotNil(t, block.MCPToolCall) + assert.Equal(t, `{"key":"value"}`, block.MCPToolCall.Arguments) + + id, ok := getItemID(block) + assert.True(t, ok) + assert.Equal(t, "iid", id) +} + +func TestMcpCallPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + + block := r.mcpCallPhaseToContentBlock("iid", 7, string(responses.ResponseStatusFailed)) + assert.NotNil(t, block.MCPToolCall) + + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, string(responses.ResponseStatusFailed), status) +} + +func TestWebSearchPhaseToContentBlock(t *testing.T) { + r := newStreamReceiver() + + block := r.webSearchPhaseToContentBlock("iid", 8, string(responses.ResponseStatusCompleted)) + assert.NotNil(t, block.ServerToolCall) + + status, ok := GetItemStatus(block) + assert.True(t, ok) + assert.Equal(t, string(responses.ResponseStatusCompleted), status) +} + +func TestMakeIndexKeyFunctions(t *testing.T) { + assert.Equal(t, "assistant_gen_text:1:2", makeAssistantGenTextIndexKey(1, 2)) + assert.Equal(t, "reasoning:3", makeReasoningIndexKey(3)) + assert.Equal(t, "function_tool_call:4", makeFunctionToolCallIndexKey(4)) + assert.Equal(t, "server_tool_call:5", makeServerToolCallIndexKey(5)) + assert.Equal(t, "server_tool_result:6", makeServerToolResultIndexKey(6)) + assert.Equal(t, "mcp_list_tools_result:7", makeMCPListToolsResultIndexKey(7)) + assert.Equal(t, "mcp_tool_approval_request:8", makeMCPToolApprovalRequestIndexKey(8)) + assert.Equal(t, "mcp_tool_call:9", makeMCPToolCallIndexKey(9)) + assert.Equal(t, "mcp_tool_result:10", makeMCPToolResultIndexKey(10)) +} + +func TestNewCallbackSender(t *testing.T) { + _, sw := schema.Pipe[*agentic.CallbackOutput](8) + config := &agentic.Config{} + + s := newCallbackSender(sw, config) + assert.NotNil(t, s) + assert.Equal(t, sw, s.sw) + assert.Equal(t, config, s.config) +} + +func TestCallbackSenderSendMeta(t *testing.T) { + sr, sw := schema.Pipe[*agentic.CallbackOutput](8) + s := newCallbackSender(sw, &agentic.Config{}) + + meta := &schema.AgenticResponseMeta{} + s.sendMeta(meta, nil) + + r := sr.Copy(1)[0] + out, err := r.Recv() + assert.NoError(t, err) + assert.NotNil(t, out) + assert.NotNil(t, out.Message.ResponseMeta) +} + +func TestCallbackSenderSendBlock(t *testing.T) { + sr, sw := schema.Pipe[*agentic.CallbackOutput](8) + s := newCallbackSender(sw, &agentic.Config{}) + + block := schema.NewContentBlock(&schema.AssistantGenText{Text: "test"}) + s.sendBlock(block, nil) + + r := sr.Copy(1)[0] + out, err := r.Recv() + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Len(t, out.Message.ContentBlocks, 1) +} + +func TestCallbackSenderSendError(t *testing.T) { + sr, sw := schema.Pipe[*agentic.CallbackOutput](8) + s := newCallbackSender(sw, &agentic.Config{}) + s.errHeader = "test error" + + s.sendMeta(nil, errors.New("error")) + + r := sr.Copy(1)[0] + _, err := r.Recv() + assert.Error(t, err) + assert.Contains(t, err.Error(), "test error") +} diff --git a/components/agentic/openai/examples/generate/main.go b/components/agentic/openai/examples/generate/main.go new file mode 100644 index 000000000..8da346ef2 --- /dev/null +++ b/components/agentic/openai/examples/generate/main.go @@ -0,0 +1,88 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/schema" + openaischema "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err: %v", err) + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what is the weather like in Beijing"), + } + + am_, err := am.WithTools([]*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model with tools, err: %v", err) + } + + msg, err := am_.Generate(ctx, input) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := msg.ResponseMeta.Extension.(*openaischema.ResponseMetaExtension) + + log.Printf("request_id: %s\n", meta.ID) + respBody, _ := sonic.MarshalIndent(msg, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/openai/examples/sstream_with_mcp_tool/main.go b/components/agentic/openai/examples/sstream_with_mcp_tool/main.go new file mode 100644 index 000000000..6ff26f664 --- /dev/null +++ b/components/agentic/openai/examples/sstream_with_mcp_tool/main.go @@ -0,0 +1,105 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + mcpTools := []*responses.ToolMcpParam{ + { + ServerLabel: "test_mcp_server", + RequireApproval: responses.ToolMcpRequireApprovalUnionParam{ + OfMcpToolApprovalSetting: param.NewOpt("never"), + }, + ServerURL: param.NewOpt("server url"), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + MCPTool: &schema.AllowedMCPTool{ + ServerLabel: "test_mcp_server", + Name: "amap/maps_weather", + }, + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + openai.WithMCPTools(mcpTools), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/openai/examples/stream_with_function_tool/main.go b/components/agentic/openai/examples/stream_with_function_tool/main.go new file mode 100644 index 000000000..c21857eb4 --- /dev/null +++ b/components/agentic/openai/examples/stream_with_function_tool/main.go @@ -0,0 +1,130 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/responses" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + functionTools := []*schema.ToolInfo{ + { + Name: "get_weather", + Desc: "get the weather in a city", + ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema]( + orderedmap.WithInitialData( + orderedmap.Pair[string, *jsonschema.Schema]{ + Key: "city", + Value: &jsonschema.Schema{ + Type: "string", + Description: "the city to get the weather", + }, + }, + ), + ), + Required: []string{"city"}, + }), + }, + } + + allowedTools := []*schema.AllowedTool{ + { + FunctionToolName: "get_weather", + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + agentic.WithTools(functionTools), + } + + firstInput := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's the weather like in Beijing today"), + } + + sResp, err := am.Stream(ctx, firstInput, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := sResp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + lastBlock := concatenated.ContentBlocks[len(concatenated.ContentBlocks)-1] + if lastBlock.Type != schema.ContentBlockTypeFunctionToolCall { + log.Fatalf("last block is not function tool call, type: %s", lastBlock.Type) + } + + toolCall := lastBlock.FunctionToolCall + toolResultMsg := schema.FunctionToolResultAgenticMessage(toolCall.CallID, toolCall.Name, "20 degrees") + + secondInput := append(firstInput, concatenated, toolResultMsg) + + gResp, err := am.Generate(ctx, secondInput, opts...) + if err != nil { + log.Fatalf("failed to generate, err: %v", err) + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(gResp, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/openai/examples/stream_with_server_tool/main.go b/components/agentic/openai/examples/stream_with_server_tool/main.go new file mode 100644 index 000000000..997f690eb --- /dev/null +++ b/components/agentic/openai/examples/stream_with_server_tool/main.go @@ -0,0 +1,118 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "errors" + "io" + "log" + "os" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino-ext/components/agentic/openai" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/responses" +) + +func main() { + ctx := context.Background() + + am, err := openai.New(ctx, &openai.Config{ + BaseURL: "https://api.openai.com/v1", + Model: os.Getenv("OPENAI_MODEL_ID"), + APIKey: os.Getenv("OPENAI_API_KEY"), + Reasoning: &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + Summary: responses.ReasoningSummaryDetailed, + }, + Include: []responses.ResponseIncludable{ + responses.ResponseIncludableWebSearchCallActionSources, + }, + }) + if err != nil { + log.Fatalf("failed to create agentic model, err=%v", err) + } + + serverTools := []*openai.ServerToolConfig{ + { + WebSearch: &responses.WebSearchToolParam{ + Type: responses.WebSearchToolTypeWebSearch, + }, + }, + } + + allowedTools := []*schema.AllowedTool{ + { + ServerTool: &schema.AllowedServerTool{ + Name: string(openai.ServerToolNameWebSearch), + }, + }, + } + + opts := []agentic.Option{ + agentic.WithToolChoice(schema.ToolChoiceForced, allowedTools...), + openai.WithServerTools(serverTools), + } + + input := []*schema.AgenticMessage{ + schema.UserAgenticMessage("what's cloudwego/eino"), + } + + resp, err := am.Stream(ctx, input, opts...) + if err != nil { + log.Fatalf("failed to stream, err: %v", err) + } + + var msgs []*schema.AgenticMessage + for { + msg, err := resp.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + log.Fatalf("failed to receive stream response, err: %v", err) + } + msgs = append(msgs, msg) + } + + concatenated, err := schema.ConcatAgenticMessages(msgs) + if err != nil { + log.Fatalf("failed to concat agentic messages, err: %v", err) + } + + for _, block := range concatenated.ContentBlocks { + if block.ServerToolCall != nil { + serverToolArgs := block.ServerToolCall.Arguments.(*openai.ServerToolCallArguments) + args, _ := sonic.MarshalIndent(serverToolArgs, " ", " ") + log.Printf("server_tool_args: %s\n", string(args)) + } + + if block.ServerToolResult != nil { + result := block.ServerToolResult.Result.(*openai.ServerToolResult) + resultJSON, _ := sonic.MarshalIndent(result, " ", " ") + log.Printf("server_tool_result: %s\n", string(resultJSON)) + } + } + + meta := concatenated.ResponseMeta.OpenAIExtension + log.Printf("request_id: %s\n", meta.ID) + + respBody, _ := sonic.MarshalIndent(concatenated, " ", " ") + log.Printf(" body: %s\n", string(respBody)) +} diff --git a/components/agentic/openai/extension.go b/components/agentic/openai/extension.go new file mode 100644 index 000000000..c5c041362 --- /dev/null +++ b/components/agentic/openai/extension.go @@ -0,0 +1,108 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + + "github.com/cloudwego/eino/schema" +) + +type ServerToolCallArguments struct { + WebSearch *WebSearchArguments `json:"web_search,omitempty"` +} + +type ServerToolResult struct { + WebSearch *WebSearchResult `json:"web_search,omitempty"` +} + +type WebSearchArguments struct { + ActionType WebSearchAction `json:"action_type,omitempty"` + + Search *WebSearchQuery `json:"search,omitempty"` + OpenPage *WebSearchOpenPage `json:"open_page,omitempty"` + Find *WebSearchFind `json:"find,omitempty"` +} + +type WebSearchQuery struct { + Query string `json:"query,omitempty"` +} + +type WebSearchOpenPage struct { + URL string `json:"url,omitempty"` +} + +type WebSearchFind struct { + URL string `json:"url,omitempty"` + Pattern string `json:"pattern,omitempty"` +} + +type WebSearchQueryResult struct { + Sources []*WebSearchQuerySource `json:"sources,omitempty"` +} + +type WebSearchQuerySource struct { + URL string `json:"url,omitempty"` +} + +type WebSearchResult struct { + ActionType WebSearchAction `json:"action_type,omitempty"` + + Search *WebSearchQueryResult `json:"search,omitempty"` +} + +func getServerToolCallArguments(call *schema.ServerToolCall) (*ServerToolCallArguments, error) { + if call == nil || call.Arguments == nil { + return nil, fmt.Errorf("server tool call arguments are nil") + } + arguments, ok := call.Arguments.(*ServerToolCallArguments) + if !ok { + return nil, fmt.Errorf("expected '%T', but got '%T'", &ServerToolCallArguments{}, call.Arguments) + } + return arguments, nil +} + +func getServerToolResult(content *schema.ServerToolResult) (*ServerToolResult, error) { + if content == nil || content.Result == nil { + return nil, fmt.Errorf("server tool result is nil") + } + result, ok := content.Result.(*ServerToolResult) + if !ok { + return nil, fmt.Errorf("expected '%T', but got '%T'", &ServerToolResult{}, content.Result) + } + return result, nil +} + +func concatServerToolCallArguments(chunks []*ServerToolCallArguments) (ret *ServerToolCallArguments, err error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no server tool call arguments found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + return nil, fmt.Errorf("cannot concat multiple server tool call arguments") +} + +func concatServerToolResult(chunks []*ServerToolResult) (ret *ServerToolResult, err error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + return nil, fmt.Errorf("cannot concat multiple server tool result") +} diff --git a/components/agentic/openai/extension_test.go b/components/agentic/openai/extension_test.go new file mode 100644 index 000000000..ba65dfbba --- /dev/null +++ b/components/agentic/openai/extension_test.go @@ -0,0 +1,151 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +func TestGetServerToolCallArguments(t *testing.T) { + mockey.PatchConvey("getServerToolCallArguments", t, func() { + mockey.PatchConvey("success", func() { + args := &ServerToolCallArguments{} + call := &schema.ServerToolCall{ + Arguments: args, + } + res, err := getServerToolCallArguments(call) + assert.NoError(t, err) + assert.Equal(t, args, res) + }) + + mockey.PatchConvey("nil input", func() { + res, err := getServerToolCallArguments(nil) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("nil arguments", func() { + call := &schema.ServerToolCall{ + Arguments: nil, + } + res, err := getServerToolCallArguments(call) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("wrong type", func() { + call := &schema.ServerToolCall{ + Arguments: "wrong type", + } + res, err := getServerToolCallArguments(call) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) +} + +func TestGetServerToolResult(t *testing.T) { + mockey.PatchConvey("getServerToolResult", t, func() { + mockey.PatchConvey("success", func() { + result := &ServerToolResult{} + content := &schema.ServerToolResult{ + Result: result, + } + res, err := getServerToolResult(content) + assert.NoError(t, err) + assert.Equal(t, result, res) + }) + + mockey.PatchConvey("nil input", func() { + res, err := getServerToolResult(nil) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("nil result", func() { + content := &schema.ServerToolResult{ + Result: nil, + } + res, err := getServerToolResult(content) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("wrong type", func() { + content := &schema.ServerToolResult{ + Result: "wrong type", + } + res, err := getServerToolResult(content) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) +} + +func TestConcatServerToolCallArguments(t *testing.T) { + mockey.PatchConvey("concatServerToolCallArguments", t, func() { + mockey.PatchConvey("empty chunks", func() { + res, err := concatServerToolCallArguments(nil) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("one chunk", func() { + args := &ServerToolCallArguments{} + res, err := concatServerToolCallArguments([]*ServerToolCallArguments{args}) + assert.NoError(t, err) + assert.Equal(t, args, res) + }) + + mockey.PatchConvey("multiple chunks", func() { + args1 := &ServerToolCallArguments{} + args2 := &ServerToolCallArguments{} + res, err := concatServerToolCallArguments([]*ServerToolCallArguments{args1, args2}) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) +} + +func TestConcatServerToolResult(t *testing.T) { + mockey.PatchConvey("concatServerToolResult", t, func() { + mockey.PatchConvey("empty chunks", func() { + res, err := concatServerToolResult(nil) + assert.Error(t, err) + assert.Nil(t, res) + }) + + mockey.PatchConvey("one chunk", func() { + result := &ServerToolResult{} + res, err := concatServerToolResult([]*ServerToolResult{result}) + assert.NoError(t, err) + assert.Equal(t, result, res) + }) + + mockey.PatchConvey("multiple chunks", func() { + result1 := &ServerToolResult{} + result2 := &ServerToolResult{} + res, err := concatServerToolResult([]*ServerToolResult{result1, result2}) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) +} diff --git a/components/agentic/openai/go.mod b/components/agentic/openai/go.mod new file mode 100644 index 000000000..81265a211 --- /dev/null +++ b/components/agentic/openai/go.mod @@ -0,0 +1,55 @@ +module github.com/cloudwego/eino-ext/components/agentic/openai + +go 1.22 + +require ( + github.com/bytedance/mockey v1.3.0 + github.com/bytedance/sonic v1.14.1 + github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31 + github.com/eino-contrib/jsonschema v1.0.3 + github.com/openai/openai-go/v3 v3.15.0 + github.com/stretchr/testify v1.10.0 + github.com/wk8/go-ordered-map/v2 v2.1.8 + golang.org/x/sync v0.10.0 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/nikolalohinski/gonja v1.5.3 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + golang.org/x/arch v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.21.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/components/agentic/openai/go.sum b/components/agentic/openai/go.sum new file mode 100644 index 000000000..865de028f --- /dev/null +++ b/components/agentic/openai/go.sum @@ -0,0 +1,177 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/mockey v1.3.0 h1:ONLRdvhqmCfr9rTasUB8ZKCfvbdD2tohOg4u+4Q/ed0= +github.com/bytedance/mockey v1.3.0/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31 h1:Vj2VKfW6A+FpzGdU4MJyIOEDVcI5Zyr0uEQanPa7PyE= +github.com/cloudwego/eino v0.7.19-0.20260108113617-d04d4b5bda31/go.mod h1:OdDJi17QawFUJRIFrVJRgdc9grjrh3eFDD0k34ZRH8M= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/openai/openai-go/v3 v3.15.0 h1:hk99rM7YPz+M99/5B/zOQcVwFRLLMdprVGx1vaZ8XMo= +github.com/openai/openai-go/v3 v3.15.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= +golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/components/agentic/openai/model.go b/components/agentic/openai/model.go new file mode 100644 index 000000000..38e4ebff2 --- /dev/null +++ b/components/agentic/openai/model.go @@ -0,0 +1,701 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "context" + "errors" + "fmt" + "net/http" + "runtime/debug" + "time" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/agentic" + "github.com/cloudwego/eino/schema" + "github.com/openai/openai-go/v3/azure" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" +) + +var _ agentic.Model = (*Model)(nil) + +type Config struct { + // ByAzure specifies whether to use Azure OpenAI service. + // Optional. + ByAzure bool + + // BaseURL specifies the base URL for the OpenAI service endpoint. + // Optional. + BaseURL string + + // APIKey specifies the API key for authentication. + // Required. + APIKey string + + // Timeout specifies the maximum duration to wait for API responses. + // Optional. + Timeout *time.Duration + + // HTTPClient specifies the HTTP client used to send requests. + // Optional. + HTTPClient *http.Client + + // MaxRetries specifies the maximum number of retry attempts for failed requests. + // Optional. + MaxRetries *int + + // Model specifies the ID of the model to use for the response. + // Required. + Model string + + // MaxOutputTokens specifies the maximum number of tokens to generate in the response. + // Optional. + MaxOutputTokens *int64 + + // Temperature controls the randomness of the model's output. + // Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic. + // Range: 0.0 to 2.0. + // Optional. + Temperature *float64 + + // TopP controls diversity via nucleus sampling. + // It specifies the cumulative probability threshold for token selection. + // Recommended to use either Temperature or TopP, but not both. + // Range: 0.0 to 1.0. + // Optional. + TopP *float64 + + // ServiceTier specifies the latency tier for processing the request. + // Optional. + ServiceTier *responses.ResponseNewParamsServiceTier + + // Text specifies configuration for text generation output. + // Optional. + Text *responses.ResponseTextConfigParam + + // Reasoning specifies configuration for reasoning models. + // Optional. + Reasoning *responses.ReasoningParam + + // Store specifies whether to store the response on the server. + // Optional. + Store *bool + + // MaxToolCalls specifies the maximum number of tool calls allowed in a single turn. + // Optional. + MaxToolCalls *int + + // ParallelToolCalls specifies whether to allow multiple tool calls in a single turn. + // Optional. + ParallelToolCalls *bool + + // Include specifies a list of additional fields to include in the response. + // Optional. + Include []responses.ResponseIncludable + + // ServerTools specifies server-side tools available to the model. + // Optional. + ServerTools []*ServerToolConfig + + // MCPTools specifies Model Context Protocol tools available to the model. + // Optional. + MCPTools []*responses.ToolMcpParam + + // CustomHeader specifies custom HTTP headers to include in API requests. + // CustomHeader allows passing additional metadata or authentication information. + // Optional. + CustomHeader map[string]string + + // ExtraFields specifies additional fields that will be directly added to the HTTP request body. + // This allows for vendor-specific or future parameters not yet explicitly supported. + // Optional. + ExtraFields map[string]any +} + +type ServerToolConfig struct { + WebSearch *responses.WebSearchToolParam +} + +func New(_ context.Context, config *Config) (*Model, error) { + if config == nil { + config = &Config{} + } + + c, err := buildClient(config) + if err != nil { + return nil, err + } + + return c, nil +} + +func buildClient(config *Config) (*Model, error) { + var opts []option.RequestOption + + if config.Timeout != nil { + opts = append(opts, option.WithRequestTimeout(*config.Timeout)) + } + if config.HTTPClient != nil { + opts = append(opts, option.WithHTTPClient(config.HTTPClient)) + } + if config.BaseURL != "" { + opts = append(opts, option.WithBaseURL(config.BaseURL)) + } + if config.APIKey != "" { + if config.ByAzure { + opts = append(opts, azure.WithAPIKey(config.APIKey)) + } else { + opts = append(opts, option.WithAPIKey(config.APIKey)) + } + } + if config.MaxRetries != nil { + opts = append(opts, option.WithMaxRetries(*config.MaxRetries)) + } + + client := responses.NewResponseService(opts...) + + cm := &Model{ + cli: client, + model: config.Model, + maxOutputTokens: config.MaxOutputTokens, + temperature: config.Temperature, + topP: config.TopP, + serviceTier: config.ServiceTier, + text: config.Text, + reasoning: config.Reasoning, + store: config.Store, + maxToolCalls: config.MaxToolCalls, + parallelToolCalls: config.ParallelToolCalls, + include: config.Include, + serverTools: config.ServerTools, + mcpTools: config.MCPTools, + customHeader: config.CustomHeader, + extraFields: config.ExtraFields, + } + + return cm, nil +} + +type Model struct { + cli responses.ResponseService + + rawFunctionTools []*schema.ToolInfo + functionTools []responses.ToolUnionParam + + model string + maxOutputTokens *int64 + temperature *float64 + topP *float64 + serviceTier *responses.ResponseNewParamsServiceTier + text *responses.ResponseTextConfigParam + reasoning *responses.ReasoningParam + store *bool + maxToolCalls *int + parallelToolCalls *bool + include []responses.ResponseIncludable + serverTools []*ServerToolConfig + mcpTools []*responses.ToolMcpParam + + customHeader map[string]string + extraFields map[string]any +} + +func (m *Model) Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...agentic.Option) ( + outMsg *schema.AgenticMessage, err error) { + + ctx = callbacks.EnsureRunInfo(ctx, m.GetType(), components.ComponentOfAgenticModel) + + options, specOptions, err := m.getOptions(opts) + if err != nil { + return nil, err + } + + req, reqOpts, err := m.genRequestAndOptions(input, options, specOptions) + if err != nil { + return nil, fmt.Errorf("genRequestAndOptions failed: %w", err) + } + + config := toCallbackConfig(req) + + tools := m.rawFunctionTools + if options.Tools != nil { + tools = options.Tools + } + + ctx = callbacks.OnStart(ctx, &agentic.CallbackInput{ + Messages: input, + Tools: tools, + ToolChoice: options.ToolChoice, + Config: config, + }) + + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + responseObject, err := m.cli.New(ctx, *req, reqOpts...) + if err != nil { + return nil, fmt.Errorf("failed to create responses, err: %w", err) + } + + outMsg, err = toOutputMessage(responseObject) + if err != nil { + return nil, fmt.Errorf("failed to convert output to message, err: %w", err) + } + + callbacks.OnEnd(ctx, &agentic.CallbackOutput{ + Message: outMsg, + Config: config, + }) + + return outMsg, nil +} + +func (m *Model) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...agentic.Option) ( + outStream *schema.StreamReader[*schema.AgenticMessage], err error) { + + ctx = callbacks.EnsureRunInfo(ctx, m.GetType(), components.ComponentOfAgenticModel) + + options, specOptions, err := m.getOptions(opts) + if err != nil { + return nil, err + } + + req, reqOpts, err := m.genRequestAndOptions(input, options, specOptions) + if err != nil { + return nil, fmt.Errorf("genRequestAndOptions failed: %w", err) + } + + config := toCallbackConfig(req) + tools := m.rawFunctionTools + if options.Tools != nil { + tools = options.Tools + } + + ctx = callbacks.OnStart(ctx, &agentic.CallbackInput{ + Messages: input, + Tools: tools, + ToolChoice: options.ToolChoice, + Config: config, + }) + + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + respStreamReader := m.cli.NewStreaming(ctx, *req, reqOpts...) + + sr, sw := schema.Pipe[*agentic.CallbackOutput](1) + + go func() { + defer func() { + pe := recover() + if pe != nil { + _ = sw.Send(nil, newPanicErr(pe, debug.Stack())) + } + + _ = respStreamReader.Close() + sw.Close() + }() + + receivedStreamingResponse(respStreamReader, config, sw) + + }() + + ctx, nsr := callbacks.OnEndWithStreamOutput(ctx, schema.StreamReaderWithConvert(sr, + func(src *agentic.CallbackOutput) (callbacks.CallbackOutput, error) { + if src.Extra == nil { + src.Extra = make(map[string]any) + } + return src, nil + }, + )) + + outStream = schema.StreamReaderWithConvert(nsr, + func(src callbacks.CallbackOutput) (*schema.AgenticMessage, error) { + s := src.(*agentic.CallbackOutput) + if s.Message == nil { + return nil, schema.ErrNoValue + } + return s.Message, nil + }, + ) + + return outStream, err +} + +func (m *Model) WithTools(functionTools []*schema.ToolInfo) (agentic.Model, error) { + if len(functionTools) == 0 { + return nil, errors.New("function tools are required") + } + + fts, err := toFunctionTools(functionTools) + if err != nil { + return nil, fmt.Errorf("failed to convert function tools, err: %w", err) + } + + m_ := *m + m_.rawFunctionTools = functionTools + m_.functionTools = fts + + return &m_, nil +} + +func (m *Model) GetType() string { + return implType +} + +func (m *Model) IsCallbacksEnabled() bool { + return true +} + +type CacheInfo struct { + // ResponseID return by ResponsesAPI, it's specifies the id of prefix that can be used with [WithCache.HeadPreviousResponseID] option. + ResponseID string + // Usage specifies the token usage of prefix + Usage schema.TokenUsage +} + +func toCallbackConfig(req *responses.ResponseNewParams) *agentic.Config { + return &agentic.Config{ + Model: req.Model, + Temperature: req.Temperature.Value, + TopP: req.TopP.Value, + } +} + +func toFunctionTools(functionTools []*schema.ToolInfo) ([]responses.ToolUnionParam, error) { + tools := make([]responses.ToolUnionParam, len(functionTools)) + for i := range functionTools { + ti := functionTools[i] + + paramsJSONSchema, err := ti.ParamsOneOf.ToJSONSchema() + if err != nil { + return nil, fmt.Errorf("failed to convert tool parameters to JSONSchema, err: %w", err) + } + + b, err := sonic.Marshal(paramsJSONSchema) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSONSchema, err: %w", err) + } + + var params map[string]any + err = sonic.Unmarshal(b, ¶ms) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal JSONSchema, err: %w", err) + } + + tools[i] = responses.ToolUnionParam{ + OfFunction: &responses.FunctionToolParam{ + Name: ti.Name, + Description: newOpenaiStrOpt(ti.Desc), + Parameters: params, + }, + } + } + + return tools, nil +} + +func toServerTools(serverTools []*ServerToolConfig) ([]responses.ToolUnionParam, error) { + tools := make([]responses.ToolUnionParam, len(serverTools)) + + for i := range serverTools { + ti := serverTools[i] + switch { + case ti.WebSearch != nil: + tools[i] = responses.ToolUnionParam{ + OfWebSearch: ti.WebSearch, + } + + default: + return nil, fmt.Errorf("found unknown server tool") + } + } + + return tools, nil +} + +func (m *Model) getOptions(opts []agentic.Option) (*agentic.Options, *openaiOptions, error) { + options := agentic.GetCommonOptions(&agentic.Options{ + Temperature: m.temperature, + Model: &m.model, + TopP: m.topP, + Tools: nil, + }, opts...) + + specOptions := agentic.GetImplSpecificOptions(&openaiOptions{ + reasoning: m.reasoning, + store: m.store, + text: m.text, + maxToolCalls: m.maxToolCalls, + parallelToolCalls: m.parallelToolCalls, + maxOutputTokens: m.maxOutputTokens, + serverTools: m.serverTools, + mcpTools: m.mcpTools, + customHeaders: m.customHeader, + }, opts...) + + return options, specOptions, nil +} + +func (m *Model) genRequestAndOptions(in []*schema.AgenticMessage, options *agentic.Options, + specOptions *openaiOptions) (req *responses.ResponseNewParams, reqOpts []option.RequestOption, err error) { + + req = &responses.ResponseNewParams{} + + err = m.prePopulateConfig(req, options, specOptions) + if err != nil { + return req, nil, fmt.Errorf("failed to prePopulateConfig, err: %w", err) + } + + err = m.populateInput(in, req) + if err != nil { + return req, nil, fmt.Errorf("failed to populateInput, err: %w", err) + } + + err = m.populateTools(req, options, specOptions) + if err != nil { + return req, nil, fmt.Errorf("failed to populateTools, err: %w", err) + } + + err = m.populateToolChoice(req, options) + if err != nil { + return req, nil, fmt.Errorf("failed to populateToolChoice, err: %w", err) + } + + for k, v := range specOptions.customHeaders { + reqOpts = append(reqOpts, option.WithHeaderAdd(k, v)) + } + + for k, v := range specOptions.extraFields { + reqOpts = append(reqOpts, option.WithJSONSet(k, v)) + } + + return req, reqOpts, nil +} + +func (m *Model) prePopulateConfig(responseReq *responses.ResponseNewParams, options *agentic.Options, + specOptions *openaiOptions) error { + + // instance configuration + if m.serviceTier != nil { + responseReq.ServiceTier = *m.serviceTier + } + responseReq.Include = m.include + + // options configuration + responseReq.TopP = newOpenaiOpt(options.TopP) + responseReq.Temperature = newOpenaiOpt(options.Temperature) + if options.Model != nil { + responseReq.Model = *options.Model + } + + // specific options configuration + if specOptions.reasoning != nil { + responseReq.Reasoning = *specOptions.reasoning + } + if specOptions.text != nil { + responseReq.Text = *specOptions.text + } + responseReq.MaxOutputTokens = newOpenaiOpt(specOptions.maxOutputTokens) + if specOptions.maxToolCalls != nil { + responseReq.MaxToolCalls = param.NewOpt(int64(*specOptions.maxToolCalls)) + } + responseReq.ParallelToolCalls = newOpenaiOpt(specOptions.parallelToolCalls) + responseReq.PromptCacheKey = newOpenaiOpt(specOptions.promptCacheKey) + responseReq.Store = newOpenaiOpt(specOptions.store) + + return nil +} + +func (m *Model) populateInput(in []*schema.AgenticMessage, responseReq *responses.ResponseNewParams) (err error) { + if len(in) == 0 { + return nil + } + + itemList := make([]responses.ResponseInputItemUnionParam, 0, len(in)) + + for _, msg := range in { + var inputItems []responses.ResponseInputItemUnionParam + + switch msg.Role { + case schema.AgenticRoleTypeUser: + inputItems, err = toUserRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeAssistant: + inputItems, err = toAssistantRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeDeveloper: + inputItems, err = toDeveloperRoleInputItems(msg) + if err != nil { + return err + } + + case schema.AgenticRoleTypeSystem: + inputItems, err = toSystemRoleInputItems(msg) + if err != nil { + return err + } + + default: + return fmt.Errorf("invalid role in message: %s", msg.Role) + } + + itemList = append(itemList, inputItems...) + } + + responseReq.Input = responses.ResponseNewParamsInputUnion{ + OfInputItemList: itemList, + } + + return nil +} + +func (m *Model) populateToolChoice(responseReq *responses.ResponseNewParams, options *agentic.Options) (err error) { + if options.ToolChoice == nil && len(options.AllowedTools) > 0 { + return fmt.Errorf("tool choice must be specified when allowed tools are provided") + } + if options.ToolChoice == nil { + return nil + } + + switch *options.ToolChoice { + case schema.ToolChoiceForbidden: + if len(options.AllowedTools) > 0 { + return fmt.Errorf("allowed tools must be empty when tool choice is 'forbidden'") + } + responseReq.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptionsNone), + } + + case schema.ToolChoiceAllowed: + if len(options.AllowedTools) == 0 { + return fmt.Errorf("allowed tools must be provided when tool choice is 'allowed'") + } + tools, err := toAllowedTools(options.AllowedTools) + if err != nil { + return err + } + responseReq.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfAllowedTools: &responses.ToolChoiceAllowedParam{ + Mode: responses.ToolChoiceAllowedModeAuto, + Tools: tools, + }, + } + + case schema.ToolChoiceForced: + if len(options.AllowedTools) == 0 { + return fmt.Errorf("allowed tools must be provided when tool choice is 'forced'") + } + tools, err := toAllowedTools(options.AllowedTools) + if err != nil { + return err + } + responseReq.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfAllowedTools: &responses.ToolChoiceAllowedParam{ + Mode: responses.ToolChoiceAllowedModeRequired, + Tools: tools, + }, + } + + default: + return fmt.Errorf("invalid tool choice: %s", *options.ToolChoice) + } + + return nil +} + +func toAllowedTools(tools []*schema.AllowedTool) ([]map[string]any, error) { + allowedTools := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + switch { + case tool.FunctionToolName != "": + allowedTools = append(allowedTools, map[string]any{ + "type": "function", + "name": tool.FunctionToolName, + }) + + case tool.MCPTool != nil: + tool_ := map[string]any{ + "type": "mcp", + "server_label": tool.MCPTool.ServerLabel, + } + if tool.MCPTool.Name != "" { + tool_["name"] = tool.MCPTool.Name + } + allowedTools = append(allowedTools, tool_) + + case tool.ServerTool != nil: + allowedTools = append(allowedTools, map[string]any{ + "type": tool.ServerTool.Name, + }) + + default: + return nil, fmt.Errorf("found unknown allowed tool") + } + } + + return allowedTools, nil +} + +func (m *Model) populateTools(responseReq *responses.ResponseNewParams, options *agentic.Options, specOptions *openaiOptions) (err error) { + var functionTools []responses.ToolUnionParam + if options.Tools != nil { + functionTools, err = toFunctionTools(options.Tools) + if err != nil { + return err + } + } else { + functionTools = m.functionTools + } + + responseReq.Tools = append(responseReq.Tools, functionTools...) + + serverTools, err := toServerTools(specOptions.serverTools) + if err != nil { + return err + } + + responseReq.Tools = append(responseReq.Tools, serverTools...) + + if len(specOptions.mcpTools) > 0 { + mcpTools := make([]responses.ToolUnionParam, 0, len(specOptions.mcpTools)) + for _, tool := range specOptions.mcpTools { + mcpTools = append(mcpTools, responses.ToolUnionParam{ + OfMcp: tool, + }) + } + responseReq.Tools = append(responseReq.Tools, mcpTools...) + } + + return nil +} diff --git a/components/agentic/openai/model_test.go b/components/agentic/openai/model_test.go new file mode 100644 index 000000000..a1549f3b6 --- /dev/null +++ b/components/agentic/openai/model_test.go @@ -0,0 +1,203 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "context" + "errors" + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/schema" + "github.com/eino-contrib/jsonschema" + "github.com/openai/openai-go/v3/packages/ssestream" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + mockey.PatchConvey("TestNew", t, func() { + mockey.PatchConvey("success", func() { + config := &Config{ + APIKey: "test", + } + m, err := New(context.Background(), config) + assert.NoError(t, err) + assert.NotNil(t, m) + assert.Equal(t, implType, m.GetType()) + }) + mockey.PatchConvey("config nil", func() { + m, err := New(context.Background(), nil) + assert.NoError(t, err) + assert.NotNil(t, m) + }) + }) +} + +func TestModelGenerate(t *testing.T) { + mockey.PatchConvey("TestModelGenerate", t, func() { + ctx := context.Background() + config := &Config{APIKey: "test"} + m, err := New(ctx, config) + assert.NoError(t, err) + + input := []*schema.AgenticMessage{ + {Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{schema.NewContentBlock(&schema.UserInputText{Text: "hi"})}}, + } + + mockey.PatchConvey("success", func() { + mockey.Mock((*responses.ResponseService).New).Return(&responses.Response{}, nil).Build() + + mockey.Mock(toOutputMessage).Return(&schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "hello"}), + }, + }, nil).Build() + + out, err := m.Generate(ctx, input) + assert.NoError(t, err) + assert.NotNil(t, out) + if assert.NotEmpty(t, out.ContentBlocks) { + assert.Equal(t, "hello", out.ContentBlocks[0].AssistantGenText.Text) + } + }) + + mockey.PatchConvey("genRequest error", func() { + invalidInput := []*schema.AgenticMessage{ + {Role: "invalid", ContentBlocks: []*schema.ContentBlock{schema.NewContentBlock(&schema.UserInputText{Text: "hi"})}}, + } + out, err := m.Generate(ctx, invalidInput) + assert.Error(t, err) + assert.Nil(t, out) + assert.Contains(t, err.Error(), "invalid role") + }) + + mockey.PatchConvey("cli.New error", func() { + mockey.Mock((*responses.ResponseService).New).Return(nil, errors.New("api error")).Build() + out, err := m.Generate(ctx, input) + assert.Error(t, err) + assert.Nil(t, out) + assert.Contains(t, err.Error(), "api error") + }) + + mockey.PatchConvey("toOutputMessage error", func() { + mockey.Mock((*responses.ResponseService).New).Return(&responses.Response{}, nil).Build() + mockey.Mock(toOutputMessage).Return(nil, errors.New("convert error")).Build() + + out, err := m.Generate(ctx, input) + assert.Error(t, err) + assert.Nil(t, out) + }) + }) +} + +func TestModelStream(t *testing.T) { + mockey.PatchConvey("TestModelStream", t, func() { + ctx := context.Background() + config := &Config{APIKey: "test", Model: "gpt-4"} + m, err := New(ctx, config) + assert.NoError(t, err) + + input := []*schema.AgenticMessage{ + {Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{schema.NewContentBlock(&schema.UserInputText{Text: "hi"})}}, + } + + mockey.PatchConvey("success", func() { + // Create a mock stream that will return events + mockStream := &ssestream.Stream[responses.ResponseStreamEventUnion]{} + + mockey.Mock((*responses.ResponseService).NewStreaming).Return(mockStream).Build() + + // Mock the stream methods to simulate a successful stream + // Use Sequence to return true once, then false + mockey.Mock(mockey.GetMethod(mockStream, "Next")).Return(mockey.Sequence(true).Then(false)).Build() + + mockey.Mock(mockey.GetMethod(mockStream, "Current")).Return(responses.ResponseStreamEventUnion{ + Type: "response.completed", + }).Build() + + mockey.Mock(mockey.GetMethod(mockStream, "Err")).Return(nil).Build() + mockey.Mock(mockey.GetMethod(mockStream, "Close")).Return(nil).Build() + + // Mock AsAny to return a completed event + mockey.Mock(responses.ResponseStreamEventUnion.AsAny).Return(responses.ResponseCompletedEvent{ + Response: responses.Response{ + Output: []responses.ResponseOutputItemUnion{ + {Type: "message", ID: "m1", Status: "completed"}, + }, + }, + }).Build() + + s, err := m.Stream(ctx, input) + assert.NoError(t, err) + assert.NotNil(t, s) + defer s.Close() + + // The stream should eventually close without errors + // We just verify it was created successfully + }) + + mockey.PatchConvey("genRequest error", func() { + invalidInput := []*schema.AgenticMessage{ + {Role: "invalid", ContentBlocks: []*schema.ContentBlock{schema.NewContentBlock(&schema.UserInputText{Text: "hi"})}}, + } + s, err := m.Stream(ctx, invalidInput) + assert.Error(t, err) + assert.Nil(t, s) + }) + }) +} + +func TestModelWithTools(t *testing.T) { + mockey.PatchConvey("TestModelWithTools", t, func() { + ctx := context.Background() + m, err := New(ctx, &Config{APIKey: "test"}) + assert.NoError(t, err) + + mockey.PatchConvey("success", func() { + tools := []*schema.ToolInfo{ + {Name: "tool1", Desc: "d", ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&jsonschema.Schema{Type: "object"})}, + } + nm, err := m.WithTools(tools) + assert.NoError(t, err) + assert.NotNil(t, nm) + }) + + mockey.PatchConvey("empty tools", func() { + nm, err := m.WithTools(nil) + assert.Error(t, err) + assert.Nil(t, nm) + }) + }) +} + +func TestModelGetType(t *testing.T) { + mockey.PatchConvey("TestModelGetType", t, func() { + m, err := New(context.Background(), &Config{APIKey: "test"}) + assert.NoError(t, err) + assert.Equal(t, "OpenAI", m.GetType()) + }) +} + +func TestModelIsCallbacksEnabled(t *testing.T) { + mockey.PatchConvey("TestModelIsCallbacksEnabled", t, func() { + m, err := New(context.Background(), &Config{APIKey: "test"}) + assert.NoError(t, err) + assert.True(t, m.IsCallbacksEnabled()) + }) +} diff --git a/components/agentic/openai/option.go b/components/agentic/openai/option.go new file mode 100644 index 000000000..cdd3c3a01 --- /dev/null +++ b/components/agentic/openai/option.go @@ -0,0 +1,104 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "github.com/cloudwego/eino/components/agentic" + "github.com/openai/openai-go/v3/responses" +) + +type openaiOptions struct { + reasoning *responses.ReasoningParam + maxOutputTokens *int64 + maxToolCalls *int + parallelToolCalls *bool + text *responses.ResponseTextConfigParam + store *bool + promptCacheKey *string + + serverTools []*ServerToolConfig + mcpTools []*responses.ToolMcpParam + + customHeaders map[string]string + extraFields map[string]any +} + +func WithStore(store bool) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.store = &store + }) +} + +func WithPromptCacheKey(key string) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.promptCacheKey = &key + }) +} + +func WithReasoning(reasoning *responses.ReasoningParam) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.reasoning = reasoning + }) +} + +func WithText(text *responses.ResponseTextConfigParam) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.text = text + }) +} + +func WithMaxOutputTokens(maxOutputTokens int64) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.maxOutputTokens = &maxOutputTokens + }) +} + +func WithMaxToolCalls(maxToolCalls int) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.maxToolCalls = &maxToolCalls + }) +} + +func WithParallelToolCalls(parallelToolCalls bool) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.parallelToolCalls = ¶llelToolCalls + }) +} + +func WithServerTools(tools []*ServerToolConfig) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.serverTools = tools + }) +} + +func WithMCPTools(tools []*responses.ToolMcpParam) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.mcpTools = tools + }) +} + +func WithCustomHeaders(headers map[string]string) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.customHeaders = headers + }) +} + +func WithExtraFields(fields map[string]any) agentic.Option { + return agentic.WrapImplSpecificOptFn(func(o *openaiOptions) { + o.extraFields = fields + }) +} diff --git a/components/agentic/openai/option_test.go b/components/agentic/openai/option_test.go new file mode 100644 index 000000000..44c32412e --- /dev/null +++ b/components/agentic/openai/option_test.go @@ -0,0 +1,278 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/bytedance/mockey" + "github.com/cloudwego/eino/components/agentic" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" +) + +func TestWithStore(t *testing.T) { + mockey.PatchConvey("WithStore", t, func() { + mockey.PatchConvey("set store to true", func() { + opt := WithStore(true) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.store) + assert.True(t, *opts.store) + }) + + mockey.PatchConvey("set store to false", func() { + opt := WithStore(false) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.store) + assert.False(t, *opts.store) + }) + }) +} + +func TestWithPromptCacheKey(t *testing.T) { + mockey.PatchConvey("WithPromptCacheKey", t, func() { + mockey.PatchConvey("set cache key", func() { + key := "test-cache-key" + opt := WithPromptCacheKey(key) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.promptCacheKey) + assert.Equal(t, key, *opts.promptCacheKey) + }) + + mockey.PatchConvey("set empty cache key", func() { + opt := WithPromptCacheKey("") + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.promptCacheKey) + assert.Equal(t, "", *opts.promptCacheKey) + }) + }) +} + +func TestWithReasoning(t *testing.T) { + mockey.PatchConvey("WithReasoning", t, func() { + mockey.PatchConvey("set reasoning param", func() { + reasoning := &responses.ReasoningParam{ + Effort: responses.ReasoningEffortLow, + } + opt := WithReasoning(reasoning) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.reasoning) + assert.Equal(t, reasoning, opts.reasoning) + }) + + mockey.PatchConvey("set nil reasoning", func() { + opt := WithReasoning(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.reasoning) + }) + }) +} + +func TestWithText(t *testing.T) { + mockey.PatchConvey("WithText", t, func() { + mockey.PatchConvey("set text config", func() { + text := &responses.ResponseTextConfigParam{ + Verbosity: responses.ResponseTextConfigVerbosityLow, + } + opt := WithText(text) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.text) + assert.Equal(t, text, opts.text) + }) + + mockey.PatchConvey("set nil text", func() { + opt := WithText(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.text) + }) + }) +} + +func TestWithMaxOutputTokens(t *testing.T) { + mockey.PatchConvey("WithMaxOutputTokens", t, func() { + mockey.PatchConvey("set positive value", func() { + maxTokens := int64(1000) + opt := WithMaxOutputTokens(maxTokens) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.maxOutputTokens) + assert.Equal(t, maxTokens, *opts.maxOutputTokens) + }) + + mockey.PatchConvey("set zero value", func() { + opt := WithMaxOutputTokens(0) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.maxOutputTokens) + assert.Equal(t, int64(0), *opts.maxOutputTokens) + }) + }) +} + +func TestWithMaxToolCalls(t *testing.T) { + mockey.PatchConvey("WithMaxToolCalls", t, func() { + mockey.PatchConvey("set positive value", func() { + maxCalls := 5 + opt := WithMaxToolCalls(maxCalls) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.maxToolCalls) + assert.Equal(t, maxCalls, *opts.maxToolCalls) + }) + + mockey.PatchConvey("set zero value", func() { + opt := WithMaxToolCalls(0) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.maxToolCalls) + assert.Equal(t, 0, *opts.maxToolCalls) + }) + }) +} + +func TestWithParallelToolCalls(t *testing.T) { + mockey.PatchConvey("WithParallelToolCalls", t, func() { + mockey.PatchConvey("set to true", func() { + opt := WithParallelToolCalls(true) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.parallelToolCalls) + assert.True(t, *opts.parallelToolCalls) + }) + + mockey.PatchConvey("set to false", func() { + opt := WithParallelToolCalls(false) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.parallelToolCalls) + assert.False(t, *opts.parallelToolCalls) + }) + }) +} + +func TestWithServerTools(t *testing.T) { + mockey.PatchConvey("WithServerTools", t, func() { + mockey.PatchConvey("set server tools", func() { + tools := []*ServerToolConfig{ + { + WebSearch: &responses.WebSearchToolParam{ + Type: responses.WebSearchToolTypeWebSearch, + }, + }, + } + opt := WithServerTools(tools) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.serverTools) + assert.Len(t, opts.serverTools, 1) + assert.Equal(t, tools, opts.serverTools) + }) + + mockey.PatchConvey("set empty tools", func() { + opt := WithServerTools([]*ServerToolConfig{}) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.serverTools) + assert.Len(t, opts.serverTools, 0) + }) + + mockey.PatchConvey("set nil tools", func() { + opt := WithServerTools(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.serverTools) + }) + }) +} + +func TestWithMCPTools(t *testing.T) { + mockey.PatchConvey("WithMCPTools", t, func() { + mockey.PatchConvey("set mcp tools", func() { + tools := []*responses.ToolMcpParam{ + { + ServerLabel: "test-server", + }, + } + opt := WithMCPTools(tools) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.mcpTools) + assert.Len(t, opts.mcpTools, 1) + assert.Equal(t, tools, opts.mcpTools) + }) + + mockey.PatchConvey("set empty tools", func() { + opt := WithMCPTools([]*responses.ToolMcpParam{}) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.mcpTools) + assert.Len(t, opts.mcpTools, 0) + }) + + mockey.PatchConvey("set nil tools", func() { + opt := WithMCPTools(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.mcpTools) + }) + }) +} + +func TestWithCustomHeaders(t *testing.T) { + mockey.PatchConvey("WithCustomHeaders", t, func() { + mockey.PatchConvey("set custom headers", func() { + headers := map[string]string{ + "X-Custom-Header": "value", + "Authorization": "Bearer token", + } + opt := WithCustomHeaders(headers) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.customHeaders) + assert.Equal(t, headers, opts.customHeaders) + }) + + mockey.PatchConvey("set empty headers", func() { + opt := WithCustomHeaders(map[string]string{}) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.customHeaders) + assert.Len(t, opts.customHeaders, 0) + }) + + mockey.PatchConvey("set nil headers", func() { + opt := WithCustomHeaders(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.customHeaders) + }) + }) +} + +func TestWithExtraFields(t *testing.T) { + mockey.PatchConvey("WithExtraFields", t, func() { + mockey.PatchConvey("set extra fields", func() { + fields := map[string]any{ + "field1": "value1", + "field2": 123, + "field3": true, + } + opt := WithExtraFields(fields) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.extraFields) + assert.Equal(t, fields, opts.extraFields) + }) + + mockey.PatchConvey("set empty fields", func() { + opt := WithExtraFields(map[string]any{}) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.NotNil(t, opts.extraFields) + assert.Len(t, opts.extraFields, 0) + }) + + mockey.PatchConvey("set nil fields", func() { + opt := WithExtraFields(nil) + opts := agentic.GetImplSpecificOptions(&openaiOptions{}, opt) + assert.Nil(t, opts.extraFields) + }) + }) +} diff --git a/components/agentic/openai/register.go b/components/agentic/openai/register.go new file mode 100644 index 000000000..44f8b4a26 --- /dev/null +++ b/components/agentic/openai/register.go @@ -0,0 +1,34 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[blockExtraItemID]("_eino_ext_openai_block_extra_item_id") + schema.RegisterName[blockExtraItemStatus]("_eino_ext_openai_block_extra_item_status") + schema.RegisterName[*ServerToolCallArguments]("_eino_ext_openai_server_tool_call_arguments") + schema.RegisterName[*ServerToolResult]("_eino_ext_openai_server_tool_result") + + compose.RegisterStreamChunkConcatFunc(concatFirstNonZero[blockExtraItemID]) + compose.RegisterStreamChunkConcatFunc(concatLast[blockExtraItemStatus]) + compose.RegisterStreamChunkConcatFunc(concatServerToolCallArguments) + compose.RegisterStreamChunkConcatFunc(concatServerToolResult) +} diff --git a/components/agentic/openai/utils.go b/components/agentic/openai/utils.go new file mode 100644 index 000000000..7b33045ae --- /dev/null +++ b/components/agentic/openai/utils.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/openai/openai-go/v3/packages/param" +) + +func newOpenaiOpt[T comparable](optVal *T) param.Opt[T] { + if optVal == nil { + return param.Opt[T]{} + } + return param.NewOpt(*optVal) +} + +func newOpenaiStrOpt(optVal string) param.Opt[string] { + if optVal == "" { + return param.Opt[string]{} + } + return param.NewOpt(optVal) +} + +func coalesce[T any](x, y T) T { + if !reflect.ValueOf(x).IsZero() { + return x + } + return y +} + +func ptrOf[T any](v T) *T { + return &v +} + +func int64ToStr(i int64) string { + return strconv.FormatInt(i, 10) +} + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +func newPanicErr(info any, stack []byte) error { + return &panicErr{ + info: info, + stack: stack, + } +} diff --git a/components/agentic/openai/utils_test.go b/components/agentic/openai/utils_test.go new file mode 100644 index 000000000..985483010 --- /dev/null +++ b/components/agentic/openai/utils_test.go @@ -0,0 +1,117 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/bytedance/mockey" + "github.com/openai/openai-go/v3/packages/param" + "github.com/stretchr/testify/assert" +) + +func TestNewOpenaiOpt(t *testing.T) { + mockey.PatchConvey("newOpenaiOpt", t, func() { + mockey.PatchConvey("non-nil value", func() { + val := 123 + opt := newOpenaiOpt(&val) + assert.Equal(t, param.NewOpt(val), opt) + }) + + mockey.PatchConvey("nil value", func() { + var val *int + opt := newOpenaiOpt(val) + assert.Equal(t, param.Opt[int]{}, opt) + }) + }) +} + +func TestNewOpenaiStrOpt(t *testing.T) { + mockey.PatchConvey("newOpenaiStrOpt", t, func() { + mockey.PatchConvey("non-empty string", func() { + val := "hello" + opt := newOpenaiStrOpt(val) + assert.Equal(t, param.NewOpt(val), opt) + }) + + mockey.PatchConvey("empty string", func() { + opt := newOpenaiStrOpt("") + assert.Equal(t, param.Opt[string]{}, opt) + }) + }) +} + +func TestCoalesce(t *testing.T) { + mockey.PatchConvey("coalesce", t, func() { + mockey.PatchConvey("x is non-zero", func() { + x := "a" + y := "b" + assert.Equal(t, x, coalesce(x, y)) + }) + + mockey.PatchConvey("x is zero", func() { + x := "" + y := "b" + assert.Equal(t, y, coalesce(x, y)) + }) + }) +} + +func TestPtrOf(t *testing.T) { + mockey.PatchConvey("ptrOf", t, func() { + mockey.PatchConvey("int", func() { + val := 123 + res := ptrOf(val) + assert.Equal(t, val, *res) + }) + + mockey.PatchConvey("string", func() { + val := "hello" + res := ptrOf(val) + assert.Equal(t, val, *res) + }) + }) +} + +func TestInt64ToStr(t *testing.T) { + mockey.PatchConvey("int64ToStr", t, func() { + mockey.PatchConvey("positive", func() { + assert.Equal(t, "123", int64ToStr(123)) + }) + + mockey.PatchConvey("negative", func() { + assert.Equal(t, "-123", int64ToStr(-123)) + }) + + mockey.PatchConvey("zero", func() { + assert.Equal(t, "0", int64ToStr(0)) + }) + }) +} + +func TestNewPanicErr(t *testing.T) { + mockey.PatchConvey("newPanicErr", t, func() { + mockey.PatchConvey("error interface and formatting", func() { + info := "something went wrong" + stack := []byte("stack trace") + err := newPanicErr(info, stack) + assert.Error(t, err) + assert.Contains(t, err.Error(), info) + assert.Contains(t, err.Error(), string(stack)) + }) + }) +}