From 1c612520e4bc1f5552f62a1c99a37fa2ffc806ce Mon Sep 17 00:00:00 2001 From: yudduy Date: Wed, 13 May 2026 16:47:00 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20weaver=20=E2=80=94=20ensemble=20inferen?= =?UTF-8?q?ce=20with=20verifier=20scoring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a coordinator-side orchestration mode that runs multiple candidate generations in parallel and scores them with verifier models, returning the highest-scoring response with a confidence metric. Inspired by the Stanford Weaver paper (https://scalingintelligence.stanford.edu/pubs/weaver.pdf). Coordinator: - /v1/chat/completions accepts {"weaver": {"mode": "deep"|"deep_plus"}} - Streams typed SSE events (weaver_init, candidate_delta, verifier_score, weaver_final) with the full trace - Verifier selection picks diverse model families from the catalog, falls back to self-verification with a warning if no eligible verifiers exist - Adds Family + CanVerify to SupportedModel (memory + postgres backends) - Billing charges candidate and verifier calls independently Provider: - Reasoning parser opt-in by family (Gemma4, DeepSeek R1, Qwen3/QwQ); Qwen2.5 no longer mis-applies the qwen3 parser Console UI: - Fast / Deep / Deep+ mode selector in ChatInput - Live candidate strip + tabbed trace modal - New auth-checked proxy routes for device/provider/stats endpoints Known follow-ups (not blocking): - Billing reserve→refund is two non-atomic writes (no two-phase commit) - No model-level SupportsWeaver flag (any catalog model is eligible) - Confidence metric is 0 when only one candidate succeeds - E2E + weaver path lacks a dedicated test --- console-ui/__tests__/api.test.ts | 64 ++ console-ui/__tests__/components.test.tsx | 86 ++- .../src/app/api/device/approve/route.ts | 36 + .../api/provider/account-earnings/route.ts | 30 + .../app/api/providers/attestation/route.ts | 15 + console-ui/src/app/api/stats/route.ts | 12 + console-ui/src/app/link/DeviceLinkForm.tsx | 6 +- console-ui/src/app/page.tsx | 285 ++++++-- .../providers/ProviderDashboardContent.tsx | 6 +- .../providers/earnings/EarningsContent.tsx | 11 +- console-ui/src/app/stats/page.tsx | 4 +- console-ui/src/components/ChatInput.tsx | 109 ++- console-ui/src/components/ChatMessage.tsx | 25 + console-ui/src/components/GoogleAnalytics.tsx | 12 +- .../src/components/InviteCodeBanner.tsx | 15 +- .../src/components/PreSendTrustBanner.tsx | 4 +- .../src/components/VerificationPanel.tsx | 6 +- .../src/components/WeaverTraceModal.tsx | 162 +++++ .../src/components/WeaverTraceStrip.tsx | 62 ++ console-ui/src/lib/api.ts | 260 +++++++ console-ui/src/lib/store.ts | 3 +- coordinator/cmd/coordinator/main.go | 133 ++-- coordinator/cmd/coordinator/main_test.go | 28 +- coordinator/internal/api/billing_handlers.go | 11 +- coordinator/internal/api/consumer.go | 29 + coordinator/internal/api/server.go | 2 + coordinator/internal/api/weaver.go | 650 ++++++++++++++++++ coordinator/internal/api/weaver_test.go | 525 ++++++++++++++ coordinator/internal/registry/registry.go | 2 + coordinator/internal/store/interface.go | 2 + coordinator/internal/store/postgres.go | 27 +- coordinator/internal/store/store_test.go | 13 + provider/src/backend/vllm_mlx.rs | 38 +- provider/src/inference.rs | 4 +- provider/src/main.rs | 97 ++- 35 files changed, 2569 insertions(+), 205 deletions(-) create mode 100644 console-ui/src/app/api/device/approve/route.ts create mode 100644 console-ui/src/app/api/provider/account-earnings/route.ts create mode 100644 console-ui/src/app/api/providers/attestation/route.ts create mode 100644 console-ui/src/app/api/stats/route.ts create mode 100644 console-ui/src/components/WeaverTraceModal.tsx create mode 100644 console-ui/src/components/WeaverTraceStrip.tsx create mode 100644 coordinator/internal/api/weaver.go create mode 100644 coordinator/internal/api/weaver_test.go diff --git a/console-ui/__tests__/api.test.ts b/console-ui/__tests__/api.test.ts index 10f403351..87641ee9b 100644 --- a/console-ui/__tests__/api.test.ts +++ b/console-ui/__tests__/api.test.ts @@ -7,6 +7,7 @@ import { fetchModels, fetchPricing, healthCheck, + streamWeaverChat, } from "@/lib/api"; // --------------------------------------------------------------------------- @@ -24,6 +25,21 @@ function jsonResponse(body: unknown, status = 200): Response { } as unknown as Response; } +function sseResponse(events: string): Response { + const encoder = new TextEncoder(); + return { + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(events)); + controller.close(); + }, + }), + } as unknown as Response; +} + // --------------------------------------------------------------------------- // Setup // --------------------------------------------------------------------------- @@ -262,6 +278,54 @@ describe("healthCheck", () => { }); }); +describe("streamWeaverChat", () => { + it("sends weaver options and routes orchestration SSE events", async () => { + fetchMock.mockResolvedValueOnce(sseResponse([ + 'event: weaver_init\ndata: {"mode":"deep","status":"generating","candidateCount":4,"verifierCount":3,"candidates":[],"verifierScores":[]}\n\n', + 'event: candidate_delta\ndata: {"candidate_id":"candidate-2","delta":"hello"}\n\n', + 'event: verifier_score\ndata: {"candidate_id":"candidate-2","verifier_model":"verifier-a","score":8,"rationale":"clear"}\n\n', + 'event: weaver_final\ndata: {"selected_candidate":"candidate-2","confidence":0.8,"content":"final answer","usage":{"tokenCount":12,"tps":0,"ttft":0}}\n\n', + 'event: done\ndata: {"status":"complete"}\n\n', + ].join(""))); + + const seen: string[] = []; + await streamWeaverChat( + [{ role: "user", content: "hi" }], + "base-model", + { mode: "deep", candidateCount: 4, verifierCount: 3 }, + { + onWeaverInit: (trace) => seen.push(`init:${trace.candidateCount}`), + onWeaverStatus: (status) => seen.push(`status:${status}`), + onCandidateStart: () => {}, + onCandidateDelta: (id, delta) => seen.push(`delta:${id}:${delta}`), + onCandidateDone: () => {}, + onCandidateError: () => {}, + onVerifierScore: (score) => seen.push(`score:${score.candidateId}:${score.score}`), + onFinal: (content, selected, confidence, usage) => seen.push(`final:${selected}:${confidence}:${usage?.tokenCount}:${content}`), + onDone: () => seen.push("done"), + onError: (error) => seen.push(`error:${error}`), + } + ); + + const [, opts] = fetchMock.mock.calls[0]; + const body = JSON.parse(opts.body); + expect(body.weaver).toEqual({ + mode: "deep", + candidate_count: 4, + verifier_count: 3, + verifier_selection: "auto", + trace: true, + }); + expect(seen).toEqual([ + "init:4", + "delta:candidate-2:hello", + "score:candidate-2:8", + "final:candidate-2:0.8:12:final answer", + "done", + ]); + }); +}); + // --------------------------------------------------------------------------- // proxyHeaders // --------------------------------------------------------------------------- diff --git a/console-ui/__tests__/components.test.tsx b/console-ui/__tests__/components.test.tsx index 7e0c6c556..16f266720 100644 --- a/console-ui/__tests__/components.test.tsx +++ b/console-ui/__tests__/components.test.tsx @@ -1,7 +1,12 @@ import { describe, it, expect } from "vitest"; -import { render, screen } from "@testing-library/react"; +import { fireEvent, render, screen } from "@testing-library/react"; +import { ChatInput } from "@/components/ChatInput"; +import { WeaverTraceStrip } from "@/components/WeaverTraceStrip"; +import { WeaverTraceModal } from "@/components/WeaverTraceModal"; import { TrustBadge } from "@/components/TrustBadge"; import { VerificationModeProvider } from "@/lib/verification-mode"; +import { useStore } from "@/lib/store"; +import type { WeaverTrace } from "@/lib/api"; import type { TrustMetadata } from "@/lib/api"; // --------------------------------------------------------------------------- @@ -93,3 +98,82 @@ describe("TrustBadge (normal mode)", () => { expect(span!.getAttribute("title")).toBe("Apple-verified hardware"); }); }); + +describe("WeaverTraceStrip", () => { + it("renders one dynamic diamond per candidate", () => { + const weaver: WeaverTrace = { + mode: "deep_plus", + status: "generating", + candidateCount: 8, + verifierCount: 5, + candidates: Array.from({ length: 8 }, (_, index) => ({ + id: `candidate-${index + 1}`, + index, + label: `Agent ${index + 1}`, + model: "base", + status: index === 0 ? "streaming" : "queued", + content: "", + })), + verifierScores: [], + }; + + render( {}} />); + + expect(screen.getByText("Deep+ · generating")).toBeInTheDocument(); + expect(screen.getAllByLabelText(/Agent \d/)).toHaveLength(8); + }); + + it("renders dynamic agent tabs in the modal", () => { + const weaver: WeaverTrace = { + mode: "deep_plus", + status: "complete", + candidateCount: 8, + verifierCount: 5, + selectedCandidate: "candidate-3", + confidence: 0.7, + candidates: Array.from({ length: 8 }, (_, index) => ({ + id: `candidate-${index + 1}`, + index, + label: `Agent ${index + 1}`, + model: "base", + status: index === 2 ? "selected" : "done", + content: `answer ${index + 1}`, + })), + verifierScores: [], + }; + + render( + {}} + onSelectTab={() => {}} + />, + ); + + expect(screen.getAllByRole("button", { name: /Agent \d/ })).toHaveLength(8); + expect(screen.getByRole("button", { name: "Verifiers" })).toBeInTheDocument(); + expect(screen.getByRole("button", { name: "Final" })).toBeInTheDocument(); + expect(screen.getByRole("dialog", { name: "Weaver trace" })).toHaveClass("h-[min(760px,86vh)]"); + }); +}); + +describe("ChatInput", () => { + it("opens the model selector even when models have not loaded", () => { + useStore.setState({ models: [], selectedModel: "" }); + + render( + {}} + onStop={() => {}} + isStreaming={false} + />, + ); + + fireEvent.click(screen.getByRole("button", { name: "Select model" })); + + expect(screen.getByText("No models loaded")).toBeInTheDocument(); + expect(screen.getByText("Check your API key or coordinator connection.")).toBeInTheDocument(); + }); +}); diff --git a/console-ui/src/app/api/device/approve/route.ts b/console-ui/src/app/api/device/approve/route.ts new file mode 100644 index 000000000..4c8a3de7b --- /dev/null +++ b/console-ui/src/app/api/device/approve/route.ts @@ -0,0 +1,36 @@ +import { NextRequest, NextResponse } from "next/server"; + +const DEFAULT_COORD = process.env.NEXT_PUBLIC_COORDINATOR_URL || "https://api.darkbloom.dev"; + +export async function POST(req: NextRequest) { + let authHeader = req.headers.get("authorization") || ""; + if (!authHeader) { + const privyToken = req.cookies.get("privy-token")?.value; + if (privyToken) { + authHeader = `Bearer ${privyToken}`; + } + } + if (!authHeader) { + return NextResponse.json({ error: "missing privy token" }, { status: 401 }); + } + + const res = await fetch(`${DEFAULT_COORD}/v1/device/approve`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: authHeader, + }, + body: JSON.stringify(await req.json()), + }); + const text = await res.text(); + if (!res.ok) { + return new Response(text, { + status: res.status, + headers: { "Content-Type": res.headers.get("content-type") || "application/json" }, + }); + } + return new Response(text, { + status: res.status, + headers: { "Content-Type": res.headers.get("content-type") || "application/json" }, + }); +} diff --git a/console-ui/src/app/api/provider/account-earnings/route.ts b/console-ui/src/app/api/provider/account-earnings/route.ts new file mode 100644 index 000000000..7b6c14eb1 --- /dev/null +++ b/console-ui/src/app/api/provider/account-earnings/route.ts @@ -0,0 +1,30 @@ +import { NextRequest, NextResponse } from "next/server"; + +const DEFAULT_COORD = process.env.NEXT_PUBLIC_COORDINATOR_URL || "https://api.darkbloom.dev"; + +export async function GET(req: NextRequest) { + let authHeader = req.headers.get("authorization") || ""; + if (!authHeader) { + const privyToken = req.cookies.get("privy-token")?.value; + if (privyToken) { + authHeader = `Bearer ${privyToken}`; + } + } + if (!authHeader) { + return NextResponse.json({ error: "missing auth token" }, { status: 401 }); + } + + const limit = req.nextUrl.searchParams.get("limit") || "100"; + const upstream = new URL(`${DEFAULT_COORD}/v1/provider/account-earnings`); + upstream.searchParams.set("limit", limit); + + const res = await fetch(upstream, { + headers: { Authorization: authHeader }, + cache: "no-store", + }); + if (!res.ok) { + const text = await res.text(); + return NextResponse.json({ error: text || `Upstream ${res.status}` }, { status: res.status }); + } + return NextResponse.json(await res.json()); +} diff --git a/console-ui/src/app/api/providers/attestation/route.ts b/console-ui/src/app/api/providers/attestation/route.ts new file mode 100644 index 000000000..d878f886f --- /dev/null +++ b/console-ui/src/app/api/providers/attestation/route.ts @@ -0,0 +1,15 @@ +import { NextResponse } from "next/server"; + +const DEFAULT_COORD = + process.env.NEXT_PUBLIC_COORDINATOR_URL || "https://api.darkbloom.dev"; + +export async function GET() { + const res = await fetch(`${DEFAULT_COORD}/v1/providers/attestation`); + if (!res.ok) { + return NextResponse.json( + { error: `Upstream ${res.status}` }, + { status: res.status } + ); + } + return NextResponse.json(await res.json()); +} diff --git a/console-ui/src/app/api/stats/route.ts b/console-ui/src/app/api/stats/route.ts new file mode 100644 index 000000000..845b541f8 --- /dev/null +++ b/console-ui/src/app/api/stats/route.ts @@ -0,0 +1,12 @@ +import { NextResponse } from "next/server"; + +const DEFAULT_COORD = process.env.NEXT_PUBLIC_COORDINATOR_URL || "https://api.darkbloom.dev"; + +export async function GET() { + const res = await fetch(`${DEFAULT_COORD}/v1/stats`, { cache: "no-store" }); + if (!res.ok) { + const text = await res.text(); + return NextResponse.json({ error: text || `Upstream ${res.status}` }, { status: res.status }); + } + return NextResponse.json(await res.json()); +} diff --git a/console-ui/src/app/link/DeviceLinkForm.tsx b/console-ui/src/app/link/DeviceLinkForm.tsx index 6c25838c4..c93653b3d 100644 --- a/console-ui/src/app/link/DeviceLinkForm.tsx +++ b/console-ui/src/app/link/DeviceLinkForm.tsx @@ -4,10 +4,6 @@ import { useState, useCallback } from "react"; import { useAuthContext } from "@/components/providers/PrivyClientProvider"; import { trackEvent } from "@/lib/google-analytics"; -const COORDINATOR_URL = - process.env.NEXT_PUBLIC_COORDINATOR_URL || - "https://api.darkbloom.dev"; - type LinkStatus = "idle" | "submitting" | "success" | "error"; export function DeviceLinkForm() { @@ -39,7 +35,7 @@ export function DeviceLinkForm() { return; } - const res = await fetch(`${COORDINATOR_URL}/v1/device/approve`, { + const res = await fetch("/api/device/approve", { method: "POST", headers: { "Content-Type": "application/json", diff --git a/console-ui/src/app/page.tsx b/console-ui/src/app/page.tsx index caedf5fd1..57e67f9a7 100644 --- a/console-ui/src/app/page.tsx +++ b/console-ui/src/app/page.tsx @@ -2,7 +2,7 @@ import { useEffect, useRef, useCallback, useState } from "react"; import { useStore } from "@/lib/store"; -import { streamChat, fetchModels } from "@/lib/api"; +import { streamWeaverChat, streamChat, fetchModels } from "@/lib/api"; import { useToastStore } from "@/hooks/useToast"; import { useAuth } from "@/hooks/useAuth"; import { ChatMessage } from "@/components/ChatMessage"; @@ -12,12 +12,46 @@ import { PreSendTrustBanner } from "@/components/PreSendTrustBanner"; import { Mail } from "lucide-react"; import { InviteCodeBanner } from "@/components/InviteCodeBanner"; import type { Message } from "@/lib/store"; +import type { WeaverCandidate, WeaverMode, WeaverTrace } from "@/lib/api"; import { trackEvent } from "@/lib/google-analytics"; function generateId() { return Math.random().toString(36).slice(2, 10) + Date.now().toString(36); } +function weaverDefaults(mode: WeaverMode) { + return mode === "deep_plus" + ? { candidateCount: 8, verifierCount: 5 } + : { candidateCount: 4, verifierCount: 3 }; +} + +function initialWeaver(mode: WeaverMode, model: string): WeaverTrace { + const { candidateCount, verifierCount } = weaverDefaults(mode); + return { + mode, + status: "generating", + candidateCount, + verifierCount, + candidates: Array.from({ length: candidateCount }, (_, index) => ({ + id: `candidate-${index + 1}`, + index, + label: `Agent ${index + 1}`, + model, + status: "queued", + content: "", + })), + verifierScores: [], + }; +} + +function upsertCandidate(candidates: WeaverCandidate[], candidate: WeaverCandidate) { + const found = candidates.some((c) => c.id === candidate.id); + if (found) { + return candidates.map((c) => (c.id === candidate.id ? { ...c, ...candidate } : c)); + } + return [...candidates, candidate].sort((a, b) => a.index - b.index); +} + const SYSTEM_PROMPT = `You are an AI assistant running on Darkbloom, a decentralized private inference platform built by Eigen Labs. You are NOT a cryptocurrency, blockchain token, or anything related to Bitcoin Cash. Darkbloom is an AI infrastructure project. When users ask "what is Darkbloom" or about the platform, use ONLY these facts: @@ -82,7 +116,7 @@ export default function ChatPage() { }, [activeChat?.messages]); const handleSend = useCallback( - async (content: string) => { + async (content: string, weaverMode?: WeaverMode) => { const trimmedContent = content.trim(); if (!trimmedContent) { return; @@ -132,6 +166,7 @@ export default function ChatPage() { id: assistantId, role: "assistant", content: "", + weaver: weaverMode ? initialWeaver(weaverMode, selectedModel) : undefined, streaming: true, timestamp: Date.now(), }; @@ -151,55 +186,209 @@ export default function ChatPage() { ]; try { - await streamChat( - allMessages, - selectedModel, - { - onToken: (token) => { - appendToMessage(chatId!, assistantId, token); + if (weaverMode) { + const defaults = weaverDefaults(weaverMode); + await streamWeaverChat( + allMessages, + selectedModel, + { + mode: weaverMode, + candidateCount: defaults.candidateCount, + verifierCount: defaults.verifierCount, }, - onThinking: (token) => { - appendToThinking(chatId!, assistantId, token); - }, - onMetrics: (metrics) => { - updateMessage(chatId!, assistantId, { - tps: metrics.tps, - ttft: metrics.ttft, - tokenCount: metrics.tokenCount, - }); - }, - onDone: (trust, metrics) => { - trackEvent("chat_complete", { - model: selectedModel, - trust_level: trust?.trustLevel, - secure_enclave: trust?.secureEnclave, - token_count: metrics.tokenCount, - }); - updateMessage(chatId!, assistantId, { - streaming: false, - trust, - tps: metrics.tps, - ttft: metrics.ttft, - tokenCount: metrics.tokenCount, - }); - setIsStreaming(false); + { + onWeaverInit: (weaver) => { + updateMessage(chatId!, assistantId, { weaver }); + }, + onWeaverStatus: (status) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { ...current, status }, + }); + } + }, + onCandidateStart: (candidate) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { + ...current, + candidates: upsertCandidate(current.candidates, candidate), + }, + }); + } + }, + onCandidateDelta: (candidateId, delta) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { + ...current, + candidates: current.candidates.map((c) => + c.id === candidateId + ? { ...c, status: "streaming", content: c.content + delta } + : c + ), + }, + }); + } + }, + onCandidateDone: (candidateId, content) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { + ...current, + candidates: current.candidates.map((c) => + c.id === candidateId ? { ...c, status: "done", content } : c + ), + }, + }); + } + }, + onCandidateError: (candidateId, error) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { + ...current, + candidates: current.candidates.map((c) => + c.id === candidateId ? { ...c, status: "failed", error } : c + ), + }, + }); + } + }, + onVerifierScore: (score) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + if (current) { + updateMessage(chatId!, assistantId, { + weaver: { + ...current, + verifierScores: [...current.verifierScores, score], + }, + }); + } + }, + onFinal: (content, selectedCandidate, confidence, usage) => { + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + trackEvent("chat_complete", { + model: selectedModel, + weaver_mode: weaverMode, + token_count: usage?.tokenCount || 0, + }); + updateMessage(chatId!, assistantId, { + content, + streaming: false, + tps: usage?.tps, + ttft: usage?.ttft, + tokenCount: usage?.tokenCount, + weaver: current + ? { + ...current, + status: "complete", + selectedCandidate, + confidence, + candidates: current.candidates.map((c) => + c.id === selectedCandidate + ? { ...c, status: "selected" } + : c + ), + } + : undefined, + }); + setIsStreaming(false); + }, + onDone: () => { + setIsStreaming(false); + }, + onError: (error) => { + trackEvent("chat_error", { + model: selectedModel, + weaver_mode: weaverMode, + error_type: "weaver_callback", + }); + const current = useStore.getState().chats + .find((c) => c.id === chatId) + ?.messages.find((m) => m.id === assistantId)?.weaver; + updateMessage(chatId!, assistantId, { + content: `Error: ${error}`, + streaming: false, + error: true, + weaver: current ? { ...current, status: "failed" } : current, + }); + addToast(error); + setIsStreaming(false); + }, }, - onError: (error) => { - trackEvent("chat_error", { - model: selectedModel, - error_type: "stream_callback", - }); - updateMessage(chatId!, assistantId, { - content: `Error: ${error}`, - streaming: false, - error: true, - }); - addToast(error); - setIsStreaming(false); + abort.signal + ); + } else { + await streamChat( + allMessages, + selectedModel, + { + onToken: (token) => { + appendToMessage(chatId!, assistantId, token); + }, + onThinking: (token) => { + appendToThinking(chatId!, assistantId, token); + }, + onMetrics: (metrics) => { + updateMessage(chatId!, assistantId, { + tps: metrics.tps, + ttft: metrics.ttft, + tokenCount: metrics.tokenCount, + }); + }, + onDone: (trust, metrics) => { + trackEvent("chat_complete", { + model: selectedModel, + trust_level: trust?.trustLevel, + secure_enclave: trust?.secureEnclave, + token_count: metrics.tokenCount, + }); + updateMessage(chatId!, assistantId, { + streaming: false, + trust, + tps: metrics.tps, + ttft: metrics.ttft, + tokenCount: metrics.tokenCount, + }); + setIsStreaming(false); + }, + onError: (error) => { + trackEvent("chat_error", { + model: selectedModel, + error_type: "stream_callback", + }); + updateMessage(chatId!, assistantId, { + content: `Error: ${error}`, + streaming: false, + error: true, + }); + addToast(error); + setIsStreaming(false); + }, }, - }, - abort.signal - ); + abort.signal + ); + } } catch (err) { if ((err as Error).name !== "AbortError") { trackEvent("chat_error", { diff --git a/console-ui/src/app/providers/ProviderDashboardContent.tsx b/console-ui/src/app/providers/ProviderDashboardContent.tsx index 3eeb4351f..6d521ec69 100644 --- a/console-ui/src/app/providers/ProviderDashboardContent.tsx +++ b/console-ui/src/app/providers/ProviderDashboardContent.tsx @@ -84,12 +84,8 @@ export default function ProviderDashboardContent() { setError("Not authenticated"); return; } - const coordinatorUrl = - localStorage.getItem("darkbloom_coordinator_url") || - process.env.NEXT_PUBLIC_COORDINATOR_URL || - "https://api.darkbloom.dev"; const headers = { Authorization: `Bearer ${token}` }; - const pRes = await fetch(`${coordinatorUrl}/v1/me/providers`, { headers, cache: "no-store" }); + const pRes = await fetch("/api/me/providers", { headers, cache: "no-store" }); if (!pRes.ok) throw new Error(`providers: HTTP ${pRes.status}`); const p = await pRes.json() as MyProvidersResponse; setProvidersResp(p); diff --git a/console-ui/src/app/providers/earnings/EarningsContent.tsx b/console-ui/src/app/providers/earnings/EarningsContent.tsx index df4b4f5d0..9ebf77323 100644 --- a/console-ui/src/app/providers/earnings/EarningsContent.tsx +++ b/console-ui/src/app/providers/earnings/EarningsContent.tsx @@ -119,18 +119,9 @@ export default function EarningsContent() { const fetchEarnings = useCallback(async () => { setError(null); try { - const coordinatorUrl = - localStorage.getItem("darkbloom_coordinator_url") || - process.env.NEXT_PUBLIC_COORDINATOR_URL || - "https://api.darkbloom.dev"; const headers = await getAuthHeaders(); - const res = await fetch( - `${coordinatorUrl}/v1/provider/account-earnings?limit=100`, - { - headers, - } - ); + const res = await fetch("/api/provider/account-earnings?limit=100", { headers }); if (!res.ok) throw new Error(`HTTP ${res.status}`); setData(await res.json()); } catch (e) { diff --git a/console-ui/src/app/stats/page.tsx b/console-ui/src/app/stats/page.tsx index 68f8ca876..e22882a6d 100644 --- a/console-ui/src/app/stats/page.tsx +++ b/console-ui/src/app/stats/page.tsx @@ -10,8 +10,6 @@ import { } from "lucide-react"; import { TopBar } from "@/components/TopBar"; -const STATS_API = "https://api.darkbloom.dev"; - interface CPUCores { total: number; performance: number; @@ -451,7 +449,7 @@ export default function StatsPage() { const fetchStats = async () => { try { - const res = await fetch(`${STATS_API}/v1/stats`); + const res = await fetch("/api/stats"); if (!res.ok) throw new Error(`HTTP ${res.status}`); const data = await res.json(); setStats(data); diff --git a/console-ui/src/components/ChatInput.tsx b/console-ui/src/components/ChatInput.tsx index 84eca0c7d..300c48ac9 100644 --- a/console-ui/src/components/ChatInput.tsx +++ b/console-ui/src/components/ChatInput.tsx @@ -1,12 +1,13 @@ "use client"; import { useState, useRef, useCallback, useEffect } from "react"; -import { Send, Square, ChevronDown, LogIn } from "lucide-react"; +import { Send, Square, ChevronDown, LogIn, Brain } from "lucide-react"; import { useStore } from "@/lib/store"; import { trackEvent } from "@/lib/google-analytics"; +import type { WeaverMode } from "@/lib/api"; interface ChatInputProps { - onSend: (content: string) => void; + onSend: (content: string, mode?: WeaverMode) => void; onStop: () => void; isStreaming: boolean; authenticated?: boolean; @@ -18,14 +19,15 @@ export function ChatInput({ onSend, onStop, isStreaming, authenticated = true, o const textareaRef = useRef(null); const { selectedModel, models, setSelectedModel } = useStore(); const [modelOpen, setModelOpen] = useState(false); + const [weaverMode, setWeaverMode] = useState<"fast" | WeaverMode>("fast"); const handleSend = useCallback(() => { const trimmed = input.trim(); if (!trimmed || isStreaming) return; - onSend(trimmed); + onSend(trimmed, weaverMode === "fast" ? undefined : weaverMode); setInput(""); if (textareaRef.current) textareaRef.current.style.height = "auto"; - }, [input, isStreaming, onSend]); + }, [weaverMode, input, isStreaming, onSend]); const handleKeyDown = useCallback( (e: React.KeyboardEvent) => { @@ -100,7 +102,7 @@ export function ChatInput({ onSend, onStop, isStreaming, authenticated = true, o {/* Bottom bar */}
{/* Left: model selector */} -
+
- {modelOpen && chatModels.length > 0 && ( + {modelOpen && (
- {chatModels.map((m) => { - const name = m.display_name || m.id.split("/").pop() || m.id; - return ( - - ); - })} + {m.quantization && ( + + {m.quantization} + + )} + + ); + }) + ) : ( +
+

No models loaded

+

+ Check your API key or coordinator connection. +

+
+ )}
)}
+ +
+ {([ + ["fast", "Fast"], + ["deep", "Deep"], + ["deep_plus", "Deep+"], + ] as const).map(([value, label]) => ( + + ))} +
{/* Right: Send / Stop */} diff --git a/console-ui/src/components/ChatMessage.tsx b/console-ui/src/components/ChatMessage.tsx index f66a4463f..8522eb99d 100644 --- a/console-ui/src/components/ChatMessage.tsx +++ b/console-ui/src/components/ChatMessage.tsx @@ -2,6 +2,8 @@ import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; +import { WeaverTraceModal } from "./WeaverTraceModal"; +import { WeaverTraceStrip } from "./WeaverTraceStrip"; import { TrustBadge } from "./TrustBadge"; import { VerificationPanel } from "./VerificationPanel"; import type { Message } from "@/lib/store"; @@ -248,6 +250,8 @@ function parseThinkFromContent(content: string, existingThinking?: string): { th export function ChatMessage({ message, onRetry }: { message: Message; onRetry?: () => void }) { const isUser = message.role === "user"; + const [traceOpen, setTraceOpen] = useState(false); + const [traceTab, setTraceTab] = useState(undefined); const parsed = !isUser && !message.streaming ? parseThinkFromContent(message.content, message.thinking) @@ -303,6 +307,16 @@ export function ChatMessage({ message, onRetry }: { message: Message; onRetry?: /> )} + {message.weaver && ( + { + setTraceTab(candidateId); + setTraceOpen(true); + }} + /> + )} + {message.trust && !message.streaming && (
@@ -349,6 +363,17 @@ export function ChatMessage({ message, onRetry }: { message: Message; onRetry?: {message.error ? "Retry" : "Regenerate"} )} + + {message.weaver && traceOpen && ( + setTraceOpen(false)} + onSelectTab={setTraceTab} + /> + )}
diff --git a/console-ui/src/components/GoogleAnalytics.tsx b/console-ui/src/components/GoogleAnalytics.tsx index b82943c83..d88952d0f 100644 --- a/console-ui/src/components/GoogleAnalytics.tsx +++ b/console-ui/src/components/GoogleAnalytics.tsx @@ -17,10 +17,9 @@ import { export function GoogleAnalytics() { const pathname = usePathname(); const measurementId = getGoogleAnalyticsMeasurementId(); - const [consentState, setConsentState] = useState<"granted" | "denied" | "unset">( - () => getGoogleAnalyticsConsentStatus(), - ); - const [hasConsent, setHasConsent] = useState(consentState === "granted"); + const [mounted, setMounted] = useState(false); + const [consentState, setConsentState] = useState<"granted" | "denied" | "unset">("unset"); + const [hasConsent, setHasConsent] = useState(false); useEffect(() => { const syncConsentState = () => { @@ -30,6 +29,7 @@ export function GoogleAnalytics() { }; syncConsentState(); + setMounted(true); window.addEventListener("darkbloom-ga-consent-changed", syncConsentState); const onStorage = (event: StorageEvent) => { if (event.key === getGoogleAnalyticsConsentStorageKey()) { @@ -61,7 +61,7 @@ export function GoogleAnalytics() { trackRouteChange(pathname); }, [hasConsent, pathname]); - if (!measurementId) { + if (!mounted || !measurementId) { return null; } @@ -74,7 +74,7 @@ export function GoogleAnalytics() { /> )} {consentState === "unset" && !hasConsent && ( -
+

Allow privacy-filtered usage analytics to help improve Darkbloom's product experience. diff --git a/console-ui/src/components/InviteCodeBanner.tsx b/console-ui/src/components/InviteCodeBanner.tsx index f6f81e367..e173f0c08 100644 --- a/console-ui/src/components/InviteCodeBanner.tsx +++ b/console-ui/src/components/InviteCodeBanner.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useCallback } from "react"; +import { useState, useCallback, useEffect } from "react"; import { Ticket, X, Check, Loader2 } from "lucide-react"; import { redeemInviteCode } from "@/lib/api"; import { trackEvent } from "@/lib/google-analytics"; @@ -8,16 +8,19 @@ import { trackEvent } from "@/lib/google-analytics"; const DISMISSED_KEY = "darkbloom_invite_dismissed"; export function InviteCodeBanner() { - const [dismissed, setDismissed] = useState(() => { - if (typeof window === "undefined") return true; - return localStorage.getItem(DISMISSED_KEY) === "1"; - }); + const [mounted, setMounted] = useState(false); + const [dismissed, setDismissed] = useState(true); const [expanded, setExpanded] = useState(false); const [code, setCode] = useState(""); const [loading, setLoading] = useState(false); const [success, setSuccess] = useState(""); const [error, setError] = useState(""); + useEffect(() => { + setDismissed(localStorage.getItem(DISMISSED_KEY) === "1"); + setMounted(true); + }, []); + const dismissBanner = useCallback(() => { setDismissed(true); localStorage.setItem(DISMISSED_KEY, "1"); @@ -57,7 +60,7 @@ export function InviteCodeBanner() { setLoading(false); }, [code, dismissBanner]); - if (dismissed) return null; + if (!mounted || dismissed) return null; return (

diff --git a/console-ui/src/components/PreSendTrustBanner.tsx b/console-ui/src/components/PreSendTrustBanner.tsx index e53d1a008..68c96cdf8 100644 --- a/console-ui/src/components/PreSendTrustBanner.tsx +++ b/console-ui/src/components/PreSendTrustBanner.tsx @@ -4,8 +4,6 @@ import { useState, useEffect } from "react"; import { ShieldCheck, Info } from "lucide-react"; import { TrustExplainerModal } from "./TrustExplainerModal"; -const ATTESTATION_API = "https://api.darkbloom.dev"; - interface ProviderSummary { count: number; lastVerified: string; @@ -22,7 +20,7 @@ export function PreSendTrustBanner({ visible }: { visible: boolean }) { async function fetchProviders() { try { - const res = await fetch(`${ATTESTATION_API}/v1/providers/attestation`); + const res = await fetch("/api/providers/attestation"); if (!res.ok) return; const data = await res.json(); if (cancelled) return; diff --git a/console-ui/src/components/VerificationPanel.tsx b/console-ui/src/components/VerificationPanel.tsx index a0226afa0..0d33aa02c 100644 --- a/console-ui/src/components/VerificationPanel.tsx +++ b/console-ui/src/components/VerificationPanel.tsx @@ -26,7 +26,9 @@ import { Info, } from "lucide-react"; -const ATTESTATION_API = "https://api.darkbloom.dev"; +const ATTESTATION_API = + process.env.NEXT_PUBLIC_COORDINATOR_URL || "https://api.darkbloom.dev"; +const ATTESTATION_PROXY = "/api/providers/attestation"; /** Mask a serial number for normal mode: show first 4 + last 2, mask the rest. */ function maskSerial(serial: string): string { @@ -508,7 +510,7 @@ export function VerificationPanel({ trust }: { trust: TrustMetadata }) { setVerifySteps([]); try { - const res = await fetch(`${ATTESTATION_API}/v1/providers/attestation`); + const res = await fetch(ATTESTATION_PROXY); const data = await res.json(); const provider = diff --git a/console-ui/src/components/WeaverTraceModal.tsx b/console-ui/src/components/WeaverTraceModal.tsx new file mode 100644 index 000000000..826bc012f --- /dev/null +++ b/console-ui/src/components/WeaverTraceModal.tsx @@ -0,0 +1,162 @@ +"use client"; + +import { X } from "lucide-react"; +import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; +import type { WeaverTrace } from "@/lib/api"; + +function groupedScores(weaver: WeaverTrace) { + return weaver.candidates.map((candidate) => ({ + candidate, + scores: weaver.verifierScores.filter((score) => score.candidateId === candidate.id), + })); +} + +export function WeaverTraceModal({ + weaver, + finalContent, + usage, + initialTab, + onClose, + onSelectTab, +}: { + weaver: WeaverTrace; + finalContent: string; + usage?: { tokenCount?: number; tps?: number; ttft?: number }; + initialTab?: string; + onClose: () => void; + onSelectTab: (tab: string) => void; +}) { + const activeTab = initialTab || weaver.candidates[0]?.id || "verifiers"; + const activeCandidate = weaver.candidates.find((candidate) => candidate.id === activeTab); + let tabContent; + + if (activeCandidate) { + tabContent = ( +
+
+ {activeCandidate.model} + · + {activeCandidate.status} +
+ {activeCandidate.error ? ( +

{activeCandidate.error}

+ ) : ( +
+ + {activeCandidate.content || "Waiting for this agent..."} + +
+ )} +
+ ); + } else if (activeTab === "verifiers") { + tabContent = ( +
+ {groupedScores(weaver).map(({ candidate, scores }) => ( +
+
+ {candidate.label} + {scores.length} scores +
+ {scores.length ? ( +
+ {scores.map((score, idx) => ( +
+
+ {score.score.toFixed(1)} + {score.verifierModel} +
+ {score.rationale && ( +

{score.rationale}

+ )} +
+ ))} +
+ ) : ( +

No verifier scores yet.

+ )} +
+ ))} +
+ ); + } else { + tabContent = ( +
+
+ Winner: {weaver.selectedCandidate || "pending"} + Confidence: {weaver.confidence ? `${Math.round(weaver.confidence * 100)}%` : "pending"} + Tokens: {usage?.tokenCount ?? 0} +
+
+ + {finalContent || "Final answer pending..."} + +
+
+ ); + } + + return ( +
+
+
+
+

Weaver Trace

+

+ {weaver.candidateCount} agents · {weaver.verifierCount} verifier passes +

+
+ +
+ +
+ {weaver.candidates.map((candidate) => ( + + ))} + {["verifiers", "final"].map((tab) => ( + + ))} +
+ +
+ {tabContent} +
+
+
+ ); +} diff --git a/console-ui/src/components/WeaverTraceStrip.tsx b/console-ui/src/components/WeaverTraceStrip.tsx new file mode 100644 index 000000000..6acc9dfd4 --- /dev/null +++ b/console-ui/src/components/WeaverTraceStrip.tsx @@ -0,0 +1,62 @@ +"use client"; + +import type { WeaverCandidate, WeaverTrace } from "@/lib/api"; + +function diamondClass(candidate: WeaverCandidate, selectedCandidate?: string) { + if (candidate.status === "failed") { + return "border-accent-red bg-accent-red text-white"; + } + if (candidate.id === selectedCandidate || candidate.status === "selected") { + return "border-gold bg-gold text-white"; + } + if (selectedCandidate && candidate.status === "done") { + return "border-border-subtle bg-bg-tertiary text-text-tertiary opacity-60"; + } + if (candidate.status === "done") { + return "border-teal bg-teal text-white"; + } + if (candidate.status === "streaming") { + return "border-teal bg-transparent text-teal animate-pulse"; + } + return "border-border-subtle bg-bg-tertiary text-text-tertiary opacity-60"; +} + +export function WeaverTraceStrip({ + weaver, + onOpen, +}: { + weaver: WeaverTrace; + onOpen: (candidateId?: string) => void; +}) { + const done = weaver.candidates.filter((c) => c.status === "done" || c.status === "selected").length; + const failed = weaver.candidates.filter((c) => c.status === "failed").length; + + return ( +
+ + +
+ {weaver.candidates.map((candidate) => ( +
+ + + {done}/{weaver.candidateCount} done{failed ? ` · ${failed} failed` : ""} + +
+ ); +} diff --git a/console-ui/src/lib/api.ts b/console-ui/src/lib/api.ts index 6a008a809..fab16061f 100644 --- a/console-ui/src/lib/api.ts +++ b/console-ui/src/lib/api.ts @@ -38,6 +38,8 @@ export interface Model { attested?: boolean; trust_level?: string; display_name?: string; + family?: string; + can_verify?: boolean; } export interface BalanceResponse { @@ -90,6 +92,57 @@ export interface StreamCallbacks { onError: (error: string) => void; } +export type WeaverMode = "deep" | "deep_plus"; +export type WeaverStatus = "generating" | "verifying" | "selecting" | "complete" | "failed"; +export type CandidateStatus = "queued" | "streaming" | "done" | "failed" | "selected"; + +export interface WeaverCandidate { + id: string; + index: number; + label: string; + model: string; + status: CandidateStatus; + content: string; + error?: string; +} + +export interface WeaverVerifierScore { + candidateId: string; + verifierModel: string; + score: number; + rationale?: string; +} + +export interface WeaverTrace { + mode: WeaverMode; + status: WeaverStatus; + candidateCount: number; + verifierCount: number; + selectedCandidate?: string; + confidence?: number; + candidates: WeaverCandidate[]; + verifierScores: WeaverVerifierScore[]; +} + +export interface WeaverRequestOptions { + mode: WeaverMode; + candidateCount: number; + verifierCount: number; +} + +export interface WeaverStreamCallbacks { + onWeaverInit: (trace: WeaverTrace) => void; + onWeaverStatus: (status: WeaverStatus) => void; + onCandidateStart: (candidate: WeaverCandidate) => void; + onCandidateDelta: (candidateId: string, delta: string) => void; + onCandidateDone: (candidateId: string, content: string) => void; + onCandidateError: (candidateId: string, error: string) => void; + onVerifierScore: (score: WeaverVerifierScore) => void; + onFinal: (content: string, selectedCandidate?: string, confidence?: number, usage?: StreamMetrics) => void; + onDone: () => void; + onError: (error: string) => void; +} + export async function fetchModels(): Promise { const res = await fetch("/api/models", { headers: proxyHeaders() }); if (!res.ok) throw new Error(`Failed to fetch models: ${res.status}`); @@ -106,6 +159,8 @@ export async function fetchModels(): Promise { trust_level: m.trust_level || meta.trust_level, attested: m.attested ?? (meta.attested_providers as number) > 0, display_name: m.display_name || meta.display_name, + family: m.family || meta.family, + can_verify: m.can_verify ?? meta.can_verify, }; }); } @@ -296,6 +351,211 @@ export async function healthCheck(): Promise<{ status: string; providers: number return res.json(); } +function parseSseBlock(block: string): { event: string; data: string } | null { + const lines = block.split(/\r?\n/); + let event = "message"; + const data: string[] = []; + for (const line of lines) { + if (line.startsWith("event:")) { + event = line.slice(6).trim(); + } else if (line.startsWith("data:")) { + data.push(line.slice(5).trimStart()); + } + } + if (data.length === 0) return null; + return { event, data: data.join("\n") }; +} + +async function parseStreamError(res: Response, sealCtx: { ephemPriv: Uint8Array; coordPub: Uint8Array } | null): Promise { + if (res.status === 401) { + localStorage.removeItem("darkbloom_api_key"); + window.dispatchEvent(new Event("darkbloom-key-expired")); + return "Session expired — please try again"; + } + + let text = await res.text(); + const errCt = res.headers.get("content-type") || ""; + const errSealed = + sealCtx && (res.headers.get("x-eigen-sealed") === "true" || + errCt.toLowerCase().startsWith(SEALED_CONTENT_TYPE)); + if (errSealed && sealCtx) { + try { + const pt = unsealResponse(text, sealCtx.ephemPriv, sealCtx.coordPub); + text = new TextDecoder().decode(pt); + } catch (err) { + return `Could not decrypt sealed error response: ${err instanceof Error ? err.message : String(err)}`; + } + } + try { + const errData = JSON.parse(text); + const msg = errData?.error?.message || text; + if (res.status === 503 && msg.includes("queue timeout")) { + return "All providers are busy — please try again in a moment"; + } + if (res.status === 402) { + return "Insufficient credits — buy credits in Billing to continue"; + } + return `Request failed (${res.status}): ${msg}`; + } catch { + return `Request failed (${res.status}): ${text}`; + } +} + +export async function streamWeaverChat( + messages: ChatMessage[], + model: string, + options: WeaverRequestOptions, + callbacks: WeaverStreamCallbacks, + signal?: AbortSignal +): Promise { + const requestBody = { + model, + messages, + stream: true, + weaver: { + mode: options.mode, + candidate_count: options.candidateCount, + verifier_count: options.verifierCount, + verifier_selection: "auto", + trace: true, + }, + }; + let sealCtx: { ephemPriv: Uint8Array; coordPub: Uint8Array } | null = null; + let fetchHeaders = proxyHeaders(); + let fetchBody: string; + + if (isEncryptionEnabled()) { + try { + const coordKey = await getCoordinatorKey(); + const sealed = sealRequest(requestBody, coordKey); + fetchBody = sealed.envelopeJson; + fetchHeaders = proxyHeaders({ "Content-Type": SEALED_CONTENT_TYPE }); + sealCtx = { + ephemPriv: sealed.ephemeralPrivateKey, + coordPub: coordKey.publicKey, + }; + } catch (err) { + callbacks.onError( + `Encryption setup failed: ${err instanceof Error ? err.message : String(err)} — disable "Encrypt to coordinator" in Settings to continue in plaintext.`, + ); + return; + } + } else { + fetchBody = JSON.stringify(requestBody); + } + + const res = await fetch("/api/chat", { + method: "POST", + headers: fetchHeaders, + body: fetchBody, + signal, + }); + + if (sealCtx && res.status === 400) { + const text = await res.clone().text(); + if (text.includes("kid_mismatch")) { + clearCoordinatorKeyCache(); + } + } + + if (!res.ok) { + callbacks.onError(await parseStreamError(res, sealCtx)); + return; + } + + const reader = res.body?.getReader(); + if (!reader) { + callbacks.onError("No response body"); + return; + } + + const decoder = new TextDecoder(); + let buffer = ""; + const responseSealed = sealCtx !== null && res.headers.get("x-eigen-sealed") === "true"; + + const handleBlock = (rawBlock: string) => { + let block = rawBlock.trim(); + if (!block) return; + + if (responseSealed && sealCtx) { + const encrypted = parseSseBlock(block); + if (!encrypted) return; + try { + block = unsealSseEvent(encrypted.data, sealCtx.ephemPriv, sealCtx.coordPub).trim(); + } catch (err) { + callbacks.onError( + `Sealed stream decryption failed: ${err instanceof Error ? err.message : String(err)}`, + ); + return; + } + } + + const parsed = parseSseBlock(block); + if (!parsed) return; + if (parsed.data === "[DONE]" || parsed.event === "done") { + callbacks.onDone(); + return; + } + + let payload: any; + try { + payload = JSON.parse(parsed.data); + } catch { + return; + } + + switch (parsed.event) { + case "weaver_init": + callbacks.onWeaverInit(payload as WeaverTrace); + break; + case "weaver_status": + callbacks.onWeaverStatus(payload.status as WeaverStatus); + break; + case "candidate_start": + callbacks.onCandidateStart(payload as WeaverCandidate); + break; + case "candidate_delta": + callbacks.onCandidateDelta(payload.candidate_id, payload.delta || ""); + break; + case "candidate_done": + callbacks.onCandidateDone(payload.candidate_id, payload.content || ""); + break; + case "candidate_error": + callbacks.onCandidateError(payload.candidate_id, payload.error || "candidate failed"); + break; + case "verifier_score": + callbacks.onVerifierScore({ + candidateId: payload.candidate_id, + verifierModel: payload.verifier_model, + score: Number(payload.score) || 0, + rationale: payload.rationale, + }); + break; + case "weaver_final": + callbacks.onFinal(payload.content || "", payload.selected_candidate, payload.confidence, payload.usage); + break; + case "error": + callbacks.onError(payload?.error?.message || payload?.message || "weaver failed"); + break; + } + }; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const blocks = buffer.split(/\n\n/); + buffer = blocks.pop() || ""; + for (const block of blocks) { + handleBlock(block); + } + } + + if (buffer.trim()) { + handleBlock(buffer); + } +} + export async function streamChat( messages: ChatMessage[], model: string, diff --git a/console-ui/src/lib/store.ts b/console-ui/src/lib/store.ts index a5418d2da..1f4b98b1f 100644 --- a/console-ui/src/lib/store.ts +++ b/console-ui/src/lib/store.ts @@ -1,6 +1,6 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; -import type { TrustMetadata, Model } from "./api"; +import type { WeaverTrace, TrustMetadata, Model } from "./api"; export interface Chat { id: string; @@ -14,6 +14,7 @@ export interface Message { role: "user" | "assistant"; content: string; thinking?: string; + weaver?: WeaverTrace; trust?: TrustMetadata; streaming?: boolean; error?: boolean; diff --git a/coordinator/cmd/coordinator/main.go b/coordinator/cmd/coordinator/main.go index f2c69ee33..cb00cef73 100644 --- a/coordinator/cmd/coordinator/main.go +++ b/coordinator/cmd/coordinator/main.go @@ -18,6 +18,13 @@ // production; omit + EIGENINFERENCE_ALLOW_MEMORY_STORE=true for dev) // EIGENINFERENCE_ALLOW_MEMORY_STORE - Set to "true" to permit MemoryStore boot // when DATABASE_URL is unset (dev/test only) +// EIGENINFERENCE_ALLOW_UNCATALOGUED_MODELS +// - Set to "true" for local dev to permit +// ad-hoc downloaded models outside the +// production model catalog +// EIGENINFERENCE_DEV_PROFILE - Set to "true" to enable local dev defaults: +// memory store, uncatalogued models, +// TrustNone routing, and env runtime hashes // // Graceful shutdown: The coordinator handles SIGINT/SIGTERM, stops the // eviction loop, and drains active connections with a 15-second deadline. @@ -70,10 +77,14 @@ func main() { // Configuration from environment. port := envOr("EIGENINFERENCE_PORT", "8080") adminKey := os.Getenv("EIGENINFERENCE_ADMIN_KEY") + devProfile := localDevProfileEnabled() if adminKey == "" { logger.Warn("EIGENINFERENCE_ADMIN_KEY is not set — no pre-seeded API key available") } + if devProfile { + logger.Warn("LOCAL DEV PROFILE ENABLED — memory store, uncatalogued models, TrustNone routing, and env runtime hashes are allowed") + } // Create core components. ctx, cancel := context.WithCancel(context.Background()) @@ -101,14 +112,14 @@ func main() { // In production that would lose USDC deposits and provider payouts. // Refuse to boot unless the operator has explicitly opted in (e.g. // for local dev or integration tests). - if os.Getenv("EIGENINFERENCE_ALLOW_MEMORY_STORE") != "true" { + if !devProfile && os.Getenv("EIGENINFERENCE_ALLOW_MEMORY_STORE") != "true" { logger.Error("EIGENINFERENCE_DATABASE_URL is not set and EIGENINFERENCE_ALLOW_MEMORY_STORE is not \"true\" — refusing to start with non-durable store") os.Exit(1) } memStore := store.NewMemory(adminKey) st = memStore - logger.Warn("using in-memory store — billing state will not survive restart (set EIGENINFERENCE_DATABASE_URL for production)") + logger.Warn("using in-memory store — billing state will not survive restart (set EIGENINFERENCE_DATABASE_URL for production)", "dev_profile", devProfile) // MemoryStore's append-only slices (usage, ledger, earnings, // payouts, payments) grow unboundedly over the lifetime of the @@ -130,8 +141,13 @@ func main() { }) } - // Seed the model catalog if empty (first startup or fresh DB). - seedModelCatalog(st, logger) + allowUncataloguedModels := devProfile || os.Getenv("EIGENINFERENCE_ALLOW_UNCATALOGUED_MODELS") == "true" + if allowUncataloguedModels { + logger.Warn("uncatalogued model routing ENABLED — local dev only", "dev_profile", devProfile) + } else { + // Seed the model catalog if empty (first startup or fresh DB). + seedModelCatalog(st, logger) + } reg := registry.New(logger) @@ -140,6 +156,9 @@ func main() { if minTrust := os.Getenv("EIGENINFERENCE_MIN_TRUST"); minTrust != "" { reg.MinTrustLevel = registry.TrustLevel(minTrust) logger.Info("minimum trust level override", "level", minTrust) + } else if devProfile { + reg.MinTrustLevel = registry.TrustNone + logger.Warn("minimum trust lowered to none by local dev profile") } srv := api.NewServer(reg, st, logger) @@ -207,8 +226,14 @@ func main() { } // Sync the model catalog to the registry so providers and consumers - // are filtered against the admin-managed whitelist. - srv.SyncModelCatalog() + // are filtered against the admin-managed whitelist. Local dev can disable + // this to test small ad-hoc models without downloading production weights. + if allowUncataloguedModels { + reg.SetModelCatalog(nil) + logger.Warn("model catalog filtering disabled — local dev only", "dev_profile", devProfile) + } else { + srv.SyncModelCatalog() + } // Console URL — frontend for device auth verification links. if consoleURL := os.Getenv("EIGENINFERENCE_CONSOLE_URL"); consoleURL != "" { @@ -273,41 +298,49 @@ func main() { templateHashes := os.Getenv("EIGENINFERENCE_KNOWN_TEMPLATE_HASHES") // format: name=hash,name=hash if pythonHashes != "" || runtimeHashes != "" || templateHashes != "" { - manifest := &api.RuntimeManifest{ - PythonHashes: make(map[string]bool), - RuntimeHashes: make(map[string]bool), - TemplateHashes: make(map[string]string), + allowEnvRuntimeManifest := devProfile || os.Getenv("EIGENINFERENCE_ALLOW_ENV_RUNTIME_HASHES") == "true" + if !allowEnvRuntimeManifest { + logger.Error("runtime manifest env vars ignored outside local dev profile; register a release manifest or set EIGENINFERENCE_DEV_PROFILE=true for local development") + } else if os.Getenv("EIGENINFERENCE_ALLOW_ENV_RUNTIME_HASHES") == "true" && !devProfile { + logger.Warn("runtime manifest configured from env outside dev profile — operator override") } - if pythonHashes != "" { - for _, h := range strings.Split(pythonHashes, ",") { - h = strings.TrimSpace(h) - if h != "" { - manifest.PythonHashes[h] = true + if allowEnvRuntimeManifest { + manifest := &api.RuntimeManifest{ + PythonHashes: make(map[string]bool), + RuntimeHashes: make(map[string]bool), + TemplateHashes: make(map[string]string), + } + if pythonHashes != "" { + for _, h := range strings.Split(pythonHashes, ",") { + h = strings.TrimSpace(h) + if h != "" { + manifest.PythonHashes[h] = true + } } } - } - if runtimeHashes != "" { - for _, h := range strings.Split(runtimeHashes, ",") { - h = strings.TrimSpace(h) - if h != "" { - manifest.RuntimeHashes[h] = true + if runtimeHashes != "" { + for _, h := range strings.Split(runtimeHashes, ",") { + h = strings.TrimSpace(h) + if h != "" { + manifest.RuntimeHashes[h] = true + } } } - } - if templateHashes != "" { - for _, pair := range strings.Split(templateHashes, ",") { - parts := strings.SplitN(strings.TrimSpace(pair), "=", 2) - if len(parts) == 2 { - manifest.TemplateHashes[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + if templateHashes != "" { + for _, pair := range strings.Split(templateHashes, ",") { + parts := strings.SplitN(strings.TrimSpace(pair), "=", 2) + if len(parts) == 2 { + manifest.TemplateHashes[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } } } + srv.SetRuntimeManifest(manifest) + logger.Info("runtime manifest configured", + "python_hashes", len(manifest.PythonHashes), + "runtime_hashes", len(manifest.RuntimeHashes), + "template_hashes", len(manifest.TemplateHashes), + ) } - srv.SetRuntimeManifest(manifest) - logger.Info("runtime manifest configured", - "python_hashes", len(manifest.PythonHashes), - "runtime_hashes", len(manifest.RuntimeHashes), - "template_hashes", len(manifest.TemplateHashes), - ) } } @@ -554,13 +587,18 @@ func envInt(key string, fallback int) int { return fallback } +func localDevProfileEnabled() bool { + return os.Getenv("EIGENINFERENCE_DEV_PROFILE") == "true" || + strings.EqualFold(os.Getenv("EIGENINFERENCE_PROFILE"), "dev") +} + // seedModelCatalog ensures all hardcoded models exist in the catalog. // On first startup it populates everything; on subsequent starts it adds // any new models that were added to the code but not yet in the DB and removes // catalog entries that should no longer be provider-selectable. func seedModelCatalog(st store.Store, logger *slog.Logger) { existing := st.ListSupportedModels() - existingIDs := make(map[string]bool, len(existing)) + existingByID := make(map[string]store.SupportedModel, len(existing)) removed := 0 for _, m := range existing { if api.IsRetiredProviderModel(m) { @@ -571,24 +609,33 @@ func seedModelCatalog(st store.Store, logger *slog.Logger) { } continue } - existingIDs[m.ID] = true + existingByID[m.ID] = m } models := []store.SupportedModel{ // --- Text generation (8-bit quantization) --- - {ID: "qwen3.5-27b-claude-opus-8bit", S3Name: "qwen35-27b-claude-opus-8bit", DisplayName: "Qwen3.5 27B Claude Opus Distilled", ModelType: "text", SizeGB: 27.0, Architecture: "27B dense, Claude Opus distilled", Description: "Frontier quality reasoning", MinRAMGB: 36, Active: true}, - {ID: "mlx-community/Trinity-Mini-8bit", S3Name: "Trinity-Mini-8bit", DisplayName: "Trinity Mini", ModelType: "text", SizeGB: 26.0, Architecture: "27B Adaptive MoE", Description: "Fast agentic inference", MinRAMGB: 48, Active: true}, - {ID: "mlx-community/gemma-4-26b-a4b-it-8bit", S3Name: "gemma-4-26b-a4b-it-8bit", DisplayName: "Gemma 4 26B", ModelType: "text", SizeGB: 28.0, Architecture: "26B MoE, 4B active", Description: "Fast multimodal MoE", MinRAMGB: 36, Active: true}, - {ID: "mlx-community/Qwen3.5-122B-A10B-8bit", S3Name: "Qwen3.5-122B-A10B-8bit", DisplayName: "Qwen3.5 122B", ModelType: "text", SizeGB: 122.0, Architecture: "122B MoE, 10B active", Description: "Best quality", MinRAMGB: 128, Active: true}, - {ID: "mlx-community/MiniMax-M2.5-8bit", S3Name: "MiniMax-M2.5-8bit", DisplayName: "MiniMax M2.5", ModelType: "text", SizeGB: 243.0, Architecture: "239B MoE, 11B active", Description: "SOTA coding, 100 tok/s", MinRAMGB: 256, Active: true}, + {ID: "qwen3.5-27b-claude-opus-8bit", S3Name: "qwen35-27b-claude-opus-8bit", DisplayName: "Qwen3.5 27B Claude Opus Distilled", ModelType: "text", SizeGB: 27.0, Family: "qwen", CanVerify: true, Architecture: "27B dense, Claude Opus distilled", Description: "Frontier quality reasoning", MinRAMGB: 36, Active: true}, + {ID: "mlx-community/Trinity-Mini-8bit", S3Name: "Trinity-Mini-8bit", DisplayName: "Trinity Mini", ModelType: "text", SizeGB: 26.0, Family: "trinity", CanVerify: true, Architecture: "27B Adaptive MoE", Description: "Fast agentic inference", MinRAMGB: 48, Active: true}, + {ID: "mlx-community/gemma-4-26b-a4b-it-8bit", S3Name: "gemma-4-26b-a4b-it-8bit", DisplayName: "Gemma 4 26B", ModelType: "text", SizeGB: 28.0, Family: "gemma", CanVerify: true, Architecture: "26B MoE, 4B active", Description: "Fast multimodal MoE", MinRAMGB: 36, Active: true}, + {ID: "mlx-community/Qwen3.5-122B-A10B-8bit", S3Name: "Qwen3.5-122B-A10B-8bit", DisplayName: "Qwen3.5 122B", ModelType: "text", SizeGB: 122.0, Family: "qwen", CanVerify: true, Architecture: "122B MoE, 10B active", Description: "Best quality", MinRAMGB: 128, Active: true}, + {ID: "mlx-community/MiniMax-M2.5-8bit", S3Name: "MiniMax-M2.5-8bit", DisplayName: "MiniMax M2.5", ModelType: "text", SizeGB: 243.0, Family: "minimax", CanVerify: true, Architecture: "239B MoE, 11B active", Description: "SOTA coding, 100 tok/s", MinRAMGB: 256, Active: true}, } added := 0 + updated := 0 for i := range models { if api.IsRetiredProviderModel(models[i]) { continue } - if existingIDs[models[i].ID] { + if existing, ok := existingByID[models[i].ID]; ok { + if existing.Family == "" && models[i].Family != "" { + existing.Family = models[i].Family + if err := st.SetSupportedModel(&existing); err != nil { + logger.Warn("failed to update seeded model metadata", "id", existing.ID, "error", err) + } else { + updated++ + } + } continue } if err := st.SetSupportedModel(&models[i]); err != nil { @@ -597,8 +644,8 @@ func seedModelCatalog(st store.Store, logger *slog.Logger) { added++ } } - if added > 0 || removed > 0 { - logger.Info("model catalog reconciled", "added", added, "removed", removed, "total", len(existing)+added-removed) + if added > 0 || removed > 0 || updated > 0 { + logger.Info("model catalog reconciled", "added", added, "updated", updated, "removed", removed, "total", len(existing)+added-removed) } else { logger.Info("model catalog loaded", "count", len(existing)) } diff --git a/coordinator/cmd/coordinator/main_test.go b/coordinator/cmd/coordinator/main_test.go index 2ccb62774..becaef61c 100644 --- a/coordinator/cmd/coordinator/main_test.go +++ b/coordinator/cmd/coordinator/main_test.go @@ -8,6 +8,25 @@ import ( "github.com/eigeninference/coordinator/internal/store" ) +func TestLocalDevProfileEnabled(t *testing.T) { + t.Setenv("EIGENINFERENCE_DEV_PROFILE", "") + t.Setenv("EIGENINFERENCE_PROFILE", "") + if localDevProfileEnabled() { + t.Fatal("dev profile should be disabled by default") + } + + t.Setenv("EIGENINFERENCE_DEV_PROFILE", "true") + if !localDevProfileEnabled() { + t.Fatal("EIGENINFERENCE_DEV_PROFILE=true should enable dev profile") + } + + t.Setenv("EIGENINFERENCE_DEV_PROFILE", "") + t.Setenv("EIGENINFERENCE_PROFILE", "dev") + if !localDevProfileEnabled() { + t.Fatal("EIGENINFERENCE_PROFILE=dev should enable dev profile") + } +} + func TestSeedModelCatalogRemovesRetiredProviderModels(t *testing.T) { st := store.NewMemory("") logger := slog.New(slog.NewTextHandler(io.Discard, nil)) @@ -60,7 +79,14 @@ func TestSeedModelCatalogRemovesRetiredProviderModels(t *testing.T) { if _, ok := byID[keep.ID]; !ok { t.Fatalf("supported model %q was removed", keep.ID) } - if _, ok := byID["qwen3.5-27b-claude-opus-8bit"]; !ok { + qwen, ok := byID["qwen3.5-27b-claude-opus-8bit"] + if !ok { t.Fatalf("seeded text model missing") } + if qwen.Family != "qwen" { + t.Fatalf("seeded model family = %q, want qwen", qwen.Family) + } + if !qwen.CanVerify { + t.Fatalf("seeded text model should be verifier-capable by default") + } } diff --git a/coordinator/internal/api/billing_handlers.go b/coordinator/internal/api/billing_handlers.go index 2216575a8..27c06d3fe 100644 --- a/coordinator/internal/api/billing_handlers.go +++ b/coordinator/internal/api/billing_handlers.go @@ -567,8 +567,14 @@ func (s *Server) handleAdminSetModel(w http.ResponseWriter, r *http.Request) { return } + var raw map[string]json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&raw); err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "invalid JSON: "+err.Error())) + return + } + body, _ := json.Marshal(raw) var model store.SupportedModel - if err := json.NewDecoder(r.Body).Decode(&model); err != nil { + if err := json.Unmarshal(body, &model); err != nil { writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "invalid JSON: "+err.Error())) return } @@ -580,6 +586,9 @@ func (s *Server) handleAdminSetModel(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "display_name is required")) return } + if _, provided := raw["can_verify"]; !provided && model.Active && isTextCatalogModel(model.ModelType) { + model.CanVerify = true + } if err := s.store.SetSupportedModel(&model); err != nil { s.logger.Error("admin: set model failed", "error", err) diff --git a/coordinator/internal/api/consumer.go b/coordinator/internal/api/consumer.go index e82afbecc..e8011d2df 100644 --- a/coordinator/internal/api/consumer.go +++ b/coordinator/internal/api/consumer.go @@ -583,6 +583,31 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { rawBody, _ = json.Marshal(providerParsed) } + if weaverConfig, ok, err := parseWeaverConfig(parsed); ok || err != nil { + if err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", err.Error())) + return + } + if isResponsesAPI { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "weaver is only supported with chat messages")) + return + } + if !stream { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "weaver requires stream=true")) + return + } + if !s.registry.IsModelInCatalog(model) { + writeJSON(w, http.StatusNotFound, errorResponse("model_not_found", + fmt.Sprintf("model %q is not available — see /v1/models for supported models", model))) + return + } + delete(parsed, "weaver") + parsed["stream"] = true + rawBody, _ = json.Marshal(parsed) + s.handleWeaverChatCompletions(w, r, weaverConfig, rawBody, model, estimatedPromptTokens, requestedMaxTokens, allowedProviderSerials) + return + } + // Pre-flight balance reservation — atomically debit the worst-case cost // (prompt tokens already consumed + max_tokens we just bounded the // generation to) before routing to a provider. The post-inference charge @@ -2088,6 +2113,10 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { if inCatalog && cm.DisplayName != "" { metadata["display_name"] = cm.DisplayName } + if inCatalog { + metadata["family"] = cm.Family + metadata["can_verify"] = cm.CanVerify + } data = append(data, map[string]any{ "id": m.ID, "object": "model", diff --git a/coordinator/internal/api/server.go b/coordinator/internal/api/server.go index b4969f830..8090c4b6a 100644 --- a/coordinator/internal/api/server.go +++ b/coordinator/internal/api/server.go @@ -409,6 +409,8 @@ func (s *Server) SyncModelCatalog() { ID: m.ID, WeightHash: m.WeightHash, SizeGB: m.SizeGB, + Family: m.Family, + CanVerify: m.CanVerify, }) } } diff --git a/coordinator/internal/api/weaver.go b/coordinator/internal/api/weaver.go new file mode 100644 index 000000000..670538219 --- /dev/null +++ b/coordinator/internal/api/weaver.go @@ -0,0 +1,650 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math" + "net/http" + "sort" + "strings" + "sync" + "time" + + "github.com/eigeninference/coordinator/internal/e2e" + "github.com/eigeninference/coordinator/internal/protocol" + "github.com/eigeninference/coordinator/internal/registry" + "github.com/eigeninference/coordinator/internal/store" + "github.com/google/uuid" + "nhooyr.io/websocket" +) + +const ( + weaverModeDeep = "deep" + weaverModeDeepPlus = "deep_plus" + + weaverVerifierMaxTokens = 256 +) + +type weaverConfig struct { + Mode string + CandidateCount int + VerifierCount int + Trace bool +} + +type weaverCandidate struct { + ID string `json:"id"` + Index int `json:"index"` + Label string `json:"label"` + Model string `json:"model"` + Status string `json:"status"` + Content string `json:"content"` + Error string `json:"error,omitempty"` +} + +type weaverVerifierScore struct { + CandidateID string `json:"candidate_id"` + VerifierModel string `json:"verifier_model"` + Score float64 `json:"score"` + Rationale string `json:"rationale,omitempty"` +} + +type weaverTrace struct { + Mode string `json:"mode"` + Status string `json:"status"` + CandidateCount int `json:"candidateCount"` + VerifierCount int `json:"verifierCount"` + SelectedCandidate string `json:"selectedCandidate,omitempty"` + Confidence float64 `json:"confidence,omitempty"` + Candidates []weaverCandidate `json:"candidates"` + VerifierScores []weaverVerifierScore `json:"verifierScores"` +} + +type weaverRunResult struct { + Candidate weaverCandidate + Usage protocol.UsageInfo + Err error +} + +type weaverProviderJob struct { + Provider *registry.Provider + Pending *registry.PendingRequest +} + +type weaverSSEWriter struct { + mu sync.Mutex + w http.ResponseWriter + flusher http.Flusher +} + +func (sw *weaverSSEWriter) write(event string, payload any) { + data, err := json.Marshal(payload) + if err != nil { + return + } + sw.mu.Lock() + defer sw.mu.Unlock() + fmt.Fprintf(sw.w, "event: %s\n", event) + fmt.Fprintf(sw.w, "data: %s\n\n", data) + sw.flusher.Flush() +} + +func parseWeaverConfig(parsed map[string]any) (weaverConfig, bool, error) { + raw, ok := parsed["weaver"] + if !ok || raw == nil { + return weaverConfig{}, false, nil + } + obj, ok := raw.(map[string]any) + if !ok { + return weaverConfig{}, true, errors.New("weaver must be an object") + } + + mode, _ := obj["mode"].(string) + cfg := weaverConfig{Mode: mode, Trace: true} + switch mode { + case weaverModeDeep: + cfg.CandidateCount = 4 + cfg.VerifierCount = 3 + case weaverModeDeepPlus: + cfg.CandidateCount = 8 + cfg.VerifierCount = 5 + default: + return weaverConfig{}, true, errors.New("weaver.mode must be deep or deep_plus") + } + + if v, ok := numberFromAny(obj["candidate_count"]); ok { + cfg.CandidateCount = v + } + if v, ok := numberFromAny(obj["verifier_count"]); ok { + cfg.VerifierCount = v + } + if trace, ok := obj["trace"].(bool); ok { + cfg.Trace = trace + } + if cfg.CandidateCount < 1 || cfg.CandidateCount > 16 { + return weaverConfig{}, true, errors.New("weaver.candidate_count must be between 1 and 16") + } + if cfg.VerifierCount < 1 || cfg.VerifierCount > 10 { + return weaverConfig{}, true, errors.New("weaver.verifier_count must be between 1 and 10") + } + return cfg, true, nil +} + +func numberFromAny(v any) (int, bool) { + switch n := v.(type) { + case float64: + return int(n), true + case int: + return n, true + default: + return 0, false + } +} + +func (s *Server) handleWeaverChatCompletions(w http.ResponseWriter, r *http.Request, cfg weaverConfig, rawBody []byte, model string, estimatedPromptTokens, requestedMaxTokens int, allowedProviderSerials []string) { + flusher, ok := w.(http.Flusher) + if !ok { + writeJSON(w, http.StatusInternalServerError, errorResponse("internal_error", "streaming not supported")) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + writer := &weaverSSEWriter{w: w, flusher: flusher} + + candidates := make([]weaverCandidate, cfg.CandidateCount) + for i := range candidates { + candidates[i] = weaverCandidate{ + ID: fmt.Sprintf("candidate-%d", i+1), + Index: i, + Label: fmt.Sprintf("Agent %d", i+1), + Model: model, + Status: "queued", + } + } + trace := weaverTrace{ + Mode: cfg.Mode, + Status: "generating", + CandidateCount: cfg.CandidateCount, + VerifierCount: cfg.VerifierCount, + Candidates: candidates, + VerifierScores: []weaverVerifierScore{}, + } + writer.write("weaver_init", trace) + writer.write("weaver_status", map[string]any{"status": "generating"}) + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + consumerKey := consumerKeyFromContext(r.Context()) + results := make(chan weaverRunResult, cfg.CandidateCount) + var wg sync.WaitGroup + for _, candidate := range candidates { + wg.Add(1) + go func(c weaverCandidate) { + defer wg.Done() + results <- s.runWeaverCandidate(ctx, writer, c, rawBody, model, consumerKey, estimatedPromptTokens, requestedMaxTokens, allowedProviderSerials) + }(candidate) + } + go func() { + wg.Wait() + close(results) + }() + + completed := make([]weaverRunResult, 0, cfg.CandidateCount) + var totalUsage protocol.UsageInfo + for result := range results { + if result.Err == nil && strings.TrimSpace(result.Candidate.Content) != "" { + completed = append(completed, result) + totalUsage.PromptTokens += result.Usage.PromptTokens + totalUsage.CompletionTokens += result.Usage.CompletionTokens + } + } + if len(completed) == 0 { + writer.write("weaver_status", map[string]any{"status": "failed"}) + writer.write("error", map[string]any{"message": "all weaver candidates failed"}) + writer.write("done", map[string]any{"status": "failed"}) + return + } + + writer.write("weaver_status", map[string]any{"status": "verifying"}) + verifierModels := s.selectWeaverVerifierModels(model, cfg.VerifierCount) + scores := make([]weaverVerifierScore, 0, len(completed)*cfg.VerifierCount) + for pass := 0; pass < cfg.VerifierCount; pass++ { + verifierModel := verifierModels[pass%len(verifierModels)] + for _, result := range completed { + score, usage := s.runWeaverVerifier(ctx, result.Candidate, verifierModel, consumerKey, estimatedPromptTokens, requestedMaxTokens, allowedProviderSerials) + scores = append(scores, score) + totalUsage.PromptTokens += usage.PromptTokens + totalUsage.CompletionTokens += usage.CompletionTokens + writer.write("verifier_score", score) + } + } + + writer.write("weaver_status", map[string]any{"status": "selecting"}) + winner, confidence := selectWeaverWinner(completed, scores) + writer.write("weaver_final", map[string]any{ + "selected_candidate": winner.Candidate.ID, + "confidence": confidence, + "content": winner.Candidate.Content, + "usage": map[string]any{ + "tokenCount": totalUsage.CompletionTokens, + "tps": 0, + "ttft": 0, + }, + }) + writer.write("weaver_status", map[string]any{"status": "complete"}) + writer.write("done", map[string]any{"status": "complete"}) +} + +func (s *Server) runWeaverCandidate(ctx context.Context, writer *weaverSSEWriter, candidate weaverCandidate, rawBody []byte, model, consumerKey string, estimatedPromptTokens, requestedMaxTokens int, allowedProviderSerials []string) weaverRunResult { + candidate.Status = "streaming" + writer.write("candidate_start", candidate) + + job, err := s.dispatchWeaverProviderJob(ctx, rawBody, model, consumerKey, estimatedPromptTokens, requestedMaxTokens, allowedProviderSerials) + if err != nil { + candidate.Status = "failed" + candidate.Error = err.Error() + writer.write("candidate_error", map[string]any{"candidate_id": candidate.ID, "error": candidate.Error}) + return weaverRunResult{Candidate: candidate, Err: err} + } + defer s.cleanupWeaverProviderJob(job) + + chunks, usage, err := s.collectWeaverChunks(ctx, job, func(delta string) { + if delta != "" { + writer.write("candidate_delta", map[string]any{"candidate_id": candidate.ID, "delta": delta}) + } + }) + if err != nil { + candidate.Status = "failed" + candidate.Error = err.Error() + writer.write("candidate_error", map[string]any{"candidate_id": candidate.ID, "error": candidate.Error}) + return weaverRunResult{Candidate: candidate, Err: err} + } + + msg := extractMessage(chunks) + candidate.Status = "done" + candidate.Content = msg.Content + writer.write("candidate_done", map[string]any{"candidate_id": candidate.ID, "content": candidate.Content}) + return weaverRunResult{Candidate: candidate, Usage: usage} +} + +func (s *Server) dispatchWeaverProviderJob(ctx context.Context, rawBody []byte, model, consumerKey string, estimatedPromptTokens, requestedMaxTokens int, allowedProviderSerials []string) (*weaverProviderJob, error) { + requestID := uuid.New().String() + pr := ®istry.PendingRequest{ + RequestID: requestID, + Model: model, + ConsumerKey: consumerKey, + EstimatedPromptTokens: estimatedPromptTokens, + RequestedMaxTokens: requestedMaxTokens, + AllowedProviderSerials: allowedProviderSerials, + AcceptedCh: make(chan struct{}, 1), + ChunkCh: make(chan string, chunkBufferSize), + CompleteCh: make(chan protocol.UsageInfo, 1), + ErrorCh: make(chan protocol.InferenceErrorMessage, 1), + } + + provider, decision := s.registry.ReserveProviderEx(model, pr) + if provider == nil { + outcome := "no provider available" + if decision.CapacityRejections > 0 && decision.CandidateCount == 0 { + outcome = "provider capacity unavailable" + } + return nil, errors.New(outcome) + } + reservedMicroUSD, err := s.reserveWeaverProviderJob(pr, model, estimatedPromptTokens, requestedMaxTokens) + if err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + return nil, err + } + pr.ReservedMicroUSD = reservedMicroUSD + if provider.PublicKey == "" { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("no provider with E2E encryption") + } + providerPubKey, err := e2e.ParsePublicKey(provider.PublicKey) + if err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("provider public key invalid") + } + sessionKeys, err := e2e.GenerateSessionKeys() + if err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("failed to generate session keys") + } + encrypted, err := e2e.Encrypt(rawBody, providerPubKey, sessionKeys) + if err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("failed to encrypt request") + } + wireMsg := map[string]any{ + "type": protocol.TypeInferenceRequest, + "request_id": requestID, + "encrypted_body": map[string]string{ + "ephemeral_public_key": encrypted.EphemeralPublicKey, + "ciphertext": encrypted.Ciphertext, + }, + } + data, err := json.Marshal(wireMsg) + if err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("failed to marshal request") + } + pr.SessionPrivKey = &sessionKeys.PrivateKey + if err := provider.Conn.Write(ctx, websocket.MessageText, data); err != nil { + provider.RemovePending(requestID) + s.registry.SetProviderIdle(provider.ID) + s.refundWeaverProviderJob(pr) + return nil, errors.New("failed to send request to provider") + } + return &weaverProviderJob{Provider: provider, Pending: pr}, nil +} + +func (s *Server) reserveWeaverProviderJob(pr *registry.PendingRequest, model string, promptTokens, maxTokens int) (int64, error) { + if s.billing == nil { + return 0, nil + } + reservedMicroUSD := s.reservationCost(model, promptTokens, maxTokens) + if err := s.ledger.Charge(pr.ConsumerKey, reservedMicroUSD, "reserve:"+pr.RequestID); err != nil { + return 0, errors.New("consumer balance is too low for weaver request") + } + return reservedMicroUSD, nil +} + +func (s *Server) refundWeaverProviderJob(pr *registry.PendingRequest) { + if pr == nil || pr.ReservedMicroUSD <= 0 { + return + } + _ = s.store.Credit(pr.ConsumerKey, pr.ReservedMicroUSD, store.LedgerRefund, pr.RequestID) + pr.ReservedMicroUSD = 0 +} + +func (s *Server) cleanupWeaverProviderJob(job *weaverProviderJob) { + if job == nil || job.Provider == nil || job.Pending == nil { + return + } + removed := job.Provider.RemovePending(job.Pending.RequestID) + if removed == nil { + return + } + s.refundWeaverProviderJob(removed) + s.registry.SetProviderIdle(job.Provider.ID) + s.sendProviderCancel(job.Provider, job.Pending.RequestID) +} + +func (s *Server) collectWeaverChunks(ctx context.Context, job *weaverProviderJob, onDelta func(string)) ([]string, protocol.UsageInfo, error) { + chunks := make([]string, 0, 16) + timer := time.NewTimer(inferenceTimeout) + defer timer.Stop() + for { + select { + case chunk, ok := <-job.Pending.ChunkCh: + if !ok { + var usage protocol.UsageInfo + select { + case u, ok := <-job.Pending.CompleteCh: + if ok { + usage = u + } + default: + } + return chunks, usage, nil + } + normalized := normalizeSSEChunk(chunk) + chunks = append(chunks, normalized) + onDelta(extractMessage([]string{normalized}).Content) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(inferenceTimeout) + case errMsg := <-job.Pending.ErrorCh: + if errMsg.Error == "" { + errMsg.Error = "provider error" + } + s.refundWeaverProviderJob(job.Pending) + return nil, protocol.UsageInfo{}, errors.New(errMsg.Error) + case <-timer.C: + return nil, protocol.UsageInfo{}, errors.New("request timed out") + case <-ctx.Done(): + return nil, protocol.UsageInfo{}, ctx.Err() + } + } +} + +func (s *Server) selectWeaverVerifierModels(baseModel string, count int) []string { + if count <= 0 { + return []string{baseModel} + } + + online := make(map[string]struct{}) + for _, m := range s.registry.ListModels() { + online[m.ID] = struct{}{} + } + hasOnline := len(online) > 0 + isOnline := func(id string) bool { + if !hasOnline { + return true + } + _, ok := online[id] + return ok + } + + catalogModels := s.store.ListSupportedModels() + if len(catalogModels) == 0 { + return selectOnlineVerifierModels(s.registry.ListModels(), baseModel, count) + } + + var baseFamily string + eligible := make([]store.SupportedModel, 0, len(catalogModels)) + for _, m := range catalogModels { + if m.ID == baseModel { + baseFamily = strings.TrimSpace(m.Family) + } + if !m.Active || IsRetiredProviderModel(m) || !isTextCatalogModel(m.ModelType) || !m.CanVerify { + continue + } + if m.ID == baseModel || !isOnline(m.ID) { + continue + } + eligible = append(eligible, m) + } + + if len(eligible) == 0 { + s.logger.Warn("weaver: no eligible verifier models, falling back to self-verification", + "base_model", baseModel, + "catalog_size", len(catalogModels)) + return []string{baseModel} + } + + sort.SliceStable(eligible, func(i, j int) bool { + di := isDifferentVerifierFamily(eligible[i].Family, baseFamily) + dj := isDifferentVerifierFamily(eligible[j].Family, baseFamily) + if di != dj { + return di + } + si := verifierSizeSortKey(eligible[i].SizeGB) + sj := verifierSizeSortKey(eligible[j].SizeGB) + if si != sj { + return si < sj + } + return eligible[i].ID < eligible[j].ID + }) + + limit := count + if limit > len(eligible) { + limit = len(eligible) + } + ids := make([]string, 0, limit) + for _, m := range eligible[:limit] { + ids = append(ids, m.ID) + } + return ids +} + +func selectOnlineVerifierModels(models []registry.AggregateModel, baseModel string, count int) []string { + ids := make([]string, 0, len(models)) + for _, m := range models { + if m.ID == baseModel || !isTextCatalogModel(m.ModelType) { + continue + } + ids = append(ids, m.ID) + } + if len(ids) == 0 { + return []string{baseModel} + } + sort.Strings(ids) + if count < len(ids) { + return ids[:count] + } + return ids +} + +func isTextCatalogModel(modelType string) bool { + modelType = strings.ToLower(strings.TrimSpace(modelType)) + return modelType == "" || + modelType == "text" || + modelType == "chat" || + strings.Contains(modelType, "text") || + strings.Contains(modelType, "chat") || + strings.Contains(modelType, "completion") || + strings.Contains(modelType, "test") +} + +func isDifferentVerifierFamily(family, baseFamily string) bool { + family = strings.TrimSpace(family) + baseFamily = strings.TrimSpace(baseFamily) + return family != "" && baseFamily != "" && family != baseFamily +} + +func verifierSizeSortKey(sizeGB float64) float64 { + if sizeGB <= 0 { + return math.MaxFloat64 + } + return sizeGB +} + +func (s *Server) runWeaverVerifier(ctx context.Context, candidate weaverCandidate, verifierModel, consumerKey string, estimatedPromptTokens, requestedMaxTokens int, allowedProviderSerials []string) (weaverVerifierScore, protocol.UsageInfo) { + body := map[string]any{ + "model": verifierModel, + "stream": true, + "max_tokens": weaverVerifierMaxTokens, + "messages": []map[string]string{ + { + "role": "system", + "content": "Score the candidate answer from 0 to 10 for correctness, completeness, and clarity. Return only strict JSON: {\"score\": number, \"rationale\": string}. Do not include hidden reasoning.", + }, + { + "role": "user", + "content": fmt.Sprintf("Candidate %s answer:\n\n%s", candidate.Label, candidate.Content), + }, + }, + } + raw, _ := json.Marshal(body) + verifierPromptTokens := estimatePromptTokens(body) + score := weaverVerifierScore{ + CandidateID: candidate.ID, + VerifierModel: verifierModel, + Score: 0, + Rationale: "abstained", + } + job, err := s.dispatchWeaverProviderJob(ctx, raw, verifierModel, consumerKey, verifierPromptTokens, weaverVerifierMaxTokens, allowedProviderSerials) + if err != nil { + score.Rationale = "abstained: " + err.Error() + return score, protocol.UsageInfo{} + } + defer s.cleanupWeaverProviderJob(job) + + chunks, usage, err := s.collectWeaverChunks(ctx, job, func(string) {}) + if err != nil { + score.Rationale = "abstained: " + err.Error() + return score, usage + } + msg := extractMessage(chunks) + var parsed struct { + Score float64 `json:"score"` + Rationale string `json:"rationale"` + } + if err := json.Unmarshal([]byte(strings.TrimSpace(msg.Content)), &parsed); err != nil { + score.Rationale = "abstained: malformed verifier output" + return score, usage + } + if math.IsNaN(parsed.Score) || math.IsInf(parsed.Score, 0) { + score.Rationale = "abstained: invalid verifier score" + return score, usage + } + if parsed.Score < 0 { + parsed.Score = 0 + } + if parsed.Score > 10 { + parsed.Score = 10 + } + score.Score = parsed.Score + score.Rationale = parsed.Rationale + return score, usage +} + +func selectWeaverWinner(results []weaverRunResult, scores []weaverVerifierScore) (weaverRunResult, float64) { + byID := make(map[string]weaverRunResult, len(results)) + for _, result := range results { + byID[result.Candidate.ID] = result + } + type avgScore struct { + id string + avg float64 + count int + index int + } + avgs := make([]avgScore, 0, len(results)) + for _, result := range results { + var sum float64 + var count int + for _, score := range scores { + if score.CandidateID == result.Candidate.ID && score.Score > 0 { + sum += score.Score + count++ + } + } + avg := 0.0 + if count > 0 { + avg = sum / float64(count) + } + avgs = append(avgs, avgScore{id: result.Candidate.ID, avg: avg, count: count, index: result.Candidate.Index}) + } + sort.SliceStable(avgs, func(i, j int) bool { + if avgs[i].avg == avgs[j].avg { + return avgs[i].index < avgs[j].index + } + return avgs[i].avg > avgs[j].avg + }) + if len(avgs) == 0 { + return results[0], 0 + } + best := avgs[0] + confidence := best.avg / 10 + if len(avgs) > 1 && best.avg > 0 { + gap := best.avg - avgs[1].avg + confidence = math.Max(0.5, math.Min(0.99, 0.5+gap/10)) + } + if best.count == 0 { + confidence = 0 + } + return byID[best.id], confidence +} diff --git a/coordinator/internal/api/weaver_test.go b/coordinator/internal/api/weaver_test.go new file mode 100644 index 000000000..e4fe73388 --- /dev/null +++ b/coordinator/internal/api/weaver_test.go @@ -0,0 +1,525 @@ +package api + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/eigeninference/coordinator/internal/payments" + "github.com/eigeninference/coordinator/internal/protocol" + "github.com/eigeninference/coordinator/internal/registry" + "github.com/eigeninference/coordinator/internal/store" + "nhooyr.io/websocket" +) + +func TestParseWeaverConfigDefaults(t *testing.T) { + cfg, ok, err := parseWeaverConfig(map[string]any{ + "weaver": map[string]any{"mode": "deep"}, + }) + if err != nil { + t.Fatalf("parseWeaverConfig: %v", err) + } + if !ok { + t.Fatal("expected weaver config") + } + if cfg.CandidateCount != 4 { + t.Fatalf("candidate count = %d, want 4", cfg.CandidateCount) + } + if cfg.VerifierCount != 3 { + t.Fatalf("verifier count = %d, want 3", cfg.VerifierCount) + } +} + +func TestSelectWeaverWinnerDeterministicTie(t *testing.T) { + results := []weaverRunResult{ + {Candidate: weaverCandidate{ID: "candidate-1", Index: 0, Content: "a"}}, + {Candidate: weaverCandidate{ID: "candidate-2", Index: 1, Content: "b"}}, + } + scores := []weaverVerifierScore{ + {CandidateID: "candidate-2", Score: 8}, + {CandidateID: "candidate-1", Score: 8}, + } + winner, _ := selectWeaverWinner(results, scores) + if winner.Candidate.ID != "candidate-1" { + t.Fatalf("winner = %s, want candidate-1", winner.Candidate.ID) + } +} + +func TestSelectWeaverVerifierModelsUsesCatalogDiversityAndSize(t *testing.T) { + srv := weaverSelectionServer(t, []store.SupportedModel{ + {ID: "base", DisplayName: "Base", ModelType: "text", SizeGB: 27, Family: "qwen", CanVerify: true, Active: true}, + {ID: "same-family-small", DisplayName: "Same", ModelType: "text", SizeGB: 4, Family: "qwen", CanVerify: true, Active: true}, + {ID: "different-family-large", DisplayName: "Large", ModelType: "text", SizeGB: 26, Family: "trinity", CanVerify: true, Active: true}, + {ID: "different-family-small", DisplayName: "Small", ModelType: "text", SizeGB: 8, Family: "gemma", CanVerify: true, Active: true}, + {ID: "disabled-verifier", DisplayName: "Disabled", ModelType: "text", SizeGB: 1, Family: "minimax", CanVerify: false, Active: true}, + {ID: "image-model", DisplayName: "Image", ModelType: "image", SizeGB: 1, Family: "image", CanVerify: true, Active: true}, + }) + + got := srv.selectWeaverVerifierModels("base", 3) + want := []string{"different-family-small", "different-family-large", "same-family-small"} + if strings.Join(got, ",") != strings.Join(want, ",") { + t.Fatalf("verifiers = %v, want %v", got, want) + } +} + +func TestSelectWeaverVerifierModelsFallsBackToBase(t *testing.T) { + srv := weaverSelectionServer(t, []store.SupportedModel{ + {ID: "base", DisplayName: "Base", ModelType: "text", SizeGB: 27, Family: "qwen", CanVerify: true, Active: true}, + {ID: "disabled-verifier", DisplayName: "Disabled", ModelType: "text", SizeGB: 1, Family: "gemma", CanVerify: false, Active: true}, + }) + + got := srv.selectWeaverVerifierModels("base", 3) + if len(got) != 1 || got[0] != "base" { + t.Fatalf("verifiers = %v, want [base]", got) + } +} + +func weaverSelectionServer(t *testing.T, models []store.SupportedModel) *Server { + t.Helper() + + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + reg.MinTrustLevel = registry.TrustNone + srv := NewServer(reg, st, logger) + + for i := range models { + if err := st.SetSupportedModel(&models[i]); err != nil { + t.Fatalf("SetSupportedModel(%q): %v", models[i].ID, err) + } + } + srv.SyncModelCatalog() + + for _, model := range models { + if !model.Active { + continue + } + registerSelectionProvider(t, reg, model.ID) + } + return srv +} + +func registerSelectionProvider(t *testing.T, reg *registry.Registry, modelID string) { + t.Helper() + providerID := "provider-" + strings.NewReplacer("/", "-", " ", "-").Replace(modelID) + msg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ + MachineModel: "Mac15,8", + ChipName: "Apple M3 Max", + MemoryGB: 64, + }, + Models: []protocol.ModelInfo{ + {ID: modelID, SizeBytes: 1000, ModelType: "text", Quantization: "4bit"}, + }, + Backend: "inprocess-mlx", + PublicKey: testPublicKeyB64(), + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + } + reg.Register(providerID, nil, msg) + reg.RecordChallengeSuccess(providerID) +} + +func TestWeaverStreamingE2E(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/provider" + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("websocket dial: %v", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + pubKey := testPublicKeyB64() + regMsg := protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ + MachineModel: "Mac15,8", + ChipName: "Apple M3 Max", + MemoryGB: 64, + }, + Models: []protocol.ModelInfo{ + {ID: "test-model", SizeBytes: 1000, ModelType: "test", Quantization: "4bit"}, + }, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + } + regData, _ := json.Marshal(regMsg) + if err := conn.Write(ctx, websocket.MessageText, regData); err != nil { + t.Fatalf("write register: %v", err) + } + + time.Sleep(100 * time.Millisecond) + for _, id := range reg.ProviderIDs() { + reg.SetTrustLevel(id, registry.TrustHardware) + reg.RecordChallengeSuccess(id) + } + + providerDone := make(chan struct{}) + go func() { + defer close(providerDone) + served := 0 + for served < 2 { + _, data, err := conn.Read(ctx) + if err != nil { + t.Errorf("provider read: %v", err) + return + } + var raw map[string]any + if err := json.Unmarshal(data, &raw); err == nil { + msgType, _ := raw["type"].(string) + if msgType == protocol.TypeAttestationChallenge { + respData := makeValidChallengeResponse(data, pubKey) + _ = conn.Write(ctx, websocket.MessageText, respData) + continue + } + if msgType == protocol.TypeRuntimeStatus || msgType == protocol.TypeCancel { + continue + } + } + var inferReq protocol.InferenceRequestMessage + if err := json.Unmarshal(data, &inferReq); err != nil { + t.Errorf("unmarshal inference request: %v", err) + return + } + if served == 0 { + writeEncryptedTestChunk(t, ctx, conn, inferReq, pubKey, + `data: {"id":"chatcmpl-candidate","choices":[{"delta":{"content":"Answer A"}}]}`+"\n\n") + } else { + writeEncryptedTestChunk(t, ctx, conn, inferReq, pubKey, + `data: {"id":"chatcmpl-verifier","choices":[{"delta":{"content":"{\"score\":9,\"rationale\":\"clear\"}"}}]}`+"\n\n") + } + complete := protocol.InferenceCompleteMessage{ + Type: protocol.TypeInferenceComplete, + RequestID: inferReq.RequestID, + Usage: protocol.UsageInfo{PromptTokens: 3, CompletionTokens: 2}, + } + completeData, _ := json.Marshal(complete) + if err := conn.Write(ctx, websocket.MessageText, completeData); err != nil { + t.Errorf("write complete: %v", err) + return + } + served++ + } + }() + + chatBody := `{"model":"test-model","messages":[{"role":"user","content":"hi"}],"stream":true,"weaver":{"mode":"deep","candidate_count":1,"verifier_count":1,"trace":true}}` + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, ts.URL+"/v1/chat/completions", strings.NewReader(chatBody)) + httpReq.Header.Set("Authorization", "Bearer test-key") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatalf("http request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, body = %s", resp.StatusCode, body) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + response := string(body) + for _, want := range []string{"event: weaver_init", "event: candidate_delta", "event: verifier_score", "event: weaver_final", "Answer A"} { + if !strings.Contains(response, want) { + t.Fatalf("response missing %q:\n%s", want, response) + } + } + <-providerDone + + if got := len(st.UsageRecords()); got != 2 { + t.Fatalf("usage records = %d, want 2", got) + } +} + +func TestWeaverAllCandidatesFailTerminates(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + chatBody := `{"model":"missing-model","messages":[{"role":"user","content":"hi"}],"stream":true,"weaver":{"mode":"deep","candidate_count":2,"verifier_count":1,"trace":true}}` + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, ts.URL+"/v1/chat/completions", strings.NewReader(chatBody)) + httpReq.Header.Set("Authorization", "Bearer test-key") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatalf("http request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, body = %s", resp.StatusCode, body) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + response := string(body) + for _, want := range []string{"event: candidate_error", "event: weaver_status\ndata: {\"status\":\"failed\"}", "event: done\ndata: {\"status\":\"failed\"}"} { + if !strings.Contains(response, want) { + t.Fatalf("response missing %q:\n%s", want, response) + } + } +} + +func TestWeaverMalformedVerifierOutputAbstains(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/provider" + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("websocket dial: %v", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + pubKey := testPublicKeyB64() + regMsg := protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ + MachineModel: "Mac15,8", + ChipName: "Apple M3 Max", + MemoryGB: 64, + }, + Models: []protocol.ModelInfo{ + {ID: "test-model", SizeBytes: 1000, ModelType: "test", Quantization: "4bit"}, + }, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + } + regData, _ := json.Marshal(regMsg) + if err := conn.Write(ctx, websocket.MessageText, regData); err != nil { + t.Fatalf("write register: %v", err) + } + + time.Sleep(100 * time.Millisecond) + for _, id := range reg.ProviderIDs() { + reg.SetTrustLevel(id, registry.TrustHardware) + reg.RecordChallengeSuccess(id) + } + + providerDone := make(chan struct{}) + go func() { + defer close(providerDone) + served := 0 + for served < 2 { + _, data, err := conn.Read(ctx) + if err != nil { + t.Errorf("provider read: %v", err) + return + } + var raw map[string]any + if err := json.Unmarshal(data, &raw); err == nil { + msgType, _ := raw["type"].(string) + if msgType == protocol.TypeAttestationChallenge { + respData := makeValidChallengeResponse(data, pubKey) + _ = conn.Write(ctx, websocket.MessageText, respData) + continue + } + if msgType == protocol.TypeRuntimeStatus || msgType == protocol.TypeCancel { + continue + } + } + var inferReq protocol.InferenceRequestMessage + if err := json.Unmarshal(data, &inferReq); err != nil { + t.Errorf("unmarshal inference request: %v", err) + return + } + chunk := `data: {"id":"chatcmpl-candidate","choices":[{"delta":{"content":"Answer A"}}]}` + "\n\n" + if served == 1 { + chunk = `data: {"id":"chatcmpl-verifier","choices":[{"delta":{"content":"not json"}}]}` + "\n\n" + } + writeEncryptedTestChunk(t, ctx, conn, inferReq, pubKey, chunk) + complete := protocol.InferenceCompleteMessage{ + Type: protocol.TypeInferenceComplete, + RequestID: inferReq.RequestID, + Usage: protocol.UsageInfo{PromptTokens: 3, CompletionTokens: 2}, + } + completeData, _ := json.Marshal(complete) + if err := conn.Write(ctx, websocket.MessageText, completeData); err != nil { + t.Errorf("write complete: %v", err) + return + } + served++ + } + }() + + chatBody := `{"model":"test-model","messages":[{"role":"user","content":"hi"}],"stream":true,"weaver":{"mode":"deep","candidate_count":1,"verifier_count":1,"trace":true}}` + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, ts.URL+"/v1/chat/completions", strings.NewReader(chatBody)) + httpReq.Header.Set("Authorization", "Bearer test-key") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatalf("http request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, body = %s", resp.StatusCode, body) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + response := string(body) + if !strings.Contains(response, "abstained: malformed verifier output") { + t.Fatalf("response missing malformed abstention:\n%s", response) + } + if !strings.Contains(response, "event: weaver_final") { + t.Fatalf("response missing final event:\n%s", response) + } + <-providerDone +} + +func TestWeaverBillingChargesGenerationAndVerifierCalls(t *testing.T) { + srv, _, ledger := billingTestServer(t) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + const consumerID = "test-key" + const model = "billing-weaver-model" + initialBalance := ledger.Balance(consumerID) + if initialBalance <= 0 { + t.Fatalf("initial balance = %d, want > 0", initialBalance) + } + + conn, _, pubKey := setupProviderForBilling(t, ctx, ts, srv.registry, model) + defer conn.Close(websocket.StatusNormalClosure, "") + + candidateUsage := protocol.UsageInfo{PromptTokens: 10, CompletionTokens: 5} + verifierUsage := protocol.UsageInfo{PromptTokens: 20, CompletionTokens: 3} + providerDone := serveWeaverBillingInferences(ctx, t, conn, pubKey, candidateUsage, verifierUsage) + + chatBody := `{"model":"` + model + `","messages":[{"role":"user","content":"hi"}],"stream":true,"max_tokens":16,"weaver":{"mode":"deep","candidate_count":1,"verifier_count":1,"trace":true}}` + httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, ts.URL+"/v1/chat/completions", strings.NewReader(chatBody)) + httpReq.Header.Set("Authorization", "Bearer "+consumerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatalf("http request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, body = %s", resp.StatusCode, body) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + response := string(body) + for _, want := range []string{"event: weaver_final", "billing candidate"} { + if !strings.Contains(response, want) { + t.Fatalf("response missing %q:\n%s", want, response) + } + } + <-providerDone + + expectedCost := payments.CalculateCost(model, candidateUsage.PromptTokens, candidateUsage.CompletionTokens) + + payments.CalculateCost(model, verifierUsage.PromptTokens, verifierUsage.CompletionTokens) + actualBalance := ledger.Balance(consumerID) + if actualBalance != initialBalance-expectedCost { + t.Fatalf("consumer balance = %d, want %d", actualBalance, initialBalance-expectedCost) + } + + usageEntries := ledger.Usage(consumerID) + if len(usageEntries) != 2 { + t.Fatalf("usage entries = %d, want 2", len(usageEntries)) + } + var actualCost int64 + for _, entry := range usageEntries { + actualCost += entry.CostMicroUSD + } + if actualCost != expectedCost { + t.Fatalf("usage cost = %d, want %d", actualCost, expectedCost) + } +} + +func serveWeaverBillingInferences(ctx context.Context, t *testing.T, conn *websocket.Conn, pubKey string, candidateUsage, verifierUsage protocol.UsageInfo) <-chan struct{} { + t.Helper() + done := make(chan struct{}) + go func() { + defer close(done) + served := 0 + for served < 2 { + _, data, err := conn.Read(ctx) + if err != nil { + t.Errorf("provider read: %v", err) + return + } + var raw map[string]any + if err := json.Unmarshal(data, &raw); err == nil { + msgType, _ := raw["type"].(string) + if msgType == protocol.TypeAttestationChallenge { + respData := makeValidChallengeResponse(data, pubKey) + _ = conn.Write(ctx, websocket.MessageText, respData) + continue + } + if msgType == protocol.TypeRuntimeStatus || msgType == protocol.TypeCancel { + continue + } + } + var inferReq protocol.InferenceRequestMessage + if err := json.Unmarshal(data, &inferReq); err != nil { + t.Errorf("unmarshal inference request: %v", err) + return + } + usage := candidateUsage + chunk := `data: {"id":"chatcmpl-candidate","choices":[{"delta":{"content":"billing candidate"}}]}` + "\n\n" + if served == 1 { + usage = verifierUsage + chunk = `data: {"id":"chatcmpl-verifier","choices":[{"delta":{"content":"{\"score\":9,\"rationale\":\"clear\"}"}}]}` + "\n\n" + } + writeEncryptedTestChunk(t, ctx, conn, inferReq, pubKey, chunk) + complete := protocol.InferenceCompleteMessage{ + Type: protocol.TypeInferenceComplete, + RequestID: inferReq.RequestID, + Usage: usage, + } + completeData, _ := json.Marshal(complete) + if err := conn.Write(ctx, websocket.MessageText, completeData); err != nil { + t.Errorf("write complete: %v", err) + return + } + served++ + } + }() + return done +} diff --git a/coordinator/internal/registry/registry.go b/coordinator/internal/registry/registry.go index 499de8f8a..c225719e4 100644 --- a/coordinator/internal/registry/registry.go +++ b/coordinator/internal/registry/registry.go @@ -580,6 +580,8 @@ type CatalogEntry struct { ID string WeightHash string // expected SHA-256 weight fingerprint (empty = not enforced) SizeGB float64 // disk/GPU footprint of the model weights (zero = unknown, gate disabled) + Family string // model family for verifier diversity (empty = unknown) + CanVerify bool // whether this text model can act as a Weaver verifier } // SetModelCatalog updates the set of active models. Only models in this diff --git a/coordinator/internal/store/interface.go b/coordinator/internal/store/interface.go index 4c2322190..a39a547aa 100644 --- a/coordinator/internal/store/interface.go +++ b/coordinator/internal/store/interface.go @@ -550,6 +550,8 @@ type SupportedModel struct { DisplayName string `json:"display_name"` // Human-readable (e.g. "Qwen3.5 9B") ModelType string `json:"model_type"` // "text", "embedding", "tts" SizeGB float64 `json:"size_gb"` // Disk/memory size in GB + Family string `json:"family"` // Model family for verifier diversity (empty = unknown) + CanVerify bool `json:"can_verify"` // Whether this text model can act as a Weaver verifier Architecture string `json:"architecture"` // e.g. "9B dense", "2B conformer" Description string `json:"description"` // e.g. "Balanced", "Fast reasoning" MinRAMGB int `json:"min_ram_gb"` // Minimum system RAM for auto-selection diff --git a/coordinator/internal/store/postgres.go b/coordinator/internal/store/postgres.go index dbea9e43c..ea9a99969 100644 --- a/coordinator/internal/store/postgres.go +++ b/coordinator/internal/store/postgres.go @@ -259,6 +259,8 @@ func (s *PostgresStore) migrate(ctx context.Context) error { display_name TEXT NOT NULL DEFAULT '', model_type TEXT NOT NULL DEFAULT 'text', size_gb DOUBLE PRECISION NOT NULL DEFAULT 0, + family TEXT NOT NULL DEFAULT '', + can_verify BOOLEAN NOT NULL DEFAULT TRUE, architecture TEXT NOT NULL DEFAULT '', description TEXT NOT NULL DEFAULT '', min_ram_gb INTEGER NOT NULL DEFAULT 0, @@ -275,6 +277,14 @@ func (s *PostgresStore) migrate(ctx context.Context) error { ALTER TABLE supported_models ADD COLUMN IF NOT EXISTS weight_hash TEXT NOT NULL DEFAULT ''; EXCEPTION WHEN others THEN NULL; END $$`, + `DO $$ BEGIN + ALTER TABLE supported_models ADD COLUMN IF NOT EXISTS family TEXT NOT NULL DEFAULT ''; + EXCEPTION WHEN others THEN NULL; + END $$`, + `DO $$ BEGIN + ALTER TABLE supported_models ADD COLUMN IF NOT EXISTS can_verify BOOLEAN NOT NULL DEFAULT TRUE; + EXCEPTION WHEN others THEN NULL; + END $$`, // Releases (provider binary versioning) `CREATE TABLE IF NOT EXISTS releases ( @@ -1607,13 +1617,15 @@ func (s *PostgresStore) SetSupportedModel(model *SupportedModel) error { defer cancel() _, err := s.pool.Exec(ctx, - `INSERT INTO supported_models (id, s3_name, display_name, model_type, size_gb, architecture, description, min_ram_gb, active, weight_hash, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) + `INSERT INTO supported_models (id, s3_name, display_name, model_type, size_gb, family, can_verify, architecture, description, min_ram_gb, active, weight_hash, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NOW()) ON CONFLICT (id) DO UPDATE SET - s3_name = $2, display_name = $3, model_type = $4, size_gb = $5, architecture = $6, - description = $7, min_ram_gb = $8, active = $9, weight_hash = $10, updated_at = NOW()`, + s3_name = $2, display_name = $3, model_type = $4, size_gb = $5, family = $6, + can_verify = $7, architecture = $8, description = $9, min_ram_gb = $10, + active = $11, weight_hash = $12, updated_at = NOW()`, model.ID, model.S3Name, model.DisplayName, model.ModelType, model.SizeGB, - model.Architecture, model.Description, model.MinRAMGB, model.Active, model.WeightHash, + model.Family, model.CanVerify, model.Architecture, model.Description, model.MinRAMGB, + model.Active, model.WeightHash, ) if err != nil { return fmt.Errorf("store: set supported model: %w", err) @@ -1626,7 +1638,7 @@ func (s *PostgresStore) ListSupportedModels() []SupportedModel { defer cancel() rows, err := s.pool.Query(ctx, - `SELECT id, s3_name, display_name, model_type, size_gb, architecture, description, min_ram_gb, active, weight_hash + `SELECT id, s3_name, display_name, model_type, size_gb, family, can_verify, architecture, description, min_ram_gb, active, weight_hash FROM supported_models ORDER BY model_type ASC, min_ram_gb ASC, size_gb ASC`, ) if err != nil { @@ -1638,7 +1650,8 @@ func (s *PostgresStore) ListSupportedModels() []SupportedModel { for rows.Next() { var m SupportedModel if err := rows.Scan(&m.ID, &m.S3Name, &m.DisplayName, &m.ModelType, &m.SizeGB, - &m.Architecture, &m.Description, &m.MinRAMGB, &m.Active, &m.WeightHash); err != nil { + &m.Family, &m.CanVerify, &m.Architecture, &m.Description, &m.MinRAMGB, + &m.Active, &m.WeightHash); err != nil { continue } models = append(models, m) diff --git a/coordinator/internal/store/store_test.go b/coordinator/internal/store/store_test.go index 338125c5e..34bd9a4ee 100644 --- a/coordinator/internal/store/store_test.go +++ b/coordinator/internal/store/store_test.go @@ -189,6 +189,8 @@ func TestSupportedModels(t *testing.T) { DisplayName: "Qwen2.5 0.5B", ModelType: "text", SizeGB: 0.5, + Family: "qwen", + CanVerify: true, Architecture: "0.5B dense", Description: "Lightweight chat model", MinRAMGB: 8, @@ -200,6 +202,8 @@ func TestSupportedModels(t *testing.T) { DisplayName: "Qwen3.5 9B", ModelType: "text", SizeGB: 6.0, + Family: "qwen", + CanVerify: true, Architecture: "9B dense", Description: "Balanced", MinRAMGB: 16, @@ -230,7 +234,10 @@ func TestSupportedModels(t *testing.T) { ID: m1.ID, S3Name: m1.S3Name, DisplayName: "Qwen2.5 0.5B (updated)", + ModelType: "text", SizeGB: 0.5, + Family: "qwen", + CanVerify: true, Architecture: m1.Architecture, Description: "Updated description", MinRAMGB: 8, @@ -252,6 +259,12 @@ func TestSupportedModels(t *testing.T) { if m.Active { t.Error("model should be inactive after update") } + if m.Family != "qwen" { + t.Errorf("family = %q, want qwen", m.Family) + } + if !m.CanVerify { + t.Error("can_verify should persist through update") + } } } diff --git a/provider/src/backend/vllm_mlx.rs b/provider/src/backend/vllm_mlx.rs index eed80713b..04ce958e7 100644 --- a/provider/src/backend/vllm_mlx.rs +++ b/provider/src/backend/vllm_mlx.rs @@ -90,20 +90,24 @@ impl VllmMlxBackend { args.push("--tool-call-parser".to_string()); args.push(tool_parser.to_string()); - // Extract ... into reasoning_content field instead of - // leaking thinking tags into the response content. + // Extract explicit reasoning markers only for known reasoning model + // families. Qwen2.x instruct models should stream normal content. let reasoning_parser = if model_lower.contains("gemma") { - "gemma4" // <|channel>thought\n... + Some("gemma4") // <|channel>thought\n... } else if model_lower.contains("deepseek") || model_lower.contains("trinity") || model_lower.contains("minimax") { - "deepseek_r1" // ... (supports implicit reasoning) + Some("deepseek_r1") // ... (supports implicit reasoning) + } else if model_lower.contains("qwen3") || model_lower.contains("qwq") { + Some("qwen3") } else { - "qwen3" // safe default: no-op if model doesn't output tags + None }; - args.push("--reasoning-parser".to_string()); - args.push(reasoning_parser.to_string()); + if let Some(reasoning_parser) = reasoning_parser { + args.push("--reasoning-parser".to_string()); + args.push(reasoning_parser.to_string()); + } args } @@ -233,13 +237,13 @@ mod tests { fn test_build_args_basic() { let backend = VllmMlxBackend::new("mlx-community/Qwen2.5-7B-4bit".into(), 8100, false); let args = backend.build_args(); - // Qwen model gets nemotron tool parser + qwen3 reasoning parser + // Qwen2 model gets nemotron tool parser without a Qwen3 reasoning parser. assert!(args.contains(&"serve".to_string())); assert!(args.contains(&"--enable-auto-tool-choice".to_string())); assert!(args.contains(&"--tool-call-parser".to_string())); assert!(args.contains(&"nemotron".to_string())); - assert!(args.contains(&"--reasoning-parser".to_string())); - assert!(args.contains(&"qwen3".to_string())); + assert!(!args.contains(&"--reasoning-parser".to_string())); + assert!(!args.contains(&"qwen3".to_string())); assert!(!args.contains(&"--continuous-batching".to_string())); } @@ -250,6 +254,14 @@ mod tests { assert!(args.contains(&"--continuous-batching".to_string())); assert!(args.contains(&"--enable-auto-tool-choice".to_string())); assert!(args.contains(&"nemotron".to_string())); + assert!(!args.contains(&"qwen3".to_string())); + } + + #[test] + fn test_build_args_qwen3_reasoning_model() { + let backend = VllmMlxBackend::new("mlx-community/Qwen3-8B-4bit".into(), 8100, false); + let args = backend.build_args(); + assert!(args.contains(&"--reasoning-parser".to_string())); assert!(args.contains(&"qwen3".to_string())); } @@ -279,12 +291,12 @@ mod tests { fn test_build_args_generic_model() { let backend = VllmMlxBackend::new("mlx-community/Llama-3-8B".into(), 8100, false); let args = backend.build_args(); - // Generic model gets auto tool parser + qwen3 reasoning (safe default) + // Generic model gets auto tool parser and no reasoning parser. assert!(args.contains(&"--enable-auto-tool-choice".to_string())); assert!(args.contains(&"--tool-call-parser".to_string())); assert!(args.contains(&"auto".to_string())); - assert!(args.contains(&"--reasoning-parser".to_string())); - assert!(args.contains(&"qwen3".to_string())); + assert!(!args.contains(&"--reasoning-parser".to_string())); + assert!(!args.contains(&"qwen3".to_string())); } #[test] diff --git a/provider/src/inference.rs b/provider/src/inference.rs index 187b64f70..a4472d122 100644 --- a/provider/src/inference.rs +++ b/provider/src/inference.rs @@ -501,7 +501,7 @@ try: from vllm_mlx.reasoning import get_parser _name_lower = (_model_name or "").lower() _parser_name = None - if "qwen" in _name_lower: + if "qwen3" in _name_lower or "qwq" in _name_lower: _parser_name = "qwen3" elif "gemma" in _name_lower: _parser_name = "gemma4" @@ -694,7 +694,7 @@ def _select_reasoning_parser(model_id): return None name = (model_id or "").lower() parser_name = None - if "qwen" in name: + if "qwen3" in name or "qwq" in name: parser_name = "qwen3" elif "gemma" in name: parser_name = "gemma4" diff --git a/provider/src/main.rs b/provider/src/main.rs index 256788570..a1e84250d 100644 --- a/provider/src/main.rs +++ b/provider/src/main.rs @@ -2938,7 +2938,7 @@ async fn cmd_serve( // (running requests, token counts, GPU memory). This data is included // in heartbeats so the coordinator can make informed routing decisions. if !using_inprocess { - if let Some(cap_arc) = backend_capacity_opt { + if let Some(cap_arc) = backend_capacity_opt.clone() { let poll_shared_slots = shared_slots.clone(); let total_mem_gb = hw.memory_gb as f64; let poll_backend_running = backend_running_flag_opt @@ -3238,6 +3238,7 @@ async fn cmd_serve( #[cfg(feature = "python")] let event_inprocess_engines = inprocess_engines.clone(); + let event_backend_capacity = backend_capacity_opt.clone(); let event_backend_running = backend_running_flag_opt.expect("backend_running_flag must be set in non-local mode"); let event_handle = tokio::spawn(async move { @@ -3246,8 +3247,10 @@ async fn cmd_serve( // Track in-flight inference tasks so we can cancel them on // coordinator disconnect or explicit cancel messages. - let mut inflight: HashMap)> = - HashMap::new(); + let mut inflight: HashMap< + String, + (CancellationToken, tokio::task::JoinHandle<()>, String), + > = HashMap::new(); let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<(String, bool)>(64); // Idle timeout: shut down the backend after a period of no @@ -3262,6 +3265,64 @@ async fn cmd_serve( let set_backend_state = |state: u8| { event_backend_running.store(state, std::sync::atomic::Ordering::Relaxed); }; + let update_inprocess_capacity = + |inflight: &HashMap< + String, + (CancellationToken, tokio::task::JoinHandle<()>, String), + >| { + if !is_inprocess { + return; + } + let Some(cap_arc) = event_backend_capacity.as_ref() else { + return; + }; + + let mut running_by_model: HashMap = HashMap::new(); + for (_, _, model_id) in inflight.values() { + *running_by_model.entry(model_id.clone()).or_insert(0) += 1; + } + + let backend_state = + event_backend_running.load(std::sync::atomic::Ordering::Relaxed); + let slots = { + let slots = shared_slots.lock().unwrap(); + slots + .iter() + .map(|slot| { + let state = if slot.restarting { + "reloading" + } else if !slot.healthy { + match backend_state { + BACKEND_IDLE_SHUTDOWN => "idle_shutdown", + _ => "crashed", + } + } else { + "running" + }; + + protocol::BackendSlotCapacity { + model: slot.model_id.clone(), + state: state.to_string(), + num_running: *running_by_model + .get(&slot.model_id) + .unwrap_or(&0), + num_waiting: 0, + active_tokens: 0, + max_tokens_potential: 0, + } + }) + .collect() + }; + + *cap_arc.lock().unwrap() = Some(protocol::BackendCapacity { + slots, + gpu_memory_active_gb: 0.0, + gpu_memory_peak_gb: 0.0, + gpu_memory_cache_gb: 0.0, + total_memory_gb: hw.memory_gb as f64, + }); + }; + update_inprocess_capacity(&inflight); loop { let idle_sleep = async { @@ -3289,11 +3350,12 @@ async fn cmd_serve( tracing::warn!( "Disconnected from coordinator — aborting {count} in-flight request(s)" ); - for (rid, (token, handle)) in inflight.drain() { + for (rid, (token, handle, _model_id)) in inflight.drain() { tracing::info!("Aborting request {rid} (coordinator disconnected)"); token.cancel(); handle.abort(); } + update_inprocess_capacity(&inflight); inference_active.store(false, std::sync::atomic::Ordering::Relaxed); } else { tracing::warn!("Disconnected from coordinator"); @@ -3623,12 +3685,14 @@ async fn cmd_serve( } }; - inflight.insert(request_id, (cancel_token, handle)); + inflight.insert(request_id, (cancel_token, handle, requested_model)); + update_inprocess_capacity(&inflight); } coordinator::CoordinatorEvent::Cancel { request_id } => { - if let Some((token, _handle)) = inflight.remove(&request_id) { + if let Some((token, _handle, _model_id)) = inflight.remove(&request_id) { tracing::info!("Cancelling request {request_id}"); token.cancel(); + update_inprocess_capacity(&inflight); if inflight.is_empty() { inference_active.store(false, std::sync::atomic::Ordering::Relaxed); } @@ -3688,6 +3752,7 @@ async fn cmd_serve( Some((rid, backend_dead)) = done_rx.recv() => { if inflight.remove(&rid).is_some() { tracing::debug!("Request {rid} completed, removed from tracker ({} in-flight)", inflight.len()); + update_inprocess_capacity(&inflight); if inflight.is_empty() { inference_active.store(false, std::sync::atomic::Ordering::Relaxed); } @@ -3715,6 +3780,7 @@ async fn cmd_serve( slot.restarting = false; } } + update_inprocess_capacity(&inflight); } else { shutdown_backends(&backend_pids).await; } @@ -3984,15 +4050,20 @@ fn spawn_inference_backend( cmd.args(["--tool-call-parser", tool_parser]); let reasoning_parser = if model_lower.contains("gemma") { - "gemma4" - } else if model_lower.contains("deepseek") || model_lower.contains("trinity") { - "deepseek_r1" - } else if model_lower.contains("minimax") { - "deepseek_r1" // MiniMax uses ... like DeepSeek + Some("gemma4") + } else if model_lower.contains("deepseek") + || model_lower.contains("trinity") + || model_lower.contains("minimax") + { + Some("deepseek_r1") + } else if model_lower.contains("qwen3") || model_lower.contains("qwq") { + Some("qwen3") } else { - "qwen3" + None }; - cmd.args(["--reasoning-parser", reasoning_parser]); + if let Some(reasoning_parser) = reasoning_parser { + cmd.args(["--reasoning-parser", reasoning_parser]); + } } let log_target = if module.contains("vllm_mlx") {