Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pkg/agent/subturn.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ func (al *AgentLoop) generateSubTurnID() string {
return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1))
}

func toolRegistryFromSlice(toolSlice []tools.Tool) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
for _, tool := range toolSlice {
if tool != nil {
registry.Register(tool)
}
}
return registry
}

// ====================== Core Function: spawnSubTurn ======================

// AgentLoopSpawner implements tools.SubTurnSpawner interface.
Expand Down Expand Up @@ -344,10 +354,14 @@ func spawnSubTurn(
ephemeralStore := newEphemeralSession(nil)
agent := *baseAgent // shallow copy
agent.Sessions = ephemeralStore
// Clone the tool registry so child turn's tool registrations
// don't pollute the parent's registry.
if baseAgent.Tools != nil {
// Inherit the parent's tool registry snapshot so hidden/TTL metadata is preserved.
agent.Tools = baseAgent.Tools.Clone()
} else if len(cfg.Tools) > 0 {
// Fallback path for callers that provide explicit tool slices without a parent registry.
agent.Tools = toolRegistryFromSlice(cfg.Tools)
} else {
agent.Tools = tools.NewToolRegistry()
}

// Create processOptions for the child turn
Expand Down
236 changes: 236 additions & 0 deletions pkg/agent/subturn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,242 @@ func (m *simpleMockProviderAPI) GetDefaultModel() string {
return "gpt-4o-mini"
}

type toolCaptureProvider struct {
lastToolNames []string
}

func (p *toolCaptureProvider) Chat(
ctx context.Context,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
p.lastToolNames = p.lastToolNames[:0]
for _, td := range toolDefs {
p.lastToolNames = append(p.lastToolNames, td.Function.Name)
}
return &providers.LLMResponse{Content: "ok"}, nil
}

func (p *toolCaptureProvider) GetDefaultModel() string {
return "test-model"
}

type subturnProbeTool struct {
name string
}

func (t *subturnProbeTool) Name() string { return t.name }

func (t *subturnProbeTool) Description() string { return "subturn probe tool" }

func (t *subturnProbeTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
}
}

func (t *subturnProbeTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.SilentResult("ok")
}

func TestSpawnSubTurn_EmptyExplicitToolsStillInheritParentRegistry(t *testing.T) {
provider := &toolCaptureProvider{}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)

parentAgent := al.registry.GetDefaultAgent()
if parentAgent == nil {
t.Fatal("expected default agent")
}

parentAgent.Tools = tools.NewToolRegistry()
parentAgent.Tools.Register(&subturnProbeTool{name: "inherited_tool"})

parent := &turnState{
ctx: context.Background(),
turnID: "parent-empty-explicit-tools",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: parentAgent,
}

_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
Model: "test-model",
SystemPrompt: "run task",
Tools: []tools.Tool{},
})
if err != nil {
t.Fatalf("spawnSubTurn returned error: %v", err)
}

if len(provider.lastToolNames) != 1 || provider.lastToolNames[0] != "inherited_tool" {
t.Fatalf("expected inherited parent tools, got %v", provider.lastToolNames)
}
}

func TestSpawnSubTurn_InheritsRuntimeAddedTools(t *testing.T) {
provider := &toolCaptureProvider{}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)

parentAgent := al.registry.GetDefaultAgent()
if parentAgent == nil {
t.Fatal("expected default agent")
}

parentAgent.Tools = tools.NewToolRegistry()
parentAgent.Tools.Register(&subturnProbeTool{name: "base_tool"})
parent := &turnState{
ctx: context.Background(),
turnID: "parent-runtime-tools",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: parentAgent,
}

// Simulate runtime-added tools registered after initial manager setup.
parentAgent.Tools.Register(&subturnProbeTool{name: "runtime_tool"})

_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
Model: "test-model",
SystemPrompt: "run task",
})
if err != nil {
t.Fatalf("spawnSubTurn returned error: %v", err)
}

got := map[string]bool{}
for _, name := range provider.lastToolNames {
got[name] = true
}
if !got["base_tool"] || !got["runtime_tool"] {
t.Fatalf("expected runtime-added parent tools in provider defs, got %v", provider.lastToolNames)
}
}

func TestSpawnSubTurn_PreservesHiddenTTLSemantics(t *testing.T) {
provider := &toolCaptureProvider{}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)

parentAgent := al.registry.GetDefaultAgent()
if parentAgent == nil {
t.Fatal("expected default agent")
}

