Skip to content

Commit b7d3c5c

Browse files
committed
Fix tool use error and cleanup the pipe processing
1 parent 43ae282 commit b7d3c5c

File tree

6 files changed

+219
-18
lines changed

6 files changed

+219
-18
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"kilo-code": patch
3+
---
4+
5+
Fixed ZenMux tool-calling reliability to avoid repeated "tool not used" loops and preserve transformed request messages.

packages/types/src/providers/zenmux.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ export const zenmuxDefaultModelInfo: ModelInfo = {
99
contextWindow: 200_000,
1010
supportsImages: true,
1111
supportsPromptCache: true,
12+
// kilocode_change start
13+
supportsNativeTools: true,
14+
defaultToolProtocol: "native",
15+
// kilocode_change end
1216
inputPrice: 15.0,
1317
outputPrice: 75.0,
1418
cacheWritesPrice: 18.75,
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// kilocode_change - new file
2+
import OpenAI from "openai"
3+
4+
import type { ApiHandlerCreateMessageMetadata } from "../../index"
5+
import type { ApiHandlerOptions } from "../../../shared/api"
6+
import { ZenMuxHandler } from "../zenmux"
7+
8+
vi.mock("../fetchers/modelCache", () => ({
9+
getModels: vi.fn().mockResolvedValue({}),
10+
}))
11+
12+
function createMockStream() {
13+
return {
14+
async *[Symbol.asyncIterator]() {
15+
yield {
16+
choices: [{ delta: { content: "ok" }, finish_reason: "stop" }],
17+
usage: { prompt_tokens: 1, completion_tokens: 1, cost: 0 },
18+
}
19+
},
20+
}
21+
}
22+
23+
async function consume(generator: AsyncGenerator<unknown>) {
24+
for await (const _chunk of generator) {
25+
// Consume all chunks
26+
}
27+
}
28+
29+
describe("ZenMuxHandler native tools and message pipeline", () => {
30+
const baseOptions: ApiHandlerOptions = {
31+
zenmuxApiKey: "test-key",
32+
zenmuxModelId: "z-ai/glm-5",
33+
zenmuxBaseUrl: "https://test.zenmux.ai/api/v1",
34+
}
35+
36+
it("merges native tool defaults when model cache entry lacks native metadata", () => {
37+
const handler = new ZenMuxHandler(baseOptions)
38+
;(handler as unknown as { models: Record<string, unknown> }).models = {
39+
"z-ai/glm-5": {
40+
maxTokens: 8192,
41+
contextWindow: 128000,
42+
supportsImages: false,
43+
supportsPromptCache: false,
44+
inputPrice: 0,
45+
outputPrice: 0,
46+
description: "GLM 5",
47+
},
48+
}
49+
50+
const model = handler.getModel()
51+
expect(model.info.supportsNativeTools).toBe(true)
52+
expect(model.info.defaultToolProtocol).toBe("native")
53+
})
54+
55+
it("passes tools and tool choice to stream creation when task protocol is native", async () => {
56+
const handler = new ZenMuxHandler(baseOptions)
57+
58+
vi.spyOn(handler, "fetchModel").mockResolvedValue({
59+
id: "z-ai/glm-5",
60+
info: {
61+
maxTokens: 8192,
62+
contextWindow: 128000,
63+
supportsNativeTools: true,
64+
supportsImages: false,
65+
supportsPromptCache: false,
66+
inputPrice: 0,
67+
outputPrice: 0,
68+
description: "GLM 5",
69+
},
70+
} as any)
71+
72+
const streamSpy = vi.spyOn(handler, "createZenMuxStream").mockResolvedValue(createMockStream() as any)
73+
74+
const tools: OpenAI.Chat.ChatCompletionTool[] = [
75+
{
76+
type: "function",
77+
function: {
78+
name: "attempt_completion",
79+
description: "Complete the task",
80+
parameters: { type: "object", properties: {} },
81+
},
82+
},
83+
]
84+
const metadata: ApiHandlerCreateMessageMetadata = {
85+
taskId: "task-native",
86+
toolProtocol: "native",
87+
tools,
88+
tool_choice: "auto",
89+
parallelToolCalls: true,
90+
}
91+
92+
await consume(handler.createMessage("system", [{ role: "user", content: "hi" }], metadata))
93+
94+
expect(streamSpy).toHaveBeenCalledTimes(1)
95+
expect(streamSpy.mock.calls[0][6]).toEqual(tools)
96+
expect(streamSpy.mock.calls[0][7]).toBe("auto")
97+
expect(streamSpy.mock.calls[0][8]).toBe(true)
98+
})
99+
100+
it("omits tools when task protocol is xml even if tools are provided", async () => {
101+
const handler = new ZenMuxHandler(baseOptions)
102+
103+
vi.spyOn(handler, "fetchModel").mockResolvedValue({
104+
id: "z-ai/glm-5",
105+
info: {
106+
maxTokens: 8192,
107+
contextWindow: 128000,
108+
supportsNativeTools: true,
109+
supportsImages: false,
110+
supportsPromptCache: false,
111+
inputPrice: 0,
112+
outputPrice: 0,
113+
description: "GLM 5",
114+
},
115+
} as any)
116+
117+
const streamSpy = vi.spyOn(handler, "createZenMuxStream").mockResolvedValue(createMockStream() as any)
118+
119+
const tools: OpenAI.Chat.ChatCompletionTool[] = [
120+
{
121+
type: "function",
122+
function: {
123+
name: "ask_followup_question",
124+
description: "Ask a follow-up question",
125+
parameters: { type: "object", properties: {} },
126+
},
127+
},
128+
]
129+
130+
await consume(
131+
handler.createMessage("system", [{ role: "user", content: "hi" }], {
132+
taskId: "task-xml",
133+
toolProtocol: "xml",
134+
tools,
135+
tool_choice: "auto",
136+
parallelToolCalls: true,
137+
}),
138+
)
139+
140+
expect(streamSpy).toHaveBeenCalledTimes(1)
141+
expect(streamSpy.mock.calls[0][6]).toBeUndefined()
142+
expect(streamSpy.mock.calls[0][7]).toBeUndefined()
143+
expect(streamSpy.mock.calls[0][8]).toBe(false)
144+
})
145+
146+
it("passes transformed DeepSeek R1 messages into stream creation", async () => {
147+
const handler = new ZenMuxHandler({
148+
...baseOptions,
149+
zenmuxModelId: "deepseek/deepseek-r1",
150+
})
151+
152+
vi.spyOn(handler, "fetchModel").mockResolvedValue({
153+
id: "deepseek/deepseek-r1",
154+
info: {
155+
maxTokens: 8192,
156+
contextWindow: 128000,
157+
supportsNativeTools: true,
158+
supportsImages: false,
159+
supportsPromptCache: false,
160+
inputPrice: 0,
161+
outputPrice: 0,
162+
description: "DeepSeek R1",
163+
},
164+
} as any)
165+
166+
const streamSpy = vi.spyOn(handler, "createZenMuxStream").mockResolvedValue(createMockStream() as any)
167+
168+
await consume(handler.createMessage("system prompt", [{ role: "user", content: "hi" }], { taskId: "task-r1" }))
169+
170+
expect(streamSpy).toHaveBeenCalledTimes(1)
171+
const sentMessages = streamSpy.mock.calls[0][1] as OpenAI.Chat.ChatCompletionMessageParam[]
172+
expect(sentMessages.some((message: any) => message.role === "system")).toBe(false)
173+
expect((sentMessages[0] as any).role).toBe("user")
174+
})
175+
})

src/api/providers/fetchers/__tests__/zenmux.spec.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ describe("getZenmuxModels", () => {
3434
contextWindow: 200000,
3535
supportsImages: true,
3636
supportsPromptCache: false,
37+
supportsNativeTools: true,
38+
defaultToolProtocol: "native",
3739
inputPrice: 0,
3840
outputPrice: 0,
3941
description: "anthropic model",
@@ -62,5 +64,7 @@ describe("getZenmuxModels", () => {
6264

6365
expect(models["openai/gpt-5"].contextWindow).toBe(zenmuxDefaultModelInfo.contextWindow)
6466
expect(models["openai/gpt-5"].displayName).toBe("openai/gpt-5")
67+
expect(models["openai/gpt-5"].supportsNativeTools).toBe(true)
68+
expect(models["openai/gpt-5"].defaultToolProtocol).toBe("native")
6569
})
6670
})

src/api/providers/fetchers/zenmux.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ export async function getZenmuxModels(
5858
contextWindow,
5959
supportsImages: input_modalities?.includes("image") ?? false,
6060
supportsPromptCache: false,
61+
// kilocode_change start
62+
supportsNativeTools: true,
63+
defaultToolProtocol: "native",
64+
// kilocode_change end
6165
inputPrice: 0,
6266
outputPrice: 0,
6367
description: `${owned_by || "ZenMux"} model`,

src/api/providers/zenmux.ts

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import OpenAI from "openai"
33
import type Anthropic from "@anthropic-ai/sdk"
44
import type { ModelInfo } from "@roo-code/types"
5-
import { zenmuxDefaultModelId, zenmuxDefaultModelInfo } from "@roo-code/types"
5+
import { NATIVE_TOOL_DEFAULTS, TOOL_PROTOCOL, zenmuxDefaultModelId, zenmuxDefaultModelInfo } from "@roo-code/types"
66
import { ApiProviderError } from "@roo-code/types"
77
import { TelemetryService } from "@roo-code/telemetry"
88

@@ -24,7 +24,6 @@ import { ChatCompletionTool } from "openai/resources"
2424
import { convertToOpenAiMessages } from "../transform/openai-format"
2525
import { convertToR1Format } from "../transform/r1-format"
2626
import { resolveToolProtocol } from "../../utils/resolveToolProtocol"
27-
import { TOOL_PROTOCOL } from "@roo-code/types"
2827
import { ApiStreamChunk } from "../transform/stream"
2928
import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
3029
import { KiloCodeChunkSchema } from "./kilocode/chunk-schema"
@@ -117,21 +116,16 @@ export class ZenMuxHandler extends BaseProvider implements SingleCompletionHandl
117116
}
118117
async createZenMuxStream(
119118
client: OpenAI,
120-
systemPrompt: string,
121-
messages: Anthropic.Messages.MessageParam[],
119+
openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[],
122120
model: { id: string; info: ModelInfo },
123121
_reasoningEffort?: string,
124122
thinkingBudgetTokens?: number,
125123
zenMuxProviderSorting?: string,
126124
tools?: Array<ChatCompletionTool>,
125+
toolChoice?: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"],
126+
parallelToolCalls: boolean = false,
127127
_geminiThinkingLevel?: string,
128128
) {
129-
// Convert Anthropic messages to OpenAI format
130-
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
131-
{ role: "system", content: systemPrompt },
132-
...convertToOpenAiMessages(messages),
133-
]
134-
135129
// Build reasoning config if thinking budget is set
136130
let reasoning: { max_tokens: number } | undefined
137131
if (thinkingBudgetTokens && thinkingBudgetTokens > 0) {
@@ -155,16 +149,20 @@ export class ZenMuxHandler extends BaseProvider implements SingleCompletionHandl
155149
},
156150
}
157151
: {}),
158-
...this.getOpenAIToolParams(tools),
152+
...this.getOpenAIToolParams(tools, toolChoice, parallelToolCalls),
159153
})
160154

