diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index f5ba412abb..b1a2fd8567 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -197,6 +197,16 @@ func (al *AgentLoop) generateSubTurnID() string { return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1)) } +func toolRegistryFromSlice(toolSlice []tools.Tool) *tools.ToolRegistry { + registry := tools.NewToolRegistry() + for _, tool := range toolSlice { + if tool != nil { + registry.Register(tool) + } + } + return registry +} + // ====================== Core Function: spawnSubTurn ====================== // AgentLoopSpawner implements tools.SubTurnSpawner interface. @@ -344,10 +354,14 @@ func spawnSubTurn( ephemeralStore := newEphemeralSession(nil) agent := *baseAgent // shallow copy agent.Sessions = ephemeralStore - // Clone the tool registry so child turn's tool registrations - // don't pollute the parent's registry. if baseAgent.Tools != nil { + // Inherit the parent's tool registry snapshot so hidden/TTL metadata is preserved. agent.Tools = baseAgent.Tools.Clone() + } else if len(cfg.Tools) > 0 { + // Fallback path for callers that provide explicit tool slices without a parent registry. + agent.Tools = toolRegistryFromSlice(cfg.Tools) + } else { + agent.Tools = tools.NewToolRegistry() } // Create processOptions for the child turn diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 6a2ba835d8..4b0d1882d8 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -933,6 +933,242 @@ func (m *simpleMockProviderAPI) GetDefaultModel() string { return "gpt-4o-mini" } +type toolCaptureProvider struct { + lastToolNames []string +} + +func (p *toolCaptureProvider) Chat( + ctx context.Context, + messages []providers.Message, + toolDefs []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + p.lastToolNames = p.lastToolNames[:0] + for _, td := range toolDefs { + p.lastToolNames = append(p.lastToolNames, td.Function.Name) + } + return &providers.LLMResponse{Content: "ok"}, nil +} + +func (p *toolCaptureProvider) GetDefaultModel() string { + return "test-model" +} + +type subturnProbeTool struct { + name string +} + +func (t *subturnProbeTool) Name() string { return t.name } + +func (t *subturnProbeTool) Description() string { return "subturn probe tool" } + +func (t *subturnProbeTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + } +} + +func (t *subturnProbeTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + return tools.SilentResult("ok") +} + +func TestSpawnSubTurn_EmptyExplicitToolsStillInheritParentRegistry(t *testing.T) { + provider := &toolCaptureProvider{} + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + + parentAgent := al.registry.GetDefaultAgent() + if parentAgent == nil { + t.Fatal("expected default agent") + } + + parentAgent.Tools = tools.NewToolRegistry() + parentAgent.Tools.Register(&subturnProbeTool{name: "inherited_tool"}) + + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-empty-explicit-tools", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns), + session: &ephemeralSessionStore{}, + agent: parentAgent, + } + + _, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{ + Model: "test-model", + SystemPrompt: "run task", + Tools: []tools.Tool{}, + }) + if err != nil { + t.Fatalf("spawnSubTurn returned error: %v", err) + } + + if len(provider.lastToolNames) != 1 || provider.lastToolNames[0] != "inherited_tool" { + t.Fatalf("expected inherited parent tools, got %v", provider.lastToolNames) + } +} + +func TestSpawnSubTurn_InheritsRuntimeAddedTools(t *testing.T) { + provider := &toolCaptureProvider{} + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + + parentAgent := al.registry.GetDefaultAgent() + if parentAgent == nil { + t.Fatal("expected default agent") + } + + parentAgent.Tools = tools.NewToolRegistry() + parentAgent.Tools.Register(&subturnProbeTool{name: "base_tool"}) + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-runtime-tools", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns), + session: &ephemeralSessionStore{}, + agent: parentAgent, + } + + // Simulate runtime-added tools registered after initial manager setup. + parentAgent.Tools.Register(&subturnProbeTool{name: "runtime_tool"}) + + _, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{ + Model: "test-model", + SystemPrompt: "run task", + }) + if err != nil { + t.Fatalf("spawnSubTurn returned error: %v", err) + } + + got := map[string]bool{} + for _, name := range provider.lastToolNames { + got[name] = true + } + if !got["base_tool"] || !got["runtime_tool"] { + t.Fatalf("expected runtime-added parent tools in provider defs, got %v", provider.lastToolNames) + } +} + +func TestSpawnSubTurn_PreservesHiddenTTLSemantics(t *testing.T) { + provider := &toolCaptureProvider{} + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + + parentAgent := al.registry.GetDefaultAgent() + if parentAgent == nil { + t.Fatal("expected default agent") + } + + parentAgent.Tools = tools.NewToolRegistry() + parentAgent.Tools.Register(&subturnProbeTool{name: "core_tool"}) + parentAgent.Tools.RegisterHidden(&subturnProbeTool{name: "hidden_active"}) + parentAgent.Tools.RegisterHidden(&subturnProbeTool{name: "hidden_inactive"}) + parentAgent.Tools.PromoteTools([]string{"hidden_active"}, 2) + + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-hidden-ttl", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns), + session: &ephemeralSessionStore{}, + agent: parentAgent, + } + + _, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{ + Model: "test-model", + SystemPrompt: "run task", + }) + if err != nil { + t.Fatalf("spawnSubTurn returned error: %v", err) + } + + got := map[string]bool{} + for _, name := range provider.lastToolNames { + got[name] = true + } + if !got["core_tool"] || !got["hidden_active"] { + t.Fatalf("expected core + active hidden tools, got %v", provider.lastToolNames) + } + if got["hidden_inactive"] { + t.Fatalf("inactive hidden tool leaked into provider defs: %v", provider.lastToolNames) + } +} + +func TestSpawnSubTurn_UsesExplicitToolsWhenParentRegistryUnavailable(t *testing.T) { + provider := &toolCaptureProvider{} + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + + parentAgent := al.registry.GetDefaultAgent() + if parentAgent == nil { + t.Fatal("expected default agent") + } + parentAgent.Tools = nil + + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-explicit-fallback", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 1), + concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns), + session: &ephemeralSessionStore{}, + agent: parentAgent, + } + + _, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{ + Model: "test-model", + SystemPrompt: "run task", + Tools: []tools.Tool{&subturnProbeTool{name: "explicit_tool"}}, + }) + if err != nil { + t.Fatalf("spawnSubTurn returned error: %v", err) + } + + if len(provider.lastToolNames) != 1 || provider.lastToolNames[0] != "explicit_tool" { + t.Fatalf("expected explicit fallback tools, got %v", provider.lastToolNames) + } +} + // TestGetActiveTurn verifies that GetActiveTurn returns correct turn information func TestGetActiveTurn(t *testing.T) { cfg := &config.Config{ diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index d019d511ab..2929de4a09 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -124,7 +124,7 @@ Task: %s`, go func() { result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{ Model: t.defaultModel, - Tools: nil, // Will inherit from parent via context + Tools: nil, // Inherit from the parent turn registry at runtime. SystemPrompt: systemPrompt, MaxTokens: t.maxTokens, Temperature: t.temperature, diff --git a/pkg/tools/spawn_test.go b/pkg/tools/spawn_test.go index fda6bbd89b..e6ea75f583 100644 --- a/pkg/tools/spawn_test.go +++ b/pkg/tools/spawn_test.go @@ -4,6 +4,7 @@ import ( "context" "strings" "testing" + "time" ) // mockSpawner implements SubTurnSpawner for testing @@ -24,6 +25,50 @@ func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*Too }, nil } +type managerSnapshotTool struct { + name string +} + +func (t *managerSnapshotTool) Name() string { + return t.name +} + +func (t *managerSnapshotTool) Description() string { + return "test tool" +} + +func (t *managerSnapshotTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + } +} + +func (t *managerSnapshotTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + return SilentResult("ok") +} + +type recordingSpawner struct { + toolNames []string + toolsNil bool + done chan struct{} +} + +func (s *recordingSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error) { + s.toolsNil = cfg.Tools == nil + for _, tool := range cfg.Tools { + if tool != nil { + s.toolNames = append(s.toolNames, tool.Name()) + } + } + if s.done != nil { + close(s.done) + } + return &ToolResult{ + ForLLM: "Task completed", + ForUser: "Task completed", + }, nil +} + func TestSpawnTool_Execute_EmptyTask(t *testing.T) { provider := &MockLLMProvider{} manager := NewSubagentManager(provider, "test-model", "/tmp/test") @@ -96,3 +141,34 @@ func TestSpawnTool_Execute_NilManager(t *testing.T) { t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM) } } + +func TestSpawnTool_Execute_LeavesToolsUnsetForRuntimeInheritance(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager.RegisterTool(&managerSnapshotTool{name: "snapshot_tool"}) + + tool := NewSpawnTool(manager) + spawner := &recordingSpawner{done: make(chan struct{})} + tool.SetSpawner(spawner) + + result := tool.Execute(context.Background(), map[string]any{"task": "inspect tools"}) + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Fatalf("Expected success, got error: %s", result.ForLLM) + } + if !result.Async { + t.Fatal("spawn result should be async") + } + + select { + case <-spawner.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for async spawn execution") + } + + if !spawner.toolsNil { + t.Fatalf("expected cfg.Tools to be nil for runtime inheritance, got %v", spawner.toolNames) + } +} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 9a1a8b802b..9e8b08dc1e 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -395,7 +395,7 @@ Task: %s`, if t.spawner != nil { result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{ Model: t.defaultModel, - Tools: nil, // Will inherit from parent via context + Tools: nil, // Inherit from the parent turn registry at runtime. SystemPrompt: systemPrompt, MaxTokens: t.maxTokens, Temperature: t.temperature, diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 89ac7d4b57..3c8ecf9ec5 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -324,3 +324,27 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) { t.Error("ForLLM should contain reference to original task") } } + +func TestSubagentTool_Execute_LeavesToolsUnsetForRuntimeInheritance(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager.RegisterTool(&managerSnapshotTool{name: "snapshot_tool"}) + + tool := NewSubagentTool(manager) + spawner := &recordingSpawner{} + tool.SetSpawner(spawner) + + result := tool.Execute(context.Background(), map[string]any{ + "task": "inspect tools", + }) + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Fatalf("Expected success, got error: %s", result.ForLLM) + } + + if !spawner.toolsNil { + t.Fatalf("expected cfg.Tools to be nil for runtime inheritance, got %v", spawner.toolNames) + } +}