parentAgent.Tools = tools.NewToolRegistry()
parentAgent.Tools.Register(&subturnProbeTool{name: "core_tool"})
parentAgent.Tools.RegisterHidden(&subturnProbeTool{name: "hidden_active"})
parentAgent.Tools.RegisterHidden(&subturnProbeTool{name: "hidden_inactive"})
parentAgent.Tools.PromoteTools([]string{"hidden_active"}, 2)

parent := &turnState{
ctx: context.Background(),
turnID: "parent-hidden-ttl",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: parentAgent,
}

_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
Model: "test-model",
SystemPrompt: "run task",
})
if err != nil {
t.Fatalf("spawnSubTurn returned error: %v", err)
}

got := map[string]bool{}
for _, name := range provider.lastToolNames {
got[name] = true
}
if !got["core_tool"] || !got["hidden_active"] {
t.Fatalf("expected core + active hidden tools, got %v", provider.lastToolNames)
}
if got["hidden_inactive"] {
t.Fatalf("inactive hidden tool leaked into provider defs: %v", provider.lastToolNames)
}
}

func TestSpawnSubTurn_UsesExplicitToolsWhenParentRegistryUnavailable(t *testing.T) {
provider := &toolCaptureProvider{}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)

parentAgent := al.registry.GetDefaultAgent()
if parentAgent == nil {
t.Fatal("expected default agent")
}
parentAgent.Tools = nil

parent := &turnState{
ctx: context.Background(),
turnID: "parent-explicit-fallback",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: parentAgent,
}

_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
Model: "test-model",
SystemPrompt: "run task",
Tools: []tools.Tool{&subturnProbeTool{name: "explicit_tool"}},
})
if err != nil {
t.Fatalf("spawnSubTurn returned error: %v", err)
}

if len(provider.lastToolNames) != 1 || provider.lastToolNames[0] != "explicit_tool" {
t.Fatalf("expected explicit fallback tools, got %v", provider.lastToolNames)
}
}

// TestGetActiveTurn verifies that GetActiveTurn returns correct turn information
func TestGetActiveTurn(t *testing.T) {
cfg := &config.Config{
Expand Down
2 changes: 1 addition & 1 deletion pkg/tools/spawn.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Task: %s`,
go func() {
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
Model: t.defaultModel,
Tools: nil, // Will inherit from parent via context
Tools: nil, // Inherit from the parent turn registry at runtime.
SystemPrompt: systemPrompt,
MaxTokens: t.maxTokens,
Temperature: t.temperature,
Expand Down
76 changes: 76 additions & 0 deletions pkg/tools/spawn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"strings"
"testing"
"time"
)

// mockSpawner implements SubTurnSpawner for testing
Expand All @@ -24,6 +25,50 @@ func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*Too
}, nil
}

type managerSnapshotTool struct {
name string
}

func (t *managerSnapshotTool) Name() string {
return t.name
}

func (t *managerSnapshotTool) Description() string {
return "test tool"
}

func (t *managerSnapshotTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
}
}

func (t *managerSnapshotTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return SilentResult("ok")
}

type recordingSpawner struct {
toolNames []string
toolsNil bool
done chan struct{}
}

func (s *recordingSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error) {
s.toolsNil = cfg.Tools == nil
for _, tool := range cfg.Tools {
if tool != nil {
s.toolNames = append(s.toolNames, tool.Name())
}
}
if s.done != nil {
close(s.done)
}
return &ToolResult{
ForLLM: "Task completed",
ForUser: "Task completed",
}, nil
}

func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
Expand Down Expand Up @@ -96,3 +141,34 @@ func TestSpawnTool_Execute_NilManager(t *testing.T) {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}

func TestSpawnTool_Execute_LeavesToolsUnsetForRuntimeInheritance(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.RegisterTool(&managerSnapshotTool{name: "snapshot_tool"})

tool := NewSpawnTool(manager)
spawner := &recordingSpawner{done: make(chan struct{})}
tool.SetSpawner(spawner)

result := tool.Execute(context.Background(), map[string]any{"task": "inspect tools"})
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !result.Async {
t.Fatal("spawn result should be async")
}

select {
case <-spawner.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for async spawn execution")
}

if !spawner.toolsNil {
t.Fatalf("expected cfg.Tools to be nil for runtime inheritance, got %v", spawner.toolNames)
}
}
2 changes: 1 addition & 1 deletion pkg/tools/subagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ Task: %s`,
if t.spawner != nil {
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
Model: t.defaultModel,
Tools: nil, // Will inherit from parent via context
Tools: nil, // Inherit from the parent turn registry at runtime.
SystemPrompt: systemPrompt,
MaxTokens: t.maxTokens,
Temperature: t.temperature,
Expand Down
Loading
Loading