161155
return stream
162156
}
163-
getOpenAIToolParams(tools?: ChatCompletionTool[], enableParallelToolCalls: boolean = false) {
157+
getOpenAIToolParams(
158+
tools?: ChatCompletionTool[],
159+
toolChoice?: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"],
160+
enableParallelToolCalls: boolean = false,
161+
) {
164162
return tools?.length
165163
? {
166164
tools,
167-
tool_choice: tools ? "auto" : undefined,
165+
tool_choice: toolChoice ?? "auto",
168166
parallel_tool_calls: enableParallelToolCalls ? true : false,
169167
}
170168
: {
@@ -219,7 +217,9 @@ export class ZenMuxHandler extends BaseProvider implements SingleCompletionHandl
219217
}
220218

221219
// Process reasoning_details when switching models to Gemini for native tool call compatibility
222-
const toolProtocol = resolveToolProtocol(this.options, model.info)
220+
// kilocode_change start
221+
const toolProtocol = resolveToolProtocol(this.options, model.info, metadata?.toolProtocol)
222+
// kilocode_change end
223223
const isNativeProtocol = toolProtocol === TOOL_PROTOCOL.NATIVE
224224
const isGemini = modelId.startsWith("google/gemini")
225225

@@ -264,17 +264,24 @@ export class ZenMuxHandler extends BaseProvider implements SingleCompletionHandl
264264
}
265265
}
266266

