diff --git a/src/hooks/runtime-fallback/fallback-state.ts b/src/hooks/runtime-fallback/fallback-state.ts index 15348a21d6..a1eb122de5 100644 --- a/src/hooks/runtime-fallback/fallback-state.ts +++ b/src/hooks/runtime-fallback/fallback-state.ts @@ -2,6 +2,68 @@ import type { FallbackState, FallbackResult } from "./types" import { HOOK_NAME } from "./constants" import { log } from "../../shared/logger" import type { RuntimeFallbackConfig } from "../../config" +import { parseModelString } from "../../tools/delegate-task/model-string-parser" + +function canonicalizeModelID(modelID: string): string { + const loweredModelID = modelID.toLowerCase() + const dottedModelID = loweredModelID.replace(/\./g, "-") + + if ( + dottedModelID.startsWith("claude-opus-") || + dottedModelID.startsWith("claude-sonnet-") || + dottedModelID.startsWith("claude-haiku-") + ) { + return dottedModelID + .replace(/-thinking$/i, "") + .replace(/-max$/i, "") + .replace(/-high$/i, "") + } + + return modelID + .toLowerCase() + .replace(/\./g, "-") +} + +function canonicalizeProviderFamily(providerID: string, modelID: string): string { + const canonicalModelID = canonicalizeModelID(modelID) + + if ( + canonicalModelID.startsWith("claude-opus-") || + canonicalModelID.startsWith("claude-sonnet-") || + canonicalModelID.startsWith("claude-haiku-") + ) { + return "anthropic-compatible-claude" + } + + return providerID.toLowerCase() +} + +function parseCanonicalModel(model: string): { providerID: string; modelID: string } | undefined { + const parsed = parseModelString(model) + if (!parsed?.providerID || !parsed.modelID) return undefined + + const canonicalModelID = canonicalizeModelID(parsed.modelID) + const variant = parsed.variant?.toLowerCase() + + return { + providerID: canonicalizeProviderFamily(parsed.providerID, parsed.modelID), + modelID: variant ? `${canonicalModelID}::${variant}` : canonicalModelID, + } +} + +function isEquivalentModel(candidate: string, current: string): boolean { + const parsedCandidate = parseCanonicalModel(candidate) + const parsedCurrent = parseCanonicalModel(current) + + if (!parsedCandidate || !parsedCurrent) { + return candidate.toLowerCase() === current.toLowerCase() + } + + return ( + parsedCandidate.providerID === parsedCurrent.providerID && + parsedCandidate.modelID === parsedCurrent.modelID + ) +} export function createFallbackState(originalModel: string): FallbackState { return { @@ -28,6 +90,15 @@ export function findNextAvailableFallback( ): string | undefined { for (let i = state.fallbackIndex + 1; i < fallbackModels.length; i++) { const candidate = fallbackModels[i] + if (isEquivalentModel(candidate, state.currentModel)) { + log(`[${HOOK_NAME}] Skipping equivalent fallback model`, { + model: candidate, + currentModel: state.currentModel, + index: i, + }) + continue + } + if (!isModelInCooldown(candidate, state, cooldownSeconds)) { return candidate } diff --git a/src/hooks/runtime-fallback/index.test.ts b/src/hooks/runtime-fallback/index.test.ts index 3956d26c81..e6401f86c0 100644 --- a/src/hooks/runtime-fallback/index.test.ts +++ b/src/hooks/runtime-fallback/index.test.ts @@ -1245,7 +1245,10 @@ describe("runtime-fallback", () => { expect(retriedModels.length).toBeGreaterThanOrEqual(2) expect(retriedModels[0]).toBe("github-copilot/claude-opus-4.6") - expect(retriedModels[1]).toBe("anthropic/claude-opus-4-6") + expect(retriedModels[1]).toBe("openai/gpt-5.4") + + const equivalentSkipLog = logCalls.find((c) => c.msg.includes("Skipping equivalent fallback model")) + expect(equivalentSkipLog).toBeDefined() void sessionErrorPromise }) @@ -1314,7 +1317,7 @@ describe("runtime-fallback", () => { await new Promise((resolve) => setTimeout(resolve, 50)) expect(retriedModels).toContain("github-copilot/claude-opus-4.6") - expect(retriedModels).toContain("anthropic/claude-opus-4-6") + expect(retriedModels).toContain("openai/gpt-5.4") expect(abortCalls.some((call) => call.path?.id === sessionID)).toBe(true) const timeoutLog = logCalls.find((c) => c.msg.includes("Session fallback timeout reached")) @@ -1392,7 +1395,7 @@ describe("runtime-fallback", () => { await new Promise((resolve) => setTimeout(resolve, 50)) expect(retriedModels).toContain("github-copilot/claude-opus-4.6") - expect(retriedModels).toContain("anthropic/claude-opus-4-6") + expect(retriedModels).toContain("openai/gpt-5.4") }) test("should abort in-flight fallback request before advancing on timeout", async () => { @@ -1466,7 +1469,7 @@ describe("runtime-fallback", () => { expect(abortCalls.some((call) => call.path?.id === sessionID)).toBe(true) expect(retriedModels).toContain("github-copilot/claude-opus-4.6") - expect(retriedModels).toContain("anthropic/claude-opus-4-6") + expect(retriedModels).toContain("openai/gpt-5.4") void sessionErrorPromise }) @@ -2385,7 +2388,10 @@ describe("runtime-fallback", () => { }), { config: createMockConfig({ notify_on_fallback: false }), - pluginConfig: createMockPluginConfigWithAgentFallback("prometheus", ["github-copilot/claude-opus-4.6"]), + pluginConfig: createMockPluginConfigWithAgentFallback("prometheus", [ + "github-copilot/claude-opus-4.6", + "openai/gpt-5.4", + ]), }, ) const sessionID = "test-preserve-agent-on-retry" @@ -2405,7 +2411,7 @@ describe("runtime-fallback", () => { expect(promptCalls.length).toBe(1) const callBody = promptCalls[0]?.body as Record expect(callBody?.agent).toBe("prometheus") - expect(callBody?.model).toEqual({ providerID: "github-copilot", modelID: "claude-opus-4.6" }) + expect(callBody?.model).toEqual({ providerID: "openai", modelID: "gpt-5.4" }) }) })