diff --git a/coordinator/internal/api/consumer.go b/coordinator/internal/api/consumer.go index 584dda01..5b78b26b 100644 --- a/coordinator/internal/api/consumer.go +++ b/coordinator/internal/api/consumer.go @@ -53,9 +53,16 @@ const ( // backend crashed, model not loaded after idle shutdown). maxDispatchAttempts = 3 - // firstChunkTimeout is how long to wait for the first chunk from a provider - // before considering the attempt failed and retrying. - firstChunkTimeout = 10 * time.Second + // defaultFirstResponseTimeout is the baseline TTFT deadline before the + // coordinator retries on another provider. Keep this short: retries are + // invisible until the first chunk, so this is our Valorant-style timing + // authority guardrail for fast TTFT. + defaultFirstResponseTimeout = 4 * time.Second + + // maxFirstResponseTimeout caps adaptive first-response deadlines for + // intentionally queued/backlogged routes. A route can earn extra time from + // its measured queue/backlog/network cost, but never an unbounded wait. + maxFirstResponseTimeout = 15 * time.Second // cancelWriteTimeout bounds how long a cancel write to the provider can // block. Using context.Background() unbounded here risks hanging the HTTP @@ -65,6 +72,18 @@ const ( var thinkBlockPattern = regexp.MustCompile(`(?is)(.*?)\s*`) +func firstResponseTimeout(decision registry.RoutingDecision) time.Duration { + costMs := decision.StateMs + decision.QueueMs + decision.PendingMs + decision.BacklogMs + decision.NetworkMs + if costMs <= 0 { + return defaultFirstResponseTimeout + } + adaptive := defaultFirstResponseTimeout + time.Duration(costMs*float64(time.Millisecond)) + if adaptive > maxFirstResponseTimeout { + return maxFirstResponseTimeout + } + return adaptive +} + // sendProviderCancel sends a Cancel message for the given request to the // provider with a bounded timeout so a half-dead WebSocket doesn't hang the // caller. Errors are logged at debug level because a disconnect race is the @@ -786,7 +805,7 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { // Wait for an accepted signal, first chunk, or error before committing. // No HTTP response has been written yet, so retries are invisible. - timer := time.NewTimer(firstChunkTimeout) + timer := time.NewTimer(firstResponseTimeout(decision)) accepted := false select { case <-pr.AcceptedCh: @@ -888,6 +907,7 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { } // Provider accepted or sent first chunk — commit to this provider. + s.dispatchStandbyPreloads(r.Context(), model, provider) // If only accepted (no chunk yet), wait for the first chunk with // the full inference timeout since the backend may be reloading. if accepted && !committed { diff --git a/coordinator/internal/api/consumer_test.go b/coordinator/internal/api/consumer_test.go index 6b6eeab2..4dea3bf4 100644 --- a/coordinator/internal/api/consumer_test.go +++ b/coordinator/internal/api/consumer_test.go @@ -32,6 +32,33 @@ func testServer(t *testing.T) (*Server, *store.MemoryStore) { return srv, st } +func TestFirstResponseTimeoutUsesFastDefaultForHealthyRoutes(t *testing.T) { + got := firstResponseTimeout(registry.RoutingDecision{}) + if got != defaultFirstResponseTimeout { + t.Fatalf("firstResponseTimeout(zero)=%s, want %s", got, defaultFirstResponseTimeout) + } +} + +func TestFirstResponseTimeoutAccountsForQueueAndNetworkButCaps(t *testing.T) { + decision := registry.RoutingDecision{ + StateMs: 500, + QueueMs: 3_000, + PendingMs: 750, + BacklogMs: 2_000, + NetworkMs: 1_000, + } + got := firstResponseTimeout(decision) + if got <= defaultFirstResponseTimeout { + t.Fatalf("firstResponseTimeout(degraded)=%s, want > %s", got, defaultFirstResponseTimeout) + } + + decision.BacklogMs = 100_000 + got = firstResponseTimeout(decision) + if got != maxFirstResponseTimeout { + t.Fatalf("firstResponseTimeout(huge backlog)=%s, want cap %s", got, maxFirstResponseTimeout) + } +} + func TestHealthEndpoint(t *testing.T) { srv, _ := testServer(t) diff --git a/coordinator/internal/api/prewarm.go b/coordinator/internal/api/prewarm.go new file mode 100644 index 00000000..627520be --- /dev/null +++ b/coordinator/internal/api/prewarm.go @@ -0,0 +1,90 @@ +package api + +import ( + "context" + "encoding/json" + "time" + + "github.com/eigeninference/coordinator/internal/protocol" + "github.com/eigeninference/coordinator/internal/registry" + "nhooyr.io/websocket" +) + +const ( + standbyPrewarmTimeout = 2 * time.Second + maxStandbyPrewarmHints = 2 +) + +func providerIsReadyForModel(p *registry.Provider, model string) bool { + if p == nil || model == "" { + return false + } + p.Mu().Lock() + defer p.Mu().Unlock() + if p.CurrentModel == model { + return true + } + for _, warm := range p.WarmModels { + if warm == model { + return true + } + } + return false +} + +func (s *Server) dispatchStandbyPreloads(ctx context.Context, model string, primary *registry.Provider) { + if s == nil || s.registry == nil || model == "" { + return + } + writer := s.providerPrewarmWriter + if writer == nil { + writer = s.writeProviderPrewarm + } + primaryID := "" + if primary != nil { + primaryID = primary.ID + } + count := 0 + s.registry.ForEachProvider(func(p *registry.Provider) { + if count >= maxStandbyPrewarmHints || p == nil || p.ID == primaryID { + return + } + p.Mu().Lock() + eligible := p.Backend == registry.BackendMLXSwift && + p.Status != registry.StatusOffline && p.Status != registry.StatusUntrusted && + p.RuntimeVerified && providerHasModelLocked(p, model) + p.Mu().Unlock() + if !eligible || providerIsReadyForModel(p, model) { + return + } + count++ + go func(provider *registry.Provider) { + prewarmCtx, cancel := context.WithTimeout(ctx, standbyPrewarmTimeout) + defer cancel() + msg := protocol.LoadModelMessage{Type: protocol.TypeLoadModel, ModelID: model} + if err := writer(prewarmCtx, provider, msg); err != nil && s.logger != nil { + s.logger.Debug("failed to dispatch standby preload", "provider_id", provider.ID, "model", model, "error", err) + } + }(p) + }) +} + +func providerHasModelLocked(p *registry.Provider, model string) bool { + for _, m := range p.Models { + if m.ID == model { + return true + } + } + return false +} + +func (s *Server) writeProviderPrewarm(ctx context.Context, p *registry.Provider, msg protocol.LoadModelMessage) error { + if p == nil || p.Conn == nil { + return nil + } + data, err := json.Marshal(msg) + if err != nil { + return err + } + return p.Conn.Write(ctx, websocket.MessageText, data) +} diff --git a/coordinator/internal/api/prewarm_test.go b/coordinator/internal/api/prewarm_test.go new file mode 100644 index 00000000..58c01ec0 --- /dev/null +++ b/coordinator/internal/api/prewarm_test.go @@ -0,0 +1,96 @@ +package api + +import ( + "context" + "log/slog" + "os" + "testing" + "time" + + "github.com/eigeninference/coordinator/internal/protocol" + "github.com/eigeninference/coordinator/internal/registry" +) + +func TestProviderIsReadyForModel(t *testing.T) { + p := ®istry.Provider{CurrentModel: "m1", WarmModels: []string{"m2"}} + if !providerIsReadyForModel(p, "m1") { + t.Fatal("current model should be ready") + } + if !providerIsReadyForModel(p, "m2") { + t.Fatal("warm model should be ready") + } + if providerIsReadyForModel(p, "m3") { + t.Fatal("unknown model should not be ready") + } +} + +func TestDispatchStandbyPreloadsColdSwiftProviders(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + reg := registry.New(logger) + model := "standby-model" + primary := ®istry.Provider{ + ID: "primary", + Backend: registry.BackendMLXSwift, + Models: []protocol.ModelInfo{{ID: model}}, + WarmModels: []string{model}, + Status: registry.StatusServing, + TrustLevel: registry.TrustHardware, + RuntimeVerified: true, + LastHeartbeat: time.Now(), + } + cold := ®istry.Provider{ + ID: "cold", + Backend: registry.BackendMLXSwift, + Models: []protocol.ModelInfo{{ID: model}}, + Status: registry.StatusOnline, + TrustLevel: registry.TrustHardware, + RuntimeVerified: true, + LastHeartbeat: time.Now(), + } + + primaryMsg := protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Models: []protocol.ModelInfo{{ID: model}}, + Backend: registry.BackendMLXSwift, + EncryptedResponseChunks: true, + PublicKey: "pub-primary", + PrivacyCapabilities: testPrivacyCaps(), + } + coldMsg := primaryMsg + coldMsg.PublicKey = "pub-cold" + primary = reg.Register("primary", nil, &primaryMsg) + cold = reg.Register("cold", nil, &coldMsg) + primary.WarmModels = []string{model} + primary.RuntimeVerified = true + cold.RuntimeVerified = true + reg.SetTrustLevel(primary.ID, registry.TrustHardware) + reg.SetTrustLevel(cold.ID, registry.TrustHardware) + reg.RecordChallengeSuccess(primary.ID) + reg.RecordChallengeSuccess(cold.ID) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + sent := make(chan protocol.LoadModelMessage, 1) + server := &Server{registry: reg, logger: logger, providerPrewarmWriter: func(ctx context.Context, p *registry.Provider, msg protocol.LoadModelMessage) error { + sent <- msg + return nil + }} + + server.dispatchStandbyPreloads(ctx, model, primary) + select { + case msg := <-sent: + if msg.Type != protocol.TypeLoadModel || msg.ModelID != model { + t.Fatalf("unexpected preload message: %+v", msg) + } + case <-time.After(time.Second): + t.Fatal("expected standby preload message") + } + + cold.WarmModels = []string{model} + server.dispatchStandbyPreloads(ctx, model, primary) + select { + case msg := <-sent: + t.Fatalf("unexpected preload for already-warm provider: %+v", msg) + default: + } +} diff --git a/coordinator/internal/api/server.go b/coordinator/internal/api/server.go index aad12623..2ebe4bcd 100644 --- a/coordinator/internal/api/server.go +++ b/coordinator/internal/api/server.go @@ -182,6 +182,10 @@ type Server struct { // and used by internal counters/histograms. Never nil. metrics *Metrics + // providerPrewarmWriter sends coordinator-driven preload hints. Tests can + // replace it to assert standby behavior without a real WebSocket. + providerPrewarmWriter func(context.Context, *registry.Provider, protocol.LoadModelMessage) error + // telemetryLimiter throttles telemetry ingestion per submitter. telemetryLimiter *telemetryLimiter @@ -230,15 +234,16 @@ func NewServer(reg *registry.Registry, st store.Store, logger *slog.Logger) *Ser reg.SetStore(st) s := &Server{ - registry: reg, - store: st, - ledger: payments.NewLedger(st), - logger: logger, - mux: http.NewServeMux(), - knownRuntimeManifest: &RuntimeManifest{}, - metrics: NewMetrics(), - telemetryLimiter: newTelemetryLimiter(), - readCache: newTTLCache(), + registry: reg, + store: st, + ledger: payments.NewLedger(st), + logger: logger, + mux: http.NewServeMux(), + knownRuntimeManifest: &RuntimeManifest{}, + metrics: NewMetrics(), + providerPrewarmWriter: nil, + telemetryLimiter: newTelemetryLimiter(), + readCache: newTTLCache(), } s.registerDefaultGauges() s.routes() diff --git a/coordinator/internal/api/telemetry_handlers.go b/coordinator/internal/api/telemetry_handlers.go index 44c73674..1e4e2518 100644 --- a/coordinator/internal/api/telemetry_handlers.go +++ b/coordinator/internal/api/telemetry_handlers.go @@ -57,13 +57,20 @@ var telemetryFieldAllowlist = map[string]struct{}{ "error": {}, "target": {}, // Provider / backend - "model": {}, - "backend": {}, - "exit_code": {}, - "signal": {}, - "hardware_chip": {}, - "memory_gb": {}, - "macos_version": {}, + "model": {}, + "backend": {}, + "queue_ms": {}, + "admit_ms": {}, + "prompt_tokens": {}, + "completion_tokens": {}, + "ttft_ms": {}, + "total_ms": {}, + "active_count": {}, + "exit_code": {}, + "signal": {}, + "hardware_chip": {}, + "memory_gb": {}, + "macos_version": {}, // Coordinator "handler": {}, "provider_id": {}, diff --git a/coordinator/internal/protocol/messages.go b/coordinator/internal/protocol/messages.go index ab94c4a1..41a33668 100644 --- a/coordinator/internal/protocol/messages.go +++ b/coordinator/internal/protocol/messages.go @@ -134,6 +134,7 @@ type HeartbeatMessage struct { Stats HeartbeatStats `json:"stats"` WarmModels []string `json:"warm_models,omitempty"` // models currently loaded in memory SystemMetrics SystemMetrics `json:"system_metrics"` // live resource utilization + NetworkQuality NetworkQuality `json:"network_quality"` // provider-observed coordinator transport quality BackendCapacity *BackendCapacity `json:"backend_capacity,omitempty"` // live backend capacity (nil for old providers) } @@ -166,6 +167,17 @@ type SystemMetrics struct { ThermalState string `json:"thermal_state"` // nominal, fair, serious, critical } +// NetworkQuality contains provider-observed coordinator WebSocket transport +// health. All fields default to zero for backward compatibility with older +// providers, and zero means "no measured penalty" rather than "bad". +type NetworkQuality struct { + RTTMs float64 `json:"rtt_ms"` // latest WebSocket ping/pong round-trip time + JitterMs float64 `json:"jitter_ms"` // absolute delta between latest and previous RTT + ReconnectCount int64 `json:"reconnect_count"` // reconnect attempts since provider process start + WebSocketWriteFailures int64 `json:"websocket_write_failures"` // failed WebSocket writes since provider process start + LastWriteLatencyMs float64 `json:"last_write_latency_ms"` // duration of most recent successful WebSocket write +} + // HeartbeatStats contains counters reported in heartbeats. type HeartbeatStats struct { RequestsServed int64 `json:"requests_served"` diff --git a/coordinator/internal/protocol/network_quality_test.go b/coordinator/internal/protocol/network_quality_test.go new file mode 100644 index 00000000..babe7da2 --- /dev/null +++ b/coordinator/internal/protocol/network_quality_test.go @@ -0,0 +1,61 @@ +package protocol + +import ( + "encoding/json" + "testing" +) + +func TestHeartbeatNetworkQualityRoundTrip(t *testing.T) { + msg := HeartbeatMessage{ + Type: TypeHeartbeat, + Status: "idle", + Stats: HeartbeatStats{}, + NetworkQuality: NetworkQuality{ + RTTMs: 125.5, + JitterMs: 42.25, + ReconnectCount: 3, + WebSocketWriteFailures: 2, + LastWriteLatencyMs: 17.75, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var decoded HeartbeatMessage + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if decoded.NetworkQuality.RTTMs != 125.5 { + t.Fatalf("rtt_ms=%f, want 125.5", decoded.NetworkQuality.RTTMs) + } + if decoded.NetworkQuality.JitterMs != 42.25 { + t.Fatalf("jitter_ms=%f, want 42.25", decoded.NetworkQuality.JitterMs) + } + if decoded.NetworkQuality.ReconnectCount != 3 { + t.Fatalf("reconnect_count=%d, want 3", decoded.NetworkQuality.ReconnectCount) + } + if decoded.NetworkQuality.WebSocketWriteFailures != 2 { + t.Fatalf("websocket_write_failures=%d, want 2", decoded.NetworkQuality.WebSocketWriteFailures) + } + if decoded.NetworkQuality.LastWriteLatencyMs != 17.75 { + t.Fatalf("last_write_latency_ms=%f, want 17.75", decoded.NetworkQuality.LastWriteLatencyMs) + } +} + +func TestHeartbeatMissingNetworkQualityDefaultsToZero(t *testing.T) { + raw := `{"type":"heartbeat","status":"idle","active_model":null,"stats":{"requests_served":0,"tokens_generated":0}}` + + var pm ProviderMessage + if err := json.Unmarshal([]byte(raw), &pm); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + hb := pm.Payload.(*HeartbeatMessage) + if hb.NetworkQuality != (NetworkQuality{}) { + t.Fatalf("NetworkQuality=%+v, want zero value", hb.NetworkQuality) + } +} diff --git a/coordinator/internal/registry/network_quality_test.go b/coordinator/internal/registry/network_quality_test.go new file mode 100644 index 00000000..ae8321ab --- /dev/null +++ b/coordinator/internal/registry/network_quality_test.go @@ -0,0 +1,95 @@ +package registry + +import ( + "testing" + + "github.com/eigeninference/coordinator/internal/protocol" +) + +func TestNetworkPenaltyDefaultsToZeroForMissingMetrics(t *testing.T) { + if got := networkPenaltyMs(protocol.NetworkQuality{}); got != 0 { + t.Fatalf("networkPenaltyMs(zero)=%f, want 0", got) + } +} + +func TestNetworkPenaltyDegradesHighLatencyAndInstability(t *testing.T) { + good := protocol.NetworkQuality{RTTMs: 20, JitterMs: 5} + bad := protocol.NetworkQuality{ + RTTMs: 900, + JitterMs: 250, + ReconnectCount: 4, + WebSocketWriteFailures: 3, + LastWriteLatencyMs: 150, + } + + goodPenalty := networkPenaltyMs(good) + badPenalty := networkPenaltyMs(bad) + if goodPenalty != 0 { + t.Fatalf("good network penalty=%f, want 0 below thresholds", goodPenalty) + } + if badPenalty <= goodPenalty { + t.Fatalf("bad penalty=%f, want > good penalty=%f", badPenalty, goodPenalty) + } + if badPenalty > networkQualityMaxPenaltyMs { + t.Fatalf("bad penalty=%f exceeds max=%f", badPenalty, networkQualityMaxPenaltyMs) + } +} + +func TestReserveProviderPenalizesLowerNetworkQuality(t *testing.T) { + reg := New(testLogger()) + model := "network-quality-model" + good := makeSchedulerProvider(t, reg, "good-network", model, 100) + bad := makeSchedulerProvider(t, reg, "bad-network", model, 100) + + bad.mu.Lock() + bad.NetworkQuality = protocol.NetworkQuality{ + RTTMs: 1000, + JitterMs: 300, + ReconnectCount: 5, + WebSocketWriteFailures: 5, + LastWriteLatencyMs: 250, + } + bad.mu.Unlock() + + selected, decision := reg.ReserveProviderEx(model, &PendingRequest{ + RequestID: "network-quality-req", + Model: model, + RequestedMaxTokens: 128, + }) + if selected == nil { + t.Fatal("ReserveProviderEx returned nil") + } + if selected.ID != good.ID { + t.Fatalf("selected %q, want good-network; decision=%+v", selected.ID, decision) + } + if decision.NetworkMs != 0 { + t.Fatalf("winning good provider NetworkMs=%f, want 0", decision.NetworkMs) + } +} + +func TestHeartbeatStoresNetworkQualityForRoutingSnapshot(t *testing.T) { + reg := New(testLogger()) + model := "heartbeat-network-model" + p := makeSchedulerProvider(t, reg, "heartbeat-network", model, 100) + + reg.Heartbeat(p.ID, &protocol.HeartbeatMessage{ + Type: protocol.TypeHeartbeat, + Status: "idle", + Stats: protocol.HeartbeatStats{}, + NetworkQuality: protocol.NetworkQuality{ + RTTMs: 321, + JitterMs: 45, + ReconnectCount: 2, + WebSocketWriteFailures: 1, + LastWriteLatencyMs: 12, + }, + }) + + snap, ok := reg.snapshotProviderLocked(p, model, &PendingRequest{RequestID: "snap", Model: model}) + if !ok { + t.Fatal("snapshotProviderLocked returned !ok") + } + if snap.networkQuality.RTTMs != 321 { + t.Fatalf("snapshot RTTMs=%f, want 321", snap.networkQuality.RTTMs) + } +} diff --git a/coordinator/internal/registry/registry.go b/coordinator/internal/registry/registry.go index 5cd4c013..c2beb511 100644 --- a/coordinator/internal/registry/registry.go +++ b/coordinator/internal/registry/registry.go @@ -130,7 +130,8 @@ type Provider struct { CurrentModel string // model currently being served // Live system metrics from heartbeats - SystemMetrics protocol.SystemMetrics + SystemMetrics protocol.SystemMetrics + NetworkQuality protocol.NetworkQuality // Live backend capacity from heartbeats (nil for providers without capacity reporting) BackendCapacity *protocol.BackendCapacity @@ -696,6 +697,33 @@ func clampNonNeg(v, max float64) (float64, bool) { // TotalMemoryGB=1e9 would make gpuUtil ~= 0 and dodge health penalties, so // we cap it at maxMemoryGBFloat. Same for MaxTokensPotential which directly // controls backlog cost. NaN/negative become 0. +func clampNetworkQuality(n *protocol.NetworkQuality) { + if n == nil { + return + } + if v, changed := clampNonNeg(n.RTTMs, 10_000); changed { + n.RTTMs = v + } + if v, changed := clampNonNeg(n.JitterMs, 10_000); changed { + n.JitterMs = v + } + if v, changed := clampNonNeg(n.LastWriteLatencyMs, 10_000); changed { + n.LastWriteLatencyMs = v + } + if n.ReconnectCount < 0 { + n.ReconnectCount = 0 + } + if n.ReconnectCount > 1_000_000 { + n.ReconnectCount = 1_000_000 + } + if n.WebSocketWriteFailures < 0 { + n.WebSocketWriteFailures = 0 + } + if n.WebSocketWriteFailures > 1_000_000 { + n.WebSocketWriteFailures = 1_000_000 + } +} + func clampBackendCapacity(logger *slog.Logger, providerID string, bc *protocol.BackendCapacity) { if bc == nil { return @@ -923,6 +951,7 @@ func (r *Registry) Heartbeat(id string, msg *protocol.HeartbeatMessage) { if v, changed := clampNonNeg(msg.SystemMetrics.CPUUsage, 1.0); changed { msg.SystemMetrics.CPUUsage = v } + clampNetworkQuality(&msg.NetworkQuality) p.mu.Lock() p.LastHeartbeat = time.Now() @@ -930,6 +959,7 @@ func (r *Registry) Heartbeat(id string, msg *protocol.HeartbeatMessage) { p.Stats.TokensGenerated += cumulativeDelta(p.lastSessionStats.TokensGenerated, msg.Stats.TokensGenerated) p.lastSessionStats = msg.Stats p.SystemMetrics = msg.SystemMetrics + p.NetworkQuality = msg.NetworkQuality // Update backend capacity from heartbeat (nil-safe for old providers). if msg.BackendCapacity != nil { p.BackendCapacity = msg.BackendCapacity diff --git a/coordinator/internal/registry/scheduler.go b/coordinator/internal/registry/scheduler.go index 9f516df3..497f9153 100644 --- a/coordinator/internal/registry/scheduler.go +++ b/coordinator/internal/registry/scheduler.go @@ -26,15 +26,16 @@ const ( // slow-provider decode, so the cost function actually spreads load // across the fleet. Wider tie window admits more candidates to the // queue-depth tie-break + random distribution. - queueDepthPenaltyMs = 3_000.0 - totalPendingPenaltyMs = 750.0 - memoryPressurePenaltyMs = 4_000.0 - cpuUsagePenaltyMs = 1_500.0 - gpuUtilizationPenaltyMs = 5_000.0 - thermalPenaltyFairMs = 2_000.0 - thermalPenaltySeriousMs = 8_000.0 - nearTieCostWindowMs = 3_000.0 - challengeFreshnessMaxAge = 6 * time.Minute + queueDepthPenaltyMs = 3_000.0 + totalPendingPenaltyMs = 750.0 + memoryPressurePenaltyMs = 4_000.0 + cpuUsagePenaltyMs = 1_500.0 + gpuUtilizationPenaltyMs = 5_000.0 + thermalPenaltyFairMs = 2_000.0 + thermalPenaltySeriousMs = 8_000.0 + networkQualityMaxPenaltyMs = 5_000.0 + nearTieCostWindowMs = 3_000.0 + challengeFreshnessMaxAge = 6 * time.Minute // kvCacheBytesPerToken is a per-token KV-cache size estimate used by // the free-memory admission gate. @@ -77,6 +78,7 @@ type routingSnapshot struct { decodeTPS float64 prefillTPS float64 systemMetrics protocol.SystemMetrics + networkQuality protocol.NetworkQuality gpuMemoryActiveGB float64 totalMemoryGB float64 modelSizeGB float64 // catalog-reported weight footprint (0 = unknown, gate disabled) @@ -115,6 +117,7 @@ type costBreakdown struct { BacklogMs float64 ThisReqMs float64 HealthMs float64 + NetworkMs float64 Total float64 } @@ -131,6 +134,7 @@ type RoutingDecision struct { BacklogMs float64 // tokens-ahead / decodeTPS contribution ThisReqMs float64 // this request's prefill+decode contribution HealthMs float64 // memory/CPU/thermal/GPU-util contribution + NetworkMs float64 // bounded network-quality contribution EffectiveQueue int // max(pendingForModel, backendRunning+backendWaiting) CandidateCount int // total candidates that passed all gates CapacityRejections int // candidates rejected by the free-memory admission gate @@ -206,6 +210,7 @@ func (r *Registry) ReserveProviderEx(model string, pr *PendingRequest, excludeID BacklogMs: bd.BacklogMs, ThisReqMs: bd.ThisReqMs, HealthMs: bd.HealthMs, + NetworkMs: bd.NetworkMs, EffectiveQueue: selected.effectiveQueue, CandidateCount: candidateCount, CapacityRejections: capacityRejections, @@ -349,6 +354,7 @@ func (r *Registry) logRoutingDecision(model string, pr *PendingRequest, winner * "backlog_ms", bd.BacklogMs, "this_req_ms", bd.ThisReqMs, "health_ms", bd.HealthMs, + "network_ms", bd.NetworkMs, "effective_tps", winner.effectiveTPS, "effective_queue", winner.effectiveQueue, "candidates", candidates, @@ -384,15 +390,16 @@ func (r *Registry) snapshotProviderLocked(p *Provider, model string, pr *Pending } snap := routingSnapshot{ - provider: p, - model: model, - slotState: "unknown", - totalPending: p.pendingCount(), - systemMetrics: p.SystemMetrics, - decodeTPS: resolvedDecodeTPS(p), - prefillTPS: resolvedPrefillTPS(p), - totalMemoryGB: float64(p.Hardware.MemoryGB), - modelSizeGB: r.catalogSizeGBLocked(model), + provider: p, + model: model, + slotState: "unknown", + totalPending: p.pendingCount(), + systemMetrics: p.SystemMetrics, + networkQuality: p.NetworkQuality, + decodeTPS: resolvedDecodeTPS(p), + prefillTPS: resolvedPrefillTPS(p), + totalMemoryGB: float64(p.Hardware.MemoryGB), + modelSizeGB: r.catalogSizeGBLocked(model), } for _, pr := range p.pendingReqs { @@ -513,7 +520,8 @@ func (r *Registry) buildCandidateWithReason(snap routingSnapshot, pr *PendingReq backlogMs := backlogTokenMs(snap.maxTokensPotential, waitingBacklogTokens, unaccountedPendingTokens, effectiveTPS) thisReqMs := float64(reqPrompt)/snap.prefillTPS*1000.0 + float64(reqMax)/effectiveTPS*1000.0 healthMs := healthPenaltyMs(snap.systemMetrics, snap.gpuMemoryActiveGB, snap.totalMemoryGB) - cost := statePenalty + queueMs + pendingMs + backlogMs + thisReqMs + healthMs + networkMs := networkPenaltyMs(snap.networkQuality) + cost := statePenalty + queueMs + pendingMs + backlogMs + thisReqMs + healthMs + networkMs return &routingCandidate{ provider: snap.provider, @@ -528,6 +536,7 @@ func (r *Registry) buildCandidateWithReason(snap routingSnapshot, pr *PendingReq BacklogMs: backlogMs, ThisReqMs: thisReqMs, HealthMs: healthMs, + NetworkMs: networkMs, Total: cost, }, }, rejectNone, true @@ -580,6 +589,35 @@ func healthPenaltyMs(m protocol.SystemMetrics, gpuActiveGB, totalMemGB float64) return penalty } +func networkPenaltyMs(n protocol.NetworkQuality) float64 { + // Keep normal regional RTT/jitter free, but penalize degraded transport + // enough to break otherwise-similar ties. The cap preserves hardware/load + // as dominant signals and keeps malicious counters bounded. + penalty := 0.0 + if n.RTTMs > 50 { + penalty += (n.RTTMs - 50) * 2.0 + } + if n.JitterMs > 20 { + penalty += (n.JitterMs - 20) * 5.0 + } + if n.LastWriteLatencyMs > 25 { + penalty += (n.LastWriteLatencyMs - 25) * 2.0 + } + if n.ReconnectCount > 0 { + penalty += float64(n.ReconnectCount) * 250.0 + } + if n.WebSocketWriteFailures > 0 { + penalty += float64(n.WebSocketWriteFailures) * 500.0 + } + if penalty < 0 || math.IsNaN(penalty) { + return 0 + } + if penalty > networkQualityMaxPenaltyMs { + return networkQualityMaxPenaltyMs + } + return penalty +} + // effectiveDecodeTPS scales the static decode TPS down by current // backend batch size. Returns the static value when the load factor is // disabled or batch is unknown. Floored at 1 token/s to avoid divide- diff --git a/coordinator/internal/registry/scheduler_test.go b/coordinator/internal/registry/scheduler_test.go index d10768b0..c6134ddd 100644 --- a/coordinator/internal/registry/scheduler_test.go +++ b/coordinator/internal/registry/scheduler_test.go @@ -107,7 +107,7 @@ func TestReserveProviderExReturnsCostBreakdown(t *testing.T) { } // Sum of components should approximately equal the total cost. sum := decision.StateMs + decision.QueueMs + decision.PendingMs + - decision.BacklogMs + decision.ThisReqMs + decision.HealthMs + decision.BacklogMs + decision.ThisReqMs + decision.HealthMs + decision.NetworkMs if diff := sum - decision.CostMs; diff > 0.001 || diff < -0.001 { t.Fatalf("breakdown sum %f != CostMs %f", sum, decision.CostMs) } diff --git a/coordinator/internal/registry/simulator.go b/coordinator/internal/registry/simulator.go new file mode 100644 index 00000000..c8555661 --- /dev/null +++ b/coordinator/internal/registry/simulator.go @@ -0,0 +1,426 @@ +package registry + +import ( + "math" + "math/rand" + "sort" + + "github.com/eigeninference/coordinator/internal/protocol" +) + +// RoutingReplayStrategy names a local-only simulator routing policy. +type RoutingReplayStrategy string + +const ( + RoutingStrategyCurrentCostModel RoutingReplayStrategy = "current_cost_model" + RoutingStrategyRoundRobin RoutingReplayStrategy = "round_robin" + RoutingStrategyLeastActive RoutingReplayStrategy = "least_active" + // RoutingStrategyLeastMetric intentionally uses only stale snapshot metrics + // (CPU/memory) and ignores assignments made during the replay. It models the + // Riot/Valorant failure mode where a burst herds onto the machine that looked + // best in the last heartbeat. + RoutingStrategyLeastMetric RoutingReplayStrategy = "least_metric" + RoutingStrategyRandomNearTie RoutingReplayStrategy = "random_near_tie" +) + +// RoutingReplayProvider is the synthetic provider snapshot used by the replay +// harness. It is intentionally small and independent of live registry state. +type RoutingReplayProvider struct { + ID string + Model string + DecodeTPS float64 + PrefillTPS float64 + MaxConcurrency int + CPUUsage float64 + MemoryPressure float64 + ThermalState string + TotalMemoryGB float64 + GPUActiveGB float64 +} + +// RoutingReplayRequest is a synthetic request arrival. +type RoutingReplayRequest struct { + ID string + Model string + ArrivalMs float64 + PromptTokens int + MaxTokens int + TimeoutMs float64 +} + +// RoutingReplayScenario contains all inputs for a deterministic replay. +type RoutingReplayScenario struct { + Providers []RoutingReplayProvider + Requests []RoutingReplayRequest + Strategies []RoutingReplayStrategy + RandomSeed int64 +} + +// RoutingReplayAssignment records one replayed routing decision. +type RoutingReplayAssignment struct { + RequestID string + ProviderID string + ArrivalMs float64 + StartMs float64 + TTFTMs float64 + TotalLatencyMs float64 + OverCapacity bool + TimedOut bool +} + +// RoutingReplayMetrics contains per-strategy replay output. +type RoutingReplayMetrics struct { + Strategy RoutingReplayStrategy + Assignments []RoutingReplayAssignment + ProviderRequestCounts map[string]int + ProviderUtilization map[string]float64 + TTFTP50Ms float64 + TTFTP95Ms float64 + TTFTP99Ms float64 + TotalLatencyP50Ms float64 + TotalLatencyP95Ms float64 + TotalLatencyP99Ms float64 + TotalLatencyMs float64 + TimeoutCount int + OverCapacityCount int +} + +// ReplayRoutingScenario replays synthetic request arrivals against provider +// snapshots for each requested strategy. It has no side effects on the live +// registry and is intended for local unit tests and routing experiments. +func ReplayRoutingScenario(s RoutingReplayScenario) map[RoutingReplayStrategy]RoutingReplayMetrics { + strategies := s.Strategies + if len(strategies) == 0 { + strategies = []RoutingReplayStrategy{RoutingStrategyCurrentCostModel, RoutingStrategyRoundRobin, RoutingStrategyLeastActive, RoutingStrategyLeastMetric} + } + out := make(map[RoutingReplayStrategy]RoutingReplayMetrics, len(strategies)) + for _, strategy := range strategies { + out[strategy] = replayRoutingStrategy(s, strategy) + } + return out +} + +type replayProviderState struct { + spec RoutingReplayProvider + active []replayActiveRequest + busyMs float64 + rrCursor int + inputSlot int +} + +type replayActiveRequest struct { + endMs float64 + maxTokens int +} + +func replayRoutingStrategy(s RoutingReplayScenario, strategy RoutingReplayStrategy) RoutingReplayMetrics { + providers := make([]*replayProviderState, 0, len(s.Providers)) + for i, p := range s.Providers { + if p.ID == "" { + p.ID = string(rune('a' + i)) + } + if p.DecodeTPS <= 0 { + p.DecodeTPS = 1 + } + if p.PrefillTPS <= 0 { + p.PrefillTPS = p.DecodeTPS * 4 + } + if p.MaxConcurrency <= 0 { + p.MaxConcurrency = DefaultMaxConcurrent + } + if p.ThermalState == "" { + p.ThermalState = "nominal" + } + providers = append(providers, &replayProviderState{spec: p, inputSlot: i}) + } + + requests := append([]RoutingReplayRequest(nil), s.Requests...) + sort.SliceStable(requests, func(i, j int) bool { + if requests[i].ArrivalMs == requests[j].ArrivalMs { + return requests[i].ID < requests[j].ID + } + return requests[i].ArrivalMs < requests[j].ArrivalMs + }) + + seed := s.RandomSeed + if seed == 0 { + seed = 1 + } + rng := rand.New(rand.NewSource(seed)) + counts := make(map[string]int, len(providers)) + util := make(map[string]float64, len(providers)) + assignments := make([]RoutingReplayAssignment, 0, len(requests)) + var rr int + + for _, req := range requests { + if req.MaxTokens <= 0 { + req.MaxTokens = defaultRequestedMaxTokens + } + for _, p := range providers { + p.completeThrough(req.ArrivalMs) + } + + selected := selectReplayProvider(providers, req, strategy, rr, rng) + if selected == nil { + assignments = append(assignments, RoutingReplayAssignment{RequestID: req.ID, ArrivalMs: req.ArrivalMs, OverCapacity: true, TimedOut: true}) + continue + } + if strategy == RoutingStrategyRoundRobin { + rr = selected.inputSlot + 1 + } + + activeBefore := len(selected.active) + overCapacity := activeBefore >= selected.spec.MaxConcurrency + start := req.ArrivalMs + if overCapacity { + start = selected.nextAvailableMs(req.ArrivalMs) + } + prefillMs := float64(req.PromptTokens) / selected.spec.PrefillTPS * 1000 + decodeTPS := effectiveDecodeTPS(selected.spec.DecodeTPS, minInt(activeBefore, selected.spec.MaxConcurrency)) + decodeMs := float64(req.MaxTokens) / decodeTPS * 1000 + serviceMs := prefillMs + decodeMs + end := start + serviceMs + ttft := (start - req.ArrivalMs) + prefillMs + latency := end - req.ArrivalMs + timedOut := req.TimeoutMs > 0 && ttft > req.TimeoutMs + + selected.active = append(selected.active, replayActiveRequest{endMs: end, maxTokens: req.MaxTokens}) + selected.busyMs += serviceMs + counts[selected.spec.ID]++ + assignments = append(assignments, RoutingReplayAssignment{ + RequestID: req.ID, + ProviderID: selected.spec.ID, + ArrivalMs: req.ArrivalMs, + StartMs: start, + TTFTMs: ttft, + TotalLatencyMs: latency, + OverCapacity: overCapacity, + TimedOut: timedOut, + }) + } + + var horizon float64 + for _, req := range requests { + if req.ArrivalMs > horizon { + horizon = req.ArrivalMs + } + } + for _, p := range providers { + for _, a := range p.active { + if a.endMs > horizon { + horizon = a.endMs + } + } + } + if horizon <= 0 { + horizon = 1 + } + for _, p := range providers { + counts[p.spec.ID] += 0 + util[p.spec.ID] = p.busyMs / horizon + } + + metrics := RoutingReplayMetrics{ + Strategy: strategy, + Assignments: assignments, + ProviderRequestCounts: counts, + ProviderUtilization: util, + } + var ttfts, latencies []float64 + for _, a := range assignments { + if a.ProviderID == "" { + metrics.OverCapacityCount++ + metrics.TimeoutCount++ + continue + } + ttfts = append(ttfts, a.TTFTMs) + latencies = append(latencies, a.TotalLatencyMs) + metrics.TotalLatencyMs += a.TotalLatencyMs + if a.OverCapacity { + metrics.OverCapacityCount++ + } + if a.TimedOut { + metrics.TimeoutCount++ + } + } + metrics.TTFTP50Ms = percentile(ttfts, 0.50) + metrics.TTFTP95Ms = percentile(ttfts, 0.95) + metrics.TTFTP99Ms = percentile(ttfts, 0.99) + metrics.TotalLatencyP50Ms = percentile(latencies, 0.50) + metrics.TotalLatencyP95Ms = percentile(latencies, 0.95) + metrics.TotalLatencyP99Ms = percentile(latencies, 0.99) + return metrics +} + +func (p *replayProviderState) completeThrough(nowMs float64) { + kept := p.active[:0] + for _, a := range p.active { + if a.endMs > nowMs { + kept = append(kept, a) + } + } + p.active = kept +} + +func (p *replayProviderState) nextAvailableMs(arrivalMs float64) float64 { + if len(p.active) < p.spec.MaxConcurrency { + return arrivalMs + } + ends := make([]float64, 0, len(p.active)) + for _, a := range p.active { + ends = append(ends, a.endMs) + } + sort.Float64s(ends) + idx := len(p.active) - p.spec.MaxConcurrency + if idx < 0 { + idx = 0 + } + if ends[idx] < arrivalMs { + return arrivalMs + } + return ends[idx] +} + +func selectReplayProvider(providers []*replayProviderState, req RoutingReplayRequest, strategy RoutingReplayStrategy, rr int, rng *rand.Rand) *replayProviderState { + eligible := make([]*replayProviderState, 0, len(providers)) + for _, p := range providers { + if p.spec.Model == "" || p.spec.Model == req.Model { + eligible = append(eligible, p) + } + } + if len(eligible) == 0 { + return nil + } + + switch strategy { + case RoutingStrategyRoundRobin: + for i := 0; i < len(providers); i++ { + p := providers[(rr+i)%len(providers)] + if (p.spec.Model == "" || p.spec.Model == req.Model) && len(p.active) < p.spec.MaxConcurrency { + return p + } + } + return eligible[rr%len(eligible)] + case RoutingStrategyLeastActive: + return minReplayProvider(eligible, func(p *replayProviderState) float64 { return float64(len(p.active)) }) + case RoutingStrategyLeastMetric: + return minReplayProvider(eligible, func(p *replayProviderState) float64 { return p.spec.CPUUsage + p.spec.MemoryPressure }) + case RoutingStrategyRandomNearTie: + costs := replayCostCandidates(eligible, req) + if len(costs) == 0 { + return minReplayProvider(eligible, func(p *replayProviderState) float64 { return float64(len(p.active)) }) + } + best := costs[0].cost + pool := make([]*replayProviderState, 0, len(costs)) + for _, c := range costs { + if math.Abs(c.cost-best) <= nearTieCostWindowMs { + pool = append(pool, c.provider) + } + } + return pool[rng.Intn(len(pool))] + case RoutingStrategyCurrentCostModel: + fallthrough + default: + costs := replayCostCandidates(eligible, req) + if len(costs) == 0 { + return nil + } + return costs[0].provider + } +} + +type replayCostCandidate struct { + provider *replayProviderState + cost float64 +} + +func replayCostCandidates(providers []*replayProviderState, req RoutingReplayRequest) []replayCostCandidate { + out := make([]replayCostCandidate, 0, len(providers)) + reg := &Registry{} + for _, p := range providers { + if len(p.active) >= p.spec.MaxConcurrency { + continue + } + pendingTokens := 0 + maxPotential := int64(0) + for _, a := range p.active { + pendingTokens += a.maxTokens + maxPotential += int64(a.maxTokens) + } + snap := routingSnapshot{ + provider: &Provider{ID: p.spec.ID}, + model: req.Model, + slotState: "running", + totalPending: len(p.active), + pendingForModel: len(p.active), + pendingMaxTokens: pendingTokens, + backendRunning: len(p.active), + maxTokensPotential: maxPotential, + decodeTPS: p.spec.DecodeTPS, + prefillTPS: p.spec.PrefillTPS, + systemMetrics: protocol.SystemMetrics{ + MemoryPressure: p.spec.MemoryPressure, + CPUUsage: p.spec.CPUUsage, + ThermalState: p.spec.ThermalState, + }, + gpuMemoryActiveGB: p.spec.GPUActiveGB, + totalMemoryGB: p.spec.TotalMemoryGB, + modelLoaded: true, + } + candidate, _, ok := reg.buildCandidateWithReason(snap, &PendingRequest{ + RequestID: req.ID, + Model: req.Model, + EstimatedPromptTokens: req.PromptTokens, + RequestedMaxTokens: req.MaxTokens, + }) + if ok { + out = append(out, replayCostCandidate{provider: p, cost: candidate.costMs}) + } + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].cost == out[j].cost { + if len(out[i].provider.active) == len(out[j].provider.active) { + return out[i].provider.spec.ID < out[j].provider.spec.ID + } + return len(out[i].provider.active) < len(out[j].provider.active) + } + return out[i].cost < out[j].cost + }) + return out +} + +func minReplayProvider(providers []*replayProviderState, score func(*replayProviderState) float64) *replayProviderState { + var best *replayProviderState + var bestScore float64 + for _, p := range providers { + s := score(p) + if best == nil || s < bestScore || (s == bestScore && p.spec.ID < best.spec.ID) { + best = p + bestScore = s + } + } + return best +} + +func percentile(values []float64, q float64) float64 { + if len(values) == 0 { + return 0 + } + copyVals := append([]float64(nil), values...) + sort.Float64s(copyVals) + idx := int(math.Ceil(q*float64(len(copyVals)))) - 1 + if idx < 0 { + idx = 0 + } + if idx >= len(copyVals) { + idx = len(copyVals) - 1 + } + return copyVals[idx] +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/coordinator/internal/registry/simulator_test.go b/coordinator/internal/registry/simulator_test.go new file mode 100644 index 00000000..4fcb2a60 --- /dev/null +++ b/coordinator/internal/registry/simulator_test.go @@ -0,0 +1,110 @@ +package registry + +import "testing" + +func TestRoutingReplaySimulatorRiotLessonStaleMetricHerdsWhileRoundRobinSpreads(t *testing.T) { + model := "riot-lesson-model" + scenario := RoutingReplayScenario{ + Providers: []RoutingReplayProvider{ + {ID: "p0", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"}, + {ID: "p1", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.50, MemoryPressure: 0.10, ThermalState: "nominal"}, + {ID: "p2", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.50, MemoryPressure: 0.10, ThermalState: "nominal"}, + }, + Requests: burstRequests(model, 6, 0, 100, 200, 1_000), + Strategies: []RoutingReplayStrategy{ + RoutingStrategyLeastMetric, + RoutingStrategyRoundRobin, + }, + } + + results := ReplayRoutingScenario(scenario) + leastMetric := results[RoutingStrategyLeastMetric] + roundRobin := results[RoutingStrategyRoundRobin] + + if got := leastMetric.ProviderRequestCounts["p0"]; got != 6 { + t.Fatalf("stale least_metric routed %d requests to p0, want 6-request herd", got) + } + if leastMetric.ProviderRequestCounts["p1"] != 0 || leastMetric.ProviderRequestCounts["p2"] != 0 { + t.Fatalf("stale least_metric should not spread: counts=%v", leastMetric.ProviderRequestCounts) + } + if leastMetric.OverCapacityCount != 4 { + t.Fatalf("least_metric over-capacity count=%d, want 4", leastMetric.OverCapacityCount) + } + if leastMetric.TimeoutCount == 0 { + t.Fatal("least_metric should produce TTFT timeouts when stale CPU herds the burst") + } + + for _, id := range []string{"p0", "p1", "p2"} { + if got := roundRobin.ProviderRequestCounts[id]; got != 2 { + t.Fatalf("round_robin count for %s=%d, want 2 (counts=%v)", id, got, roundRobin.ProviderRequestCounts) + } + } + if roundRobin.OverCapacityCount != 0 { + t.Fatalf("round_robin over-capacity count=%d, want 0", roundRobin.OverCapacityCount) + } + if roundRobin.TimeoutCount != 0 { + t.Fatalf("round_robin timeout count=%d, want 0", roundRobin.TimeoutCount) + } + if !(leastMetric.TTFTP95Ms > roundRobin.TTFTP95Ms) { + t.Fatalf("expected least_metric p95 TTFT (%f) > round_robin (%f)", leastMetric.TTFTP95Ms, roundRobin.TTFTP95Ms) + } +} + +func TestRoutingReplaySimulatorCurrentCostModelDeterministicComparison(t *testing.T) { + model := "cost-model-replay" + scenario := RoutingReplayScenario{ + Providers: []RoutingReplayProvider{ + {ID: "fast", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"}, + {ID: "slow", Model: model, DecodeTPS: 80, PrefillTPS: 320, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"}, + }, + Requests: burstRequests(model, 4, 0, 100, 200, 10_000), + Strategies: []RoutingReplayStrategy{ + RoutingStrategyCurrentCostModel, + RoutingStrategyLeastActive, + }, + } + + first := ReplayRoutingScenario(scenario) + second := ReplayRoutingScenario(scenario) + + costA := first[RoutingStrategyCurrentCostModel] + costB := second[RoutingStrategyCurrentCostModel] + if len(costA.Assignments) != len(costB.Assignments) { + t.Fatalf("assignment lengths differ: %d vs %d", len(costA.Assignments), len(costB.Assignments)) + } + for i := range costA.Assignments { + if costA.Assignments[i].ProviderID != costB.Assignments[i].ProviderID { + t.Fatalf("current cost model assignment %d not deterministic: %q vs %q", i, costA.Assignments[i].ProviderID, costB.Assignments[i].ProviderID) + } + } + if got := costA.ProviderRequestCounts["fast"]; got != 2 { + t.Fatalf("current cost model fast count=%d, want 2 (counts=%v assignments=%v)", got, costA.ProviderRequestCounts, costA.Assignments) + } + if got := costA.ProviderRequestCounts["slow"]; got != 2 { + t.Fatalf("current cost model slow count=%d, want 2 (counts=%v assignments=%v)", got, costA.ProviderRequestCounts, costA.Assignments) + } + if costA.TTFTP50Ms <= 0 || costA.TTFTP95Ms <= 0 || costA.TTFTP99Ms <= 0 { + t.Fatalf("TTFT percentiles must be populated: %+v", costA) + } + if costA.TotalLatencyP50Ms <= 0 || costA.TotalLatencyP95Ms <= 0 || costA.TotalLatencyP99Ms <= 0 || costA.TotalLatencyMs <= 0 { + t.Fatalf("latency metrics must be populated: %+v", costA) + } + if len(costA.ProviderUtilization) != 2 || costA.ProviderUtilization["fast"] <= 0 || costA.ProviderUtilization["slow"] <= 0 { + t.Fatalf("provider utilization must be populated: %+v", costA.ProviderUtilization) + } +} + +func burstRequests(model string, n int, arrivalMs float64, promptTokens, maxTokens int, timeoutMs float64) []RoutingReplayRequest { + reqs := make([]RoutingReplayRequest, 0, n) + for i := 0; i < n; i++ { + reqs = append(reqs, RoutingReplayRequest{ + ID: string(rune('a' + i)), + Model: model, + ArrivalMs: arrivalMs, + PromptTokens: promptTokens, + MaxTokens: maxTokens, + TimeoutMs: timeoutMs, + }) + } + return reqs +} diff --git a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift index ffc5a892..b97e7155 100644 --- a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift +++ b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift @@ -59,6 +59,7 @@ public final class ProviderState: @unchecked Sendable { private var _warmModels: [String] = [] private var _currentModelHash: String? = nil private var _backendCapacity: BackendCapacity? = nil + private var _networkQuality = NetworkQuality() public init() {} @@ -86,6 +87,11 @@ public final class ProviderState: @unchecked Sendable { get { lock.withLock { _backendCapacity } } set { lock.withLock { _backendCapacity = newValue } } } + + public var networkQuality: NetworkQuality { + get { lock.withLock { _networkQuality } } + set { lock.withLock { _networkQuality = newValue } } + } } // MARK: - os_unfair_lock wrapper (Sendable-safe) @@ -280,6 +286,7 @@ public actor CoordinatorClient { private let state: ProviderState private let logger = Logger(subsystem: "dev.darkbloom.provider", category: "coordinator") + private let networkQualityTracker = NetworkQualityTracker() private var eventContinuation: AsyncStream.Continuation? private var outboundContinuation: AsyncStream.Continuation? @@ -349,6 +356,8 @@ public actor CoordinatorClient { logger.warning("Coordinator connection error: \(error.localizedDescription). Reconnecting in \(delay)s") reconnectCount += 1 + networkQualityTracker.recordReconnect() + state.networkQuality = networkQualityTracker.snapshot() if shouldEmitReconnectTelemetry(count: reconnectCount) { emitReconnectTelemetry(count: reconnectCount, error: error) } @@ -413,7 +422,15 @@ public actor CoordinatorClient { let shutting = await self.shutdownRequested if shutting { break } let json = await self.encodeOutbound(msg) - try await ws.send(.string(json)) + let started = CFAbsoluteTimeGetCurrent() + do { + try await ws.send(.string(json)) + let elapsedMs = (CFAbsoluteTimeGetCurrent() - started) * 1000.0 + await self.recordWriteLatency(elapsedMs) + } catch { + await self.recordWriteFailure() + throw error + } } } @@ -442,9 +459,18 @@ public actor CoordinatorClient { throw CoordinatorError.pongTimeout } - ws.sendPing { error in + let started = CFAbsoluteTimeGetCurrent() + ws.sendPing { [weak self] error in if error == nil { pongTracker.recordPong() + let elapsedMs = (CFAbsoluteTimeGetCurrent() - started) * 1000.0 + self?.networkQualityTracker.recordPong(rttMs: elapsedMs) + if let self { + Task { await self.publishNetworkQualitySnapshot() } + } + } else if let self { + self.networkQualityTracker.recordWriteFailure() + Task { await self.publishNetworkQualitySnapshot() } } } } @@ -613,6 +639,7 @@ public actor CoordinatorClient { tokensGenerated: stats.tokensGenerated ), systemMetrics: metrics, + networkQuality: state.networkQuality, backendCapacity: capacity ) @@ -623,6 +650,20 @@ public actor CoordinatorClient { return json } + private func recordWriteLatency(_ ms: Double) { + networkQualityTracker.recordWriteLatency(ms: ms) + state.networkQuality = networkQualityTracker.snapshot() + } + + private func recordWriteFailure() { + networkQualityTracker.recordWriteFailure() + state.networkQuality = networkQualityTracker.snapshot() + } + + private func publishNetworkQualitySnapshot() { + state.networkQuality = networkQualityTracker.snapshot() + } + // MARK: - Outbound Encoding private func encodeOutbound(_ msg: OutboundMessage) -> String { diff --git a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift index ed340c59..da43b96b 100644 --- a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift +++ b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift @@ -46,6 +46,7 @@ public enum CoordinatorClientCodec { warmModels: [String], stats: ProviderStats, systemMetrics: SystemMetrics, + networkQuality: NetworkQuality = NetworkQuality(), backendCapacity: BackendCapacity? ) -> ProviderMessage { .heartbeat(ProviderMessage.Heartbeat( @@ -54,6 +55,7 @@ public enum CoordinatorClientCodec { warmModels: warmModels, stats: stats, systemMetrics: systemMetrics, + networkQuality: networkQuality, backendCapacity: backendCapacity )) } diff --git a/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift b/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift new file mode 100644 index 00000000..e9895f16 --- /dev/null +++ b/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift @@ -0,0 +1,59 @@ +import Foundation + +/// Tracks provider-observed coordinator WebSocket transport quality. +/// All methods are synchronous and lock-backed so ping/write callbacks can +/// update from arbitrary URLSession queues without actor hops. +public final class NetworkQualityTracker: @unchecked Sendable { + private let lock = NSLock() + private var latestRTTMs: Double = 0 + private var latestJitterMs: Double = 0 + private var previousRTTMs: Double? + private var reconnects: UInt64 = 0 + private var writeFailures: UInt64 = 0 + private var latestWriteLatencyMs: Double = 0 + + public init() {} + + public func recordPong(rttMs: Double) { + let bounded = max(0, rttMs) + lock.withLock { + if let previousRTTMs { + latestJitterMs = abs(bounded - previousRTTMs) + } + previousRTTMs = bounded + latestRTTMs = bounded + } + } + + public func recordReconnect() { + lock.withLock { reconnects &+= 1 } + } + + public func recordWriteFailure() { + lock.withLock { writeFailures &+= 1 } + } + + public func recordWriteLatency(ms: Double) { + lock.withLock { latestWriteLatencyMs = max(0, ms) } + } + + public func snapshot() -> NetworkQuality { + lock.withLock { + NetworkQuality( + rttMs: latestRTTMs, + jitterMs: latestJitterMs, + reconnectCount: reconnects, + websocketWriteFailures: writeFailures, + lastWriteLatencyMs: latestWriteLatencyMs + ) + } + } +} + +private extension NSLock { + func withLock(_ body: () -> T) -> T { + lock() + defer { unlock() } + return body() + } +} diff --git a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift index e6d27b08..eb51d0f9 100644 --- a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift +++ b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift @@ -21,6 +21,7 @@ public struct SchedulerCapacity: Sendable { public let model: String public let activeRequests: Int public let pendingRequests: Int + public let activeTokens: Int public let maxConcurrent: Int public let gpuMemoryActiveBytes: Int public let gpuMemoryPeakBytes: Int @@ -168,11 +169,13 @@ public actor BatchScheduler { topK: request.top_k ?? 0, seed: request.seed ) + let admitStartedAt = ContinuousClock.now let assignedUids = gen.insert( prompts: [promptTokens], maxTokens: [maxTokens], samplers: [sampler] ) + let admittedAt = ContinuousClock.now guard let uid = assignedUids.first else { continuation.yield(.error("BatchGenerator rejected the prompt")) continuation.finish() @@ -185,10 +188,15 @@ public actor BatchScheduler { detokenizer: NaiveStreamingDetokenizer(tokenizer: tk.inner), promptTokens: promptTokens.count, completionTokens: 0, - submittedAt: .now + submittedAt: admitStartedAt, + admittedAt: admittedAt, + firstTokenAt: nil, + model: telemetryModelId(for: request) ) requestIdToUid[id] = uid + emitLifecycle(.schedulerAdmit, entry: active[uid]!, activeCount: active.count) + let scheduler = self continuation.onTermination = { @Sendable termination in if case .cancelled = termination { @@ -214,10 +222,12 @@ public actor BatchScheduler { // MARK: - Capacity public func capacity() -> SchedulerCapacity { - SchedulerCapacity( + let counters = capacityCounters() + return SchedulerCapacity( model: modelId, - activeRequests: active.count, - pendingRequests: 0, + activeRequests: counters.activeRequests, + pendingRequests: counters.pendingRequests, + activeTokens: counters.activeTokens, maxConcurrent: maxConcurrentRequests, gpuMemoryActiveBytes: gpuMemory(.active), gpuMemoryPeakBytes: gpuMemory(.peak), @@ -234,7 +244,7 @@ public actor BatchScheduler { state: cap.activeRequests > 0 ? "running" : "idle", numRunning: UInt32(cap.activeRequests), numWaiting: UInt32(cap.pendingRequests), - activeTokens: 0, + activeTokens: Int64(cap.activeTokens), maxTokensPotential: Int64(defaultMaxTokens * maxConcurrentRequests) ) return BackendCapacity( @@ -279,6 +289,10 @@ public actor BatchScheduler { entry.detokenizer.append(token: response.token) entry.completionTokens += 1 + if entry.firstTokenAt == nil { + entry.firstTokenAt = .now + emitLifecycle(.firstToken, entry: entry, activeCount: active.count) + } if let chunk = entry.detokenizer.next() { entry.continuation.yield(.chunk(chunk)) } @@ -307,6 +321,7 @@ public actor BatchScheduler { entry.continuation.finish() active.removeValue(forKey: response.uid) requestIdToUid.removeValue(forKey: entry.requestId) + emitLifecycle(.inferenceComplete, entry: entry, activeCount: active.count) } } @@ -315,6 +330,50 @@ public actor BatchScheduler { requestIdToUid.removeValue(forKey: entry.requestId) entry.continuation.yield(.error(error)) entry.continuation.finish() + let stage: InferenceLifecycleTelemetryStage = error == "Request cancelled" + ? .inferenceCancel + : .inferenceError + emitLifecycle(stage, entry: entry, activeCount: active.count, error: error, reason: error) + } + + private func capacityCounters() -> SchedulerCapacityCounters { + SchedulerCapacityCounters.from(active.values.map { entry in + InferenceRequestProgress( + promptTokens: entry.promptTokens, + completionTokens: entry.completionTokens, + firstTokenReceived: entry.firstTokenAt != nil + ) + }) + } + + private func emitLifecycle( + _ stage: InferenceLifecycleTelemetryStage, + entry: ActiveRequest, + activeCount: Int, + error: String? = nil, + reason: String? = nil + ) { + TelemetryClient.shared.emit(InferenceLifecycleTelemetry.event( + stage, + snapshot: InferenceLifecycleTelemetrySnapshot( + requestId: entry.requestId, + model: entry.model, + promptTokens: entry.promptTokens, + completionTokens: entry.completionTokens, + queueMilliseconds: InferenceLifecycleTelemetry.milliseconds(entry.admittedAt - entry.submittedAt), + admitMilliseconds: InferenceLifecycleTelemetry.milliseconds(entry.admittedAt - entry.submittedAt), + ttftMilliseconds: entry.firstTokenAt.map { InferenceLifecycleTelemetry.milliseconds($0 - entry.submittedAt) }, + totalMilliseconds: InferenceLifecycleTelemetry.milliseconds(ContinuousClock.now - entry.submittedAt), + activeCount: activeCount + ), + error: error, + reason: reason + )) + } + + private func telemetryModelId(for request: ChatCompletionRequest) -> String { + if !request.model.isEmpty { return request.model } + return modelId } private enum MemoryKind { case active, peak, cache } @@ -341,6 +400,9 @@ private struct ActiveRequest { var promptTokens: Int var completionTokens: Int let submittedAt: ContinuousClock.Instant + let admittedAt: ContinuousClock.Instant + var firstTokenAt: ContinuousClock.Instant? + let model: String } private struct LoadSnapshot: @unchecked Sendable { diff --git a/provider-swift/Sources/ProviderCore/Protocol/Messages.swift b/provider-swift/Sources/ProviderCore/Protocol/Messages.swift index 5b75c796..ebf5a1a4 100644 --- a/provider-swift/Sources/ProviderCore/Protocol/Messages.swift +++ b/provider-swift/Sources/ProviderCore/Protocol/Messages.swift @@ -70,6 +70,7 @@ public enum ProviderMessage: Sendable, Equatable { public var warmModels: [String] public var stats: ProviderStats public var systemMetrics: SystemMetrics + public var networkQuality: NetworkQuality public var backendCapacity: BackendCapacity? public init( @@ -78,6 +79,7 @@ public enum ProviderMessage: Sendable, Equatable { warmModels: [String] = [], stats: ProviderStats, systemMetrics: SystemMetrics, + networkQuality: NetworkQuality = NetworkQuality(), backendCapacity: BackendCapacity? = nil ) { self.status = status @@ -85,6 +87,7 @@ public enum ProviderMessage: Sendable, Equatable { self.warmModels = warmModels self.stats = stats self.systemMetrics = systemMetrics + self.networkQuality = networkQuality self.backendCapacity = backendCapacity } } @@ -238,6 +241,7 @@ extension ProviderMessage: Codable { case warmModels = "warm_models" case stats case systemMetrics = "system_metrics" + case networkQuality = "network_quality" case backendCapacity = "backend_capacity" // Common case requestId = "request_id" @@ -300,6 +304,7 @@ extension ProviderMessage: Codable { } try container.encode(h.stats, forKey: .stats) try container.encode(h.systemMetrics, forKey: .systemMetrics) + try container.encode(h.networkQuality, forKey: .networkQuality) try container.encodeIfPresent(h.backendCapacity, forKey: .backendCapacity) case .inferenceAccepted(let a): @@ -387,6 +392,7 @@ extension ProviderMessage: Codable { warmModels: try container.decodeIfPresent([String].self, forKey: .warmModels) ?? [], stats: try container.decode(ProviderStats.self, forKey: .stats), systemMetrics: try container.decode(SystemMetrics.self, forKey: .systemMetrics), + networkQuality: try container.decodeIfPresent(NetworkQuality.self, forKey: .networkQuality) ?? NetworkQuality(), backendCapacity: try container.decodeIfPresent(BackendCapacity.self, forKey: .backendCapacity) )) diff --git a/provider-swift/Sources/ProviderCore/Protocol/Types.swift b/provider-swift/Sources/ProviderCore/Protocol/Types.swift index 1417a8db..993467ea 100644 --- a/provider-swift/Sources/ProviderCore/Protocol/Types.swift +++ b/provider-swift/Sources/ProviderCore/Protocol/Types.swift @@ -140,6 +140,36 @@ public struct ProviderStats: Codable, Sendable, Equatable { } } +public struct NetworkQuality: Codable, Sendable, Equatable { + public var rttMs: Double + public var jitterMs: Double + public var reconnectCount: UInt64 + public var websocketWriteFailures: UInt64 + public var lastWriteLatencyMs: Double + + enum CodingKeys: String, CodingKey { + case rttMs = "rtt_ms" + case jitterMs = "jitter_ms" + case reconnectCount = "reconnect_count" + case websocketWriteFailures = "websocket_write_failures" + case lastWriteLatencyMs = "last_write_latency_ms" + } + + public init( + rttMs: Double = 0, + jitterMs: Double = 0, + reconnectCount: UInt64 = 0, + websocketWriteFailures: UInt64 = 0, + lastWriteLatencyMs: Double = 0 + ) { + self.rttMs = rttMs + self.jitterMs = jitterMs + self.reconnectCount = reconnectCount + self.websocketWriteFailures = websocketWriteFailures + self.lastWriteLatencyMs = lastWriteLatencyMs + } +} + public struct UsageInfo: Codable, Sendable, Equatable { public var promptTokens: UInt64 public var completionTokens: UInt64 diff --git a/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift b/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift new file mode 100644 index 00000000..04a02b10 --- /dev/null +++ b/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift @@ -0,0 +1,172 @@ +import Foundation + +/// Lightweight request progress used to build capacity snapshots without +/// depending on MLX runtime state. `completionTokens == 0` is treated as +/// pending/prefill until the first token is observed. +public struct InferenceRequestProgress: Sendable, Equatable { + public let promptTokens: Int + public let completionTokens: Int + public let firstTokenReceived: Bool + + public init( + promptTokens: Int, + completionTokens: Int, + firstTokenReceived: Bool + ) { + self.promptTokens = promptTokens + self.completionTokens = completionTokens + self.firstTokenReceived = firstTokenReceived + } +} + +public struct SchedulerCapacityCounters: Sendable, Equatable { + public let activeRequests: Int + public let pendingRequests: Int + public let activeTokens: Int + + public init(activeRequests: Int, pendingRequests: Int, activeTokens: Int) { + self.activeRequests = activeRequests + self.pendingRequests = pendingRequests + self.activeTokens = activeTokens + } + + public static func from(_ requests: some Sequence) -> SchedulerCapacityCounters { + var activeRequests = 0 + var pendingRequests = 0 + var activeTokens = 0 + + for request in requests { + activeRequests += 1 + if !request.firstTokenReceived { + pendingRequests += 1 + } + activeTokens += max(0, request.promptTokens) + max(0, request.completionTokens) + } + + return SchedulerCapacityCounters( + activeRequests: activeRequests, + pendingRequests: pendingRequests, + activeTokens: activeTokens + ) + } +} + +public enum InferenceLifecycleTelemetryStage: String, Sendable, Equatable { + case schedulerAdmit = "scheduler_admit" + case firstToken = "first_token" + case inferenceComplete = "inference_complete" + case inferenceError = "inference_error" + case inferenceCancel = "inference_cancel" +} + +public struct InferenceLifecycleTelemetrySnapshot: Sendable, Equatable { + public let requestId: String + public let model: String + public let promptTokens: Int + public let completionTokens: Int + public let queueMilliseconds: Double? + public let admitMilliseconds: Double? + public let ttftMilliseconds: Double? + public let totalMilliseconds: Double? + public let activeCount: Int? + + public init( + requestId: String, + model: String, + promptTokens: Int, + completionTokens: Int, + queueMilliseconds: Double? = nil, + admitMilliseconds: Double? = nil, + ttftMilliseconds: Double? = nil, + totalMilliseconds: Double? = nil, + activeCount: Int? = nil + ) { + self.requestId = requestId + self.model = model + self.promptTokens = promptTokens + self.completionTokens = completionTokens + self.queueMilliseconds = queueMilliseconds + self.admitMilliseconds = admitMilliseconds + self.ttftMilliseconds = ttftMilliseconds + self.totalMilliseconds = totalMilliseconds + self.activeCount = activeCount + } +} + +public enum InferenceLifecycleTelemetry { + public static func event( + _ stage: InferenceLifecycleTelemetryStage, + snapshot: InferenceLifecycleTelemetrySnapshot, + error: String? = nil, + reason: String? = nil + ) -> TelemetryEvent { + var fields: [String: AnyCodableValue] = [ + "component": .string("batch_scheduler"), + "operation": .string(stage.rawValue), + "model": .string(snapshot.model), + "prompt_tokens": .int(snapshot.promptTokens), + "completion_tokens": .int(snapshot.completionTokens), + ] + + if let queueMilliseconds = snapshot.queueMilliseconds { + fields["queue_ms"] = .double(queueMilliseconds) + } + if let admitMilliseconds = snapshot.admitMilliseconds { + fields["admit_ms"] = .double(admitMilliseconds) + } + if let ttftMilliseconds = snapshot.ttftMilliseconds { + fields["ttft_ms"] = .double(ttftMilliseconds) + } + if let totalMilliseconds = snapshot.totalMilliseconds { + fields["total_ms"] = .double(totalMilliseconds) + // Preserve compatibility with existing dashboards that key on the + // generic duration field while still emitting the explicit metric. + fields["duration_ms"] = .double(totalMilliseconds) + } + if let activeCount = snapshot.activeCount { + fields["active_count"] = .int(activeCount) + fields["queue_depth"] = .int(activeCount) + } + if let error { + fields["error"] = .string(error) + } + if let reason { + fields["reason"] = .string(reason) + } + + return TelemetryEvent( + source: .provider, + severity: severity(for: stage), + kind: kind(for: stage), + message: stage.rawValue + ) + .withRequestId(snapshot.requestId) + .withFields(fields) + } + + public static func milliseconds(_ duration: Duration) -> Double { + let components = duration.components + return (Double(components.seconds) * 1_000.0) + + (Double(components.attoseconds) / 1e15) + } + + private static func severity(for stage: InferenceLifecycleTelemetryStage) -> TelemetrySeverity { + switch stage { + case .inferenceError: + return .error + case .inferenceCancel: + return .warn + case .schedulerAdmit, .firstToken, .inferenceComplete: + return .info + } + } + + private static func kind(for stage: InferenceLifecycleTelemetryStage) -> TelemetryKind { + switch stage { + case .inferenceError: + return .inferenceError + case .schedulerAdmit, .firstToken, .inferenceComplete, .inferenceCancel: + return .custom + } + } +} \ No newline at end of file diff --git a/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift b/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift index d592032b..a1678b98 100644 --- a/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift +++ b/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift @@ -231,6 +231,8 @@ public enum TelemetryFieldFilter { "status_code", "error_class", "error", "model", "backend", "exit_code", "signal", "hardware_chip", "memory_gb", "macos_version", "handler", "provider_id", "trust_level", "queue_depth", "reason", + "queue_ms", "admit_ms", "prompt_tokens", "completion_tokens", + "ttft_ms", "total_ms", "active_count", "runtime_component", "reconnect_count", "last_error", "ws_state", "billing_method", "payment_failed", "target", ] diff --git a/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift b/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift new file mode 100644 index 00000000..28f84d33 --- /dev/null +++ b/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift @@ -0,0 +1,75 @@ +import Testing +@testable import ProviderCore + +@Suite("BatchScheduler telemetry + capacity helpers") +struct BatchSchedulerTelemetryTests { + + @Test("capacity counters report pending requests and active tokens") + func capacityCountersReflectRequestProgress() { + let counters = SchedulerCapacityCounters.from([ + InferenceRequestProgress(promptTokens: 12, completionTokens: 0, firstTokenReceived: false), + InferenceRequestProgress(promptTokens: 9, completionTokens: 3, firstTokenReceived: true), + InferenceRequestProgress(promptTokens: 4, completionTokens: 1, firstTokenReceived: true), + ]) + + #expect(counters.activeRequests == 3) + #expect(counters.pendingRequests == 1) + #expect(counters.activeTokens == 29) + } + + @Test("lifecycle telemetry event includes request correlation and timing fields") + func lifecycleTelemetryEventConstruction() { + let snapshot = InferenceLifecycleTelemetrySnapshot( + requestId: "req-123", + model: "mlx-community/test-model", + promptTokens: 17, + completionTokens: 5, + queueMilliseconds: 2.5, + admitMilliseconds: 3.25, + ttftMilliseconds: 42.75, + totalMilliseconds: 123.5, + activeCount: 4 + ) + + let event = InferenceLifecycleTelemetry.event(.inferenceComplete, snapshot: snapshot) + + #expect(event.requestId == "req-123") + #expect(event.message == "inference_complete") + #expect(event.severity == .info) + #expect(event.kind == .custom) + #expect(event.fields?["operation"]?.description == "inference_complete") + #expect(event.fields?["model"]?.description == "mlx-community/test-model") + #expect(event.fields?["prompt_tokens"]?.description == "17") + #expect(event.fields?["completion_tokens"]?.description == "5") + #expect(event.fields?["queue_ms"]?.description == "2.5") + #expect(event.fields?["admit_ms"]?.description == "3.25") + #expect(event.fields?["ttft_ms"]?.description == "42.75") + #expect(event.fields?["total_ms"]?.description == "123.5") + #expect(event.fields?["active_count"]?.description == "4") + } + + @Test("telemetry field filter preserves lifecycle metrics") + func telemetryFieldFilterAllowsLifecycleMetrics() { + let filtered = TelemetryFieldFilter.filter([ + "operation": .string("first_token"), + "prompt_tokens": .int(10), + "completion_tokens": .int(1), + "queue_ms": .double(1.0), + "admit_ms": .double(2.0), + "ttft_ms": .double(3.0), + "total_ms": .double(4.0), + "active_count": .int(2), + "not_allowed": .string("drop"), + ]) + + #expect(filtered?["operation"]?.description == "first_token") + #expect(filtered?["prompt_tokens"]?.description == "10") + #expect(filtered?["completion_tokens"]?.description == "1") + #expect(filtered?["queue_ms"]?.description == "1.0") + #expect(filtered?["admit_ms"]?.description == "2.0") + #expect(filtered?["ttft_ms"]?.description == "3.0") + #expect(filtered?["total_ms"]?.description == "4.0") + #expect(filtered?["active_count"]?.description == "2") + #expect(filtered?["not_allowed"] == nil) + } +} diff --git a/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift b/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift new file mode 100644 index 00000000..40358d80 --- /dev/null +++ b/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift @@ -0,0 +1,34 @@ +import Testing +@testable import ProviderCore + +@Test func providerStateStoresNetworkQuality() { + let state = ProviderState() + let quality = NetworkQuality( + rttMs: 80, + jitterMs: 12, + reconnectCount: 2, + websocketWriteFailures: 1, + lastWriteLatencyMs: 4 + ) + + state.networkQuality = quality + + #expect(state.networkQuality == quality) +} + +@Test func networkQualityTrackerComputesPingRttAndJitterAndCounters() { + let tracker = NetworkQualityTracker() + + tracker.recordPong(rttMs: 100) + tracker.recordPong(rttMs: 140) + tracker.recordReconnect() + tracker.recordWriteFailure() + tracker.recordWriteLatency(ms: 8) + + let snapshot = tracker.snapshot() + #expect(snapshot.rttMs == 140) + #expect(snapshot.jitterMs == 40) + #expect(snapshot.reconnectCount == 1) + #expect(snapshot.websocketWriteFailures == 1) + #expect(snapshot.lastWriteLatencyMs == 8) +} diff --git a/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift b/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift index 6d395b7e..21f5d823 100644 --- a/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift +++ b/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift @@ -53,6 +53,13 @@ import Testing warmModels: ["mlx-community/Qwen2.5-7B-4bit"], stats: ProviderStats(requestsServed: 4, tokensGenerated: 4096), systemMetrics: SystemMetrics(memoryPressure: 0.2, cpuUsage: 0.3, thermalState: .nominal), + networkQuality: NetworkQuality( + rttMs: 42.5, + jitterMs: 7.25, + reconnectCount: 2, + websocketWriteFailures: 1, + lastWriteLatencyMs: 3.5 + ), backendCapacity: BackendCapacity( slots: [BackendSlotCapacity( model: "mlx-community/Qwen2.5-7B-4bit", @@ -160,6 +167,37 @@ import Testing #expect(failedObj["error"] as? String == "GPU OOM") } +@Test func heartbeatNetworkQualityUsesSnakeCaseAndDefaultsToZero() throws { + let heartbeat = ProviderMessage.heartbeat(ProviderMessage.Heartbeat( + status: .idle, + stats: ProviderStats(), + systemMetrics: SystemMetrics(memoryPressure: 0, cpuUsage: 0, thermalState: .nominal), + networkQuality: NetworkQuality( + rttMs: 125.5, + jitterMs: 24.25, + reconnectCount: 3, + websocketWriteFailures: 2, + lastWriteLatencyMs: 9.75 + ) + )) + + let data = try ProviderProtocolCodec.encodeProviderMessage(heartbeat) + let object = try jsonObject(data) + let network = object["network_quality"] as? [String: Any] + #expect(network?["rtt_ms"] as? Double == 125.5) + #expect(network?["jitter_ms"] as? Double == 24.25) + #expect(network?["reconnect_count"] as? Int == 3) + #expect(network?["websocket_write_failures"] as? Int == 2) + #expect(network?["last_write_latency_ms"] as? Double == 9.75) + + let legacyJSON = #"{"type":"heartbeat","status":"idle","stats":{"requests_served":0,"tokens_generated":0},"system_metrics":{"memory_pressure":0,"cpu_usage":0,"thermal_state":"nominal"}}"# + let decoded = try ProviderProtocolCodec.decodeProviderMessage(from: legacyJSON) + guard case .heartbeat(let legacy) = decoded else { + throw TestFailure.unexpectedMessage + } + #expect(legacy.networkQuality == NetworkQuality()) +} + @Test func coordinatorMessagesDecodeAndEncodeWithSnakeCaseKeys() throws { let encryptedRequest = #"{"type":"inference_request","request_id":"go-enc-req-1","body":null,"encrypted_body":{"ephemeral_public_key":"ZXBoZW1lcmFs","ciphertext":"Y2lwaGVy"}}"# let request = try ProviderProtocolCodec.decodeCoordinatorMessage(from: encryptedRequest)