267+
// kilocode_change start
268+
const tools = isNativeProtocol ? metadata?.tools : undefined
269+
const toolChoice = isNativeProtocol ? metadata?.tool_choice : undefined
270+
const parallelToolCalls = isNativeProtocol ? (metadata?.parallelToolCalls ?? false) : false
271+
// kilocode_change end
272+
267273
let stream
268274
try {
269275
stream = await this.createZenMuxStream(
270276
this.client,
271-
systemPrompt,
272-
messages,
277+
openAiMessages,
273278
model,
274279
this.options.reasoningEffort,
275280
this.options.modelMaxThinkingTokens,
276281
this.options.zenmuxProviderSort,
277-
metadata?.tools,
282+
tools,
283+
toolChoice,
284+
parallelToolCalls,
278285
)
279286
} catch (error) {
280287
const errorMessage = error instanceof Error ? error.message : String(error)
@@ -447,7 +454,9 @@ export class ZenMuxHandler extends BaseProvider implements SingleCompletionHandl
447454

448455
override getModel() {
449456
const id = this.options.zenmuxModelId ?? zenmuxDefaultModelId
450-
let info = this.models[id] ?? zenmuxDefaultModelInfo
457+
// kilocode_change start
458+
let info = { ...NATIVE_TOOL_DEFAULTS, ...(this.models[id] ?? zenmuxDefaultModelInfo) }
459+
// kilocode_change end
451460

452461
const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || id === "perplexity/sonar-reasoning"
453462

0 commit comments

Comments
 (0)