Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/core/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ export class Agent {
}
}

assistantBlocks = this.compactContentBlocks(assistantBlocks);
assistantBlocks = this.splitThinkBlocksIfNeeded(assistantBlocks);

await this.hooks.runPostModel({ role: 'assistant', content: assistantBlocks } as any);
Expand Down Expand Up @@ -1821,17 +1822,24 @@ export class Agent {
await this.persistentStore.saveMediaCache(this.agentId, Array.from(this.mediaCache.values()));
}

private splitThinkBlocksIfNeeded(blocks: ContentBlock[]): ContentBlock[] {
private compactContentBlocks(blocks: Array<ContentBlock | undefined>): ContentBlock[] {
return blocks.filter((block): block is ContentBlock => block !== undefined);
}

private splitThinkBlocksIfNeeded(blocks: Array<ContentBlock | undefined>): ContentBlock[] {
const config = this.model.toConfig();
const transport =
config.reasoningTransport ??
(config.provider === 'openai' || config.provider === 'gemini' ? 'text' : 'provider');
if (transport !== 'text') {
return blocks;
return this.compactContentBlocks(blocks);
}

const output: ContentBlock[] = [];
for (const block of blocks) {
if (!block) {
continue;
}
if (block.type !== 'text') {
output.push(block);
continue;
Expand Down
47 changes: 38 additions & 9 deletions src/infra/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,40 @@ export class OpenAIProvider implements ModelProvider {
const decoder = new TextDecoder();
let buffer = '';
let textStarted = false;
const textIndex = 0;
let textIndex: number | undefined;
let reasoningStarted = false;
const reasoningIndex = 1000;
let reasoningIndex: number | undefined;
let nextBlockIndex = 0;
let sawFinishReason = false;
let usageEmitted = false;
const toolCallBuffers = new Map<number, { id?: string; name?: string; args: string }>();

const resolveTextIndex = (): number => {
if (textIndex === undefined) {
textIndex = nextBlockIndex++;
}
return textIndex;
};

const resolveReasoningIndex = (): number => {
if (reasoningIndex === undefined) {
reasoningIndex = nextBlockIndex++;
}
return reasoningIndex;
};

const toolCallIndexMap = new Map<number, number>();
const resolveToolCallIndex = (providerIndex?: number): number => {
const rawIndex = typeof providerIndex === 'number' ? providerIndex : 0;
const existing = toolCallIndexMap.get(rawIndex);
if (existing !== undefined) {
return existing;
}
const allocated = nextBlockIndex++;
toolCallIndexMap.set(rawIndex, allocated);
return allocated;
};

function* flushToolCalls(): Generator<ModelStreamChunk> {
if (toolCallBuffers.size === 0) return;
const entries = Array.from(toolCallBuffers.entries()).sort((a, b) => a[0] - b[0]);
Expand Down Expand Up @@ -467,41 +494,43 @@ export class OpenAIProvider implements ModelProvider {

const delta = choice.delta ?? {};
if (typeof delta.content === 'string' && delta.content.length > 0) {
const index = resolveTextIndex();
if (!textStarted) {
textStarted = true;
yield {
type: 'content_block_start',
index: textIndex,
index,
content_block: { type: 'text', text: '' },
};
}
yield {
type: 'content_block_delta',
index: textIndex,
index,
delta: { type: 'text_delta', text: delta.content },
};
}

if (typeof (delta as any).reasoning_content === 'string') {
const reasoningText = (delta as any).reasoning_content;
const index = resolveReasoningIndex();
if (!reasoningStarted) {
reasoningStarted = true;
yield {
type: 'content_block_start',
index: reasoningIndex,
index,
content_block: { type: 'reasoning', reasoning: '' },
};
}
yield {
type: 'content_block_delta',
index: reasoningIndex,
index,
delta: { type: 'reasoning_delta', text: reasoningText },
};
}

const toolCalls = Array.isArray(delta.tool_calls) ? delta.tool_calls : [];
for (const call of toolCalls) {
const index = typeof call.index === 'number' ? call.index : 0;
const index = resolveToolCallIndex(call.index);
const entry = toolCallBuffers.get(index) ?? { args: '' };
if (call.id) entry.id = call.id;
if (call.function?.name) entry.name = call.function.name;
Expand Down Expand Up @@ -529,10 +558,10 @@ export class OpenAIProvider implements ModelProvider {
}

if (textStarted) {
yield { type: 'content_block_stop', index: textIndex };
yield { type: 'content_block_stop', index: textIndex! };
}
if (reasoningStarted) {
yield { type: 'content_block_stop', index: reasoningIndex };
yield { type: 'content_block_stop', index: reasoningIndex! };
}
if (toolCallBuffers.size > 0) {
yield* flushToolCalls();
Expand Down
94 changes: 93 additions & 1 deletion tests/unit/core/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { createUnitTestAgent, ensureCleanDir } from '../../helpers/setup';
import { TestRunner, expect } from '../../helpers/utils';
import { ContentBlock } from '../../../src/core/types';
import { Hooks } from '../../../src/core/hooks';
import { ModelResponse } from '../../../src/infra/provider';
import { ModelProvider, ModelResponse, ModelStreamChunk } from '../../../src/infra/provider';
import { JSONStore, AgentTemplateRegistry, SandboxFactory, ToolRegistry } from '../../../src';
import { MockProvider } from '../../mock-provider';
import { TEST_ROOT } from '../../helpers/fixtures';
Expand Down Expand Up @@ -117,6 +117,98 @@ runner

await cleanup();
})
.test('稀疏 streamed block index 不应触发模型阶段错误', async () => {
class SparseIndexProvider implements ModelProvider {
readonly model = 'sparse-index-provider';
readonly maxWindowSize = 128000;
readonly maxOutputTokens = 4096;
readonly temperature = 0;

async complete(): Promise<ModelResponse> {
throw new Error('complete() should not be called');
}

async *stream(): AsyncIterable<ModelStreamChunk> {
yield {
type: 'content_block_start',
index: 0,
content_block: { type: 'text', text: '' },
};
yield {
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: 'hello' },
};
yield { type: 'content_block_stop', index: 0 };
yield {
type: 'content_block_start',
index: 1000,
content_block: { type: 'reasoning', reasoning: '' },
};
yield {
type: 'content_block_delta',
index: 1000,
delta: { type: 'reasoning_delta', text: 'thinking' },
};
yield { type: 'content_block_stop', index: 1000 };
yield {
type: 'message_delta',
usage: { input_tokens: 1, output_tokens: 1 },
};
yield { type: 'message_stop' };
}

toConfig() {
return {
provider: 'openai' as const,
model: this.model,
reasoningTransport: 'text' as const,
};
}
}

const workDir = path.join(TEST_ROOT, `agent-sparse-work-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`);
const storeDir = path.join(TEST_ROOT, `agent-sparse-store-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`);
ensureCleanDir(workDir);
ensureCleanDir(storeDir);

const templates = new AgentTemplateRegistry();
templates.register({
id: 'sparse-index-template',
systemPrompt: 'test sparse stream indexes',
tools: [],
permission: { mode: 'auto' },
});

const agent = await Agent.create(
{
templateId: 'sparse-index-template',
model: new SparseIndexProvider(),
sandbox: { kind: 'local', workDir, enforceBoundary: true, watchFiles: false },
},
{
store: new JSONStore(storeDir),
templateRegistry: templates,
sandboxFactory: new SandboxFactory(),
toolRegistry: new ToolRegistry(),
}
);

const errors: string[] = [];
const stop = agent.on('error', (event) => {
errors.push(event.message);
});

const result = await agent.chat('trigger sparse streamed blocks');
stop();

expect.toEqual(result.status, 'ok');
expect.toEqual(result.text, 'hello');
expect.toHaveLength(errors, 0);

fs.rmSync(workDir, { recursive: true, force: true });
fs.rmSync(storeDir, { recursive: true, force: true });
})

.test('Todo 管理API在启用时可用', async () => {
const template = {
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/providers/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,51 @@ runner
const assistant = capturedBody.messages.find((msg: any) => msg.role === 'assistant');
expect.toEqual(assistant?.reasoning_content, undefined);
expect.toEqual(assistant?.content, 'ok');
})
.test('stream 使用连续且不冲突的本地 block index', async () => {
const provider = new OpenAIProvider('test-key', 'gpt-4o', 'https://api.openai.com/v1');
const originalFetch = globalThis.fetch;
const encoder = new TextEncoder();
const sseBody = [
'data: {"choices":[{"delta":{"content":"hello","reasoning_content":"think","tool_calls":[{"index":0,"id":"call_1","function":{"name":"sum","arguments":"{\\"value\\":"}}]}}]}',
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"1}"}}],"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":1,"completion_tokens":2}}',
'data: [DONE]',
'',
].join('\n');

globalThis.fetch = (async () => {
return {
ok: true,
body: new ReadableStream({
start(controller) {
controller.enqueue(encoder.encode(sseBody));
controller.close();
},
}),
} as any;
}) as any;

const chunks: any[] = [];
try {
for await (const chunk of provider.stream([{ role: 'user', content: [{ type: 'text', text: 'hi' }] }])) {
chunks.push(chunk);
}
} finally {
globalThis.fetch = originalFetch;
}

const starts = chunks.filter((chunk) => chunk.type === 'content_block_start');
expect.toHaveLength(starts, 3);
expect.toDeepEqual(
starts.map((chunk) => ({ index: chunk.index, type: chunk.content_block.type })),
[
{ index: 0, type: 'text' },
{ index: 1, type: 'reasoning' },
{ index: 2, type: 'tool_use' },
]
);
const toolDelta = chunks.find((chunk) => chunk.type === 'content_block_delta' && chunk.delta?.type === 'input_json_delta');
expect.toEqual(toolDelta?.index, 2);
});

export async function run() {
Expand Down
Loading