diff --git a/coordinator/internal/api/consumer.go b/coordinator/internal/api/consumer.go
index 584dda01..5b78b26b 100644
--- a/coordinator/internal/api/consumer.go
+++ b/coordinator/internal/api/consumer.go
@@ -53,9 +53,16 @@ const (
// backend crashed, model not loaded after idle shutdown).
maxDispatchAttempts = 3
- // firstChunkTimeout is how long to wait for the first chunk from a provider
- // before considering the attempt failed and retrying.
- firstChunkTimeout = 10 * time.Second
+ // defaultFirstResponseTimeout is the baseline TTFT deadline before the
+ // coordinator retries on another provider. Keep this short: retries are
+ // invisible until the first chunk, so this is our Valorant-style timing
+ // authority guardrail for fast TTFT.
+ defaultFirstResponseTimeout = 4 * time.Second
+
+ // maxFirstResponseTimeout caps adaptive first-response deadlines for
+ // intentionally queued/backlogged routes. A route can earn extra time from
+ // its measured queue/backlog/network cost, but never an unbounded wait.
+ maxFirstResponseTimeout = 15 * time.Second
// cancelWriteTimeout bounds how long a cancel write to the provider can
// block. Using context.Background() unbounded here risks hanging the HTTP
@@ -65,6 +72,18 @@ const (
var thinkBlockPattern = regexp.MustCompile(`(?is)(.*?)\s*`)
+func firstResponseTimeout(decision registry.RoutingDecision) time.Duration {
+ costMs := decision.StateMs + decision.QueueMs + decision.PendingMs + decision.BacklogMs + decision.NetworkMs
+ if costMs <= 0 {
+ return defaultFirstResponseTimeout
+ }
+ adaptive := defaultFirstResponseTimeout + time.Duration(costMs*float64(time.Millisecond))
+ if adaptive > maxFirstResponseTimeout {
+ return maxFirstResponseTimeout
+ }
+ return adaptive
+}
+
// sendProviderCancel sends a Cancel message for the given request to the
// provider with a bounded timeout so a half-dead WebSocket doesn't hang the
// caller. Errors are logged at debug level because a disconnect race is the
@@ -786,7 +805,7 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
// Wait for an accepted signal, first chunk, or error before committing.
// No HTTP response has been written yet, so retries are invisible.
- timer := time.NewTimer(firstChunkTimeout)
+ timer := time.NewTimer(firstResponseTimeout(decision))
accepted := false
select {
case <-pr.AcceptedCh:
@@ -888,6 +907,7 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
}
// Provider accepted or sent first chunk — commit to this provider.
+ s.dispatchStandbyPreloads(r.Context(), model, provider)
// If only accepted (no chunk yet), wait for the first chunk with
// the full inference timeout since the backend may be reloading.
if accepted && !committed {
diff --git a/coordinator/internal/api/consumer_test.go b/coordinator/internal/api/consumer_test.go
index 6b6eeab2..4dea3bf4 100644
--- a/coordinator/internal/api/consumer_test.go
+++ b/coordinator/internal/api/consumer_test.go
@@ -32,6 +32,33 @@ func testServer(t *testing.T) (*Server, *store.MemoryStore) {
return srv, st
}
+func TestFirstResponseTimeoutUsesFastDefaultForHealthyRoutes(t *testing.T) {
+ got := firstResponseTimeout(registry.RoutingDecision{})
+ if got != defaultFirstResponseTimeout {
+ t.Fatalf("firstResponseTimeout(zero)=%s, want %s", got, defaultFirstResponseTimeout)
+ }
+}
+
+func TestFirstResponseTimeoutAccountsForQueueAndNetworkButCaps(t *testing.T) {
+ decision := registry.RoutingDecision{
+ StateMs: 500,
+ QueueMs: 3_000,
+ PendingMs: 750,
+ BacklogMs: 2_000,
+ NetworkMs: 1_000,
+ }
+ got := firstResponseTimeout(decision)
+ if got <= defaultFirstResponseTimeout {
+ t.Fatalf("firstResponseTimeout(degraded)=%s, want > %s", got, defaultFirstResponseTimeout)
+ }
+
+ decision.BacklogMs = 100_000
+ got = firstResponseTimeout(decision)
+ if got != maxFirstResponseTimeout {
+ t.Fatalf("firstResponseTimeout(huge backlog)=%s, want cap %s", got, maxFirstResponseTimeout)
+ }
+}
+
func TestHealthEndpoint(t *testing.T) {
srv, _ := testServer(t)
diff --git a/coordinator/internal/api/prewarm.go b/coordinator/internal/api/prewarm.go
new file mode 100644
index 00000000..627520be
--- /dev/null
+++ b/coordinator/internal/api/prewarm.go
@@ -0,0 +1,90 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/eigeninference/coordinator/internal/protocol"
+ "github.com/eigeninference/coordinator/internal/registry"
+ "nhooyr.io/websocket"
+)
+
+const (
+ standbyPrewarmTimeout = 2 * time.Second
+ maxStandbyPrewarmHints = 2
+)
+
+func providerIsReadyForModel(p *registry.Provider, model string) bool {
+ if p == nil || model == "" {
+ return false
+ }
+ p.Mu().Lock()
+ defer p.Mu().Unlock()
+ if p.CurrentModel == model {
+ return true
+ }
+ for _, warm := range p.WarmModels {
+ if warm == model {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *Server) dispatchStandbyPreloads(ctx context.Context, model string, primary *registry.Provider) {
+ if s == nil || s.registry == nil || model == "" {
+ return
+ }
+ writer := s.providerPrewarmWriter
+ if writer == nil {
+ writer = s.writeProviderPrewarm
+ }
+ primaryID := ""
+ if primary != nil {
+ primaryID = primary.ID
+ }
+ count := 0
+ s.registry.ForEachProvider(func(p *registry.Provider) {
+ if count >= maxStandbyPrewarmHints || p == nil || p.ID == primaryID {
+ return
+ }
+ p.Mu().Lock()
+ eligible := p.Backend == registry.BackendMLXSwift &&
+ p.Status != registry.StatusOffline && p.Status != registry.StatusUntrusted &&
+ p.RuntimeVerified && providerHasModelLocked(p, model)
+ p.Mu().Unlock()
+ if !eligible || providerIsReadyForModel(p, model) {
+ return
+ }
+ count++
+ go func(provider *registry.Provider) {
+ prewarmCtx, cancel := context.WithTimeout(ctx, standbyPrewarmTimeout)
+ defer cancel()
+ msg := protocol.LoadModelMessage{Type: protocol.TypeLoadModel, ModelID: model}
+ if err := writer(prewarmCtx, provider, msg); err != nil && s.logger != nil {
+ s.logger.Debug("failed to dispatch standby preload", "provider_id", provider.ID, "model", model, "error", err)
+ }
+ }(p)
+ })
+}
+
+func providerHasModelLocked(p *registry.Provider, model string) bool {
+ for _, m := range p.Models {
+ if m.ID == model {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *Server) writeProviderPrewarm(ctx context.Context, p *registry.Provider, msg protocol.LoadModelMessage) error {
+ if p == nil || p.Conn == nil {
+ return nil
+ }
+ data, err := json.Marshal(msg)
+ if err != nil {
+ return err
+ }
+ return p.Conn.Write(ctx, websocket.MessageText, data)
+}
diff --git a/coordinator/internal/api/prewarm_test.go b/coordinator/internal/api/prewarm_test.go
new file mode 100644
index 00000000..58c01ec0
--- /dev/null
+++ b/coordinator/internal/api/prewarm_test.go
@@ -0,0 +1,96 @@
+package api
+
+import (
+ "context"
+ "log/slog"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/eigeninference/coordinator/internal/protocol"
+ "github.com/eigeninference/coordinator/internal/registry"
+)
+
+func TestProviderIsReadyForModel(t *testing.T) {
+ p := ®istry.Provider{CurrentModel: "m1", WarmModels: []string{"m2"}}
+ if !providerIsReadyForModel(p, "m1") {
+ t.Fatal("current model should be ready")
+ }
+ if !providerIsReadyForModel(p, "m2") {
+ t.Fatal("warm model should be ready")
+ }
+ if providerIsReadyForModel(p, "m3") {
+ t.Fatal("unknown model should not be ready")
+ }
+}
+
+func TestDispatchStandbyPreloadsColdSwiftProviders(t *testing.T) {
+ logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
+ reg := registry.New(logger)
+ model := "standby-model"
+ primary := ®istry.Provider{
+ ID: "primary",
+ Backend: registry.BackendMLXSwift,
+ Models: []protocol.ModelInfo{{ID: model}},
+ WarmModels: []string{model},
+ Status: registry.StatusServing,
+ TrustLevel: registry.TrustHardware,
+ RuntimeVerified: true,
+ LastHeartbeat: time.Now(),
+ }
+ cold := ®istry.Provider{
+ ID: "cold",
+ Backend: registry.BackendMLXSwift,
+ Models: []protocol.ModelInfo{{ID: model}},
+ Status: registry.StatusOnline,
+ TrustLevel: registry.TrustHardware,
+ RuntimeVerified: true,
+ LastHeartbeat: time.Now(),
+ }
+
+ primaryMsg := protocol.RegisterMessage{
+ Type: protocol.TypeRegister,
+ Models: []protocol.ModelInfo{{ID: model}},
+ Backend: registry.BackendMLXSwift,
+ EncryptedResponseChunks: true,
+ PublicKey: "pub-primary",
+ PrivacyCapabilities: testPrivacyCaps(),
+ }
+ coldMsg := primaryMsg
+ coldMsg.PublicKey = "pub-cold"
+ primary = reg.Register("primary", nil, &primaryMsg)
+ cold = reg.Register("cold", nil, &coldMsg)
+ primary.WarmModels = []string{model}
+ primary.RuntimeVerified = true
+ cold.RuntimeVerified = true
+ reg.SetTrustLevel(primary.ID, registry.TrustHardware)
+ reg.SetTrustLevel(cold.ID, registry.TrustHardware)
+ reg.RecordChallengeSuccess(primary.ID)
+ reg.RecordChallengeSuccess(cold.ID)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ sent := make(chan protocol.LoadModelMessage, 1)
+ server := &Server{registry: reg, logger: logger, providerPrewarmWriter: func(ctx context.Context, p *registry.Provider, msg protocol.LoadModelMessage) error {
+ sent <- msg
+ return nil
+ }}
+
+ server.dispatchStandbyPreloads(ctx, model, primary)
+ select {
+ case msg := <-sent:
+ if msg.Type != protocol.TypeLoadModel || msg.ModelID != model {
+ t.Fatalf("unexpected preload message: %+v", msg)
+ }
+ case <-time.After(time.Second):
+ t.Fatal("expected standby preload message")
+ }
+
+ cold.WarmModels = []string{model}
+ server.dispatchStandbyPreloads(ctx, model, primary)
+ select {
+ case msg := <-sent:
+ t.Fatalf("unexpected preload for already-warm provider: %+v", msg)
+ default:
+ }
+}
diff --git a/coordinator/internal/api/server.go b/coordinator/internal/api/server.go
index aad12623..2ebe4bcd 100644
--- a/coordinator/internal/api/server.go
+++ b/coordinator/internal/api/server.go
@@ -182,6 +182,10 @@ type Server struct {
// and used by internal counters/histograms. Never nil.
metrics *Metrics
+ // providerPrewarmWriter sends coordinator-driven preload hints. Tests can
+ // replace it to assert standby behavior without a real WebSocket.
+ providerPrewarmWriter func(context.Context, *registry.Provider, protocol.LoadModelMessage) error
+
// telemetryLimiter throttles telemetry ingestion per submitter.
telemetryLimiter *telemetryLimiter
@@ -230,15 +234,16 @@ func NewServer(reg *registry.Registry, st store.Store, logger *slog.Logger) *Ser
reg.SetStore(st)
s := &Server{
- registry: reg,
- store: st,
- ledger: payments.NewLedger(st),
- logger: logger,
- mux: http.NewServeMux(),
- knownRuntimeManifest: &RuntimeManifest{},
- metrics: NewMetrics(),
- telemetryLimiter: newTelemetryLimiter(),
- readCache: newTTLCache(),
+ registry: reg,
+ store: st,
+ ledger: payments.NewLedger(st),
+ logger: logger,
+ mux: http.NewServeMux(),
+ knownRuntimeManifest: &RuntimeManifest{},
+ metrics: NewMetrics(),
+ providerPrewarmWriter: nil,
+ telemetryLimiter: newTelemetryLimiter(),
+ readCache: newTTLCache(),
}
s.registerDefaultGauges()
s.routes()
diff --git a/coordinator/internal/api/telemetry_handlers.go b/coordinator/internal/api/telemetry_handlers.go
index 44c73674..1e4e2518 100644
--- a/coordinator/internal/api/telemetry_handlers.go
+++ b/coordinator/internal/api/telemetry_handlers.go
@@ -57,13 +57,20 @@ var telemetryFieldAllowlist = map[string]struct{}{
"error": {},
"target": {},
// Provider / backend
- "model": {},
- "backend": {},
- "exit_code": {},
- "signal": {},
- "hardware_chip": {},
- "memory_gb": {},
- "macos_version": {},
+ "model": {},
+ "backend": {},
+ "queue_ms": {},
+ "admit_ms": {},
+ "prompt_tokens": {},
+ "completion_tokens": {},
+ "ttft_ms": {},
+ "total_ms": {},
+ "active_count": {},
+ "exit_code": {},
+ "signal": {},
+ "hardware_chip": {},
+ "memory_gb": {},
+ "macos_version": {},
// Coordinator
"handler": {},
"provider_id": {},
diff --git a/coordinator/internal/protocol/messages.go b/coordinator/internal/protocol/messages.go
index ab94c4a1..41a33668 100644
--- a/coordinator/internal/protocol/messages.go
+++ b/coordinator/internal/protocol/messages.go
@@ -134,6 +134,7 @@ type HeartbeatMessage struct {
Stats HeartbeatStats `json:"stats"`
WarmModels []string `json:"warm_models,omitempty"` // models currently loaded in memory
SystemMetrics SystemMetrics `json:"system_metrics"` // live resource utilization
+ NetworkQuality NetworkQuality `json:"network_quality"` // provider-observed coordinator transport quality
BackendCapacity *BackendCapacity `json:"backend_capacity,omitempty"` // live backend capacity (nil for old providers)
}
@@ -166,6 +167,17 @@ type SystemMetrics struct {
ThermalState string `json:"thermal_state"` // nominal, fair, serious, critical
}
+// NetworkQuality contains provider-observed coordinator WebSocket transport
+// health. All fields default to zero for backward compatibility with older
+// providers, and zero means "no measured penalty" rather than "bad".
+type NetworkQuality struct {
+ RTTMs float64 `json:"rtt_ms"` // latest WebSocket ping/pong round-trip time
+ JitterMs float64 `json:"jitter_ms"` // absolute delta between latest and previous RTT
+ ReconnectCount int64 `json:"reconnect_count"` // reconnect attempts since provider process start
+ WebSocketWriteFailures int64 `json:"websocket_write_failures"` // failed WebSocket writes since provider process start
+ LastWriteLatencyMs float64 `json:"last_write_latency_ms"` // duration of most recent successful WebSocket write
+}
+
// HeartbeatStats contains counters reported in heartbeats.
type HeartbeatStats struct {
RequestsServed int64 `json:"requests_served"`
diff --git a/coordinator/internal/protocol/network_quality_test.go b/coordinator/internal/protocol/network_quality_test.go
new file mode 100644
index 00000000..babe7da2
--- /dev/null
+++ b/coordinator/internal/protocol/network_quality_test.go
@@ -0,0 +1,61 @@
+package protocol
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestHeartbeatNetworkQualityRoundTrip(t *testing.T) {
+ msg := HeartbeatMessage{
+ Type: TypeHeartbeat,
+ Status: "idle",
+ Stats: HeartbeatStats{},
+ NetworkQuality: NetworkQuality{
+ RTTMs: 125.5,
+ JitterMs: 42.25,
+ ReconnectCount: 3,
+ WebSocketWriteFailures: 2,
+ LastWriteLatencyMs: 17.75,
+ },
+ }
+
+ data, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("marshal: %v", err)
+ }
+
+ var decoded HeartbeatMessage
+ if err := json.Unmarshal(data, &decoded); err != nil {
+ t.Fatalf("unmarshal: %v", err)
+ }
+
+ if decoded.NetworkQuality.RTTMs != 125.5 {
+ t.Fatalf("rtt_ms=%f, want 125.5", decoded.NetworkQuality.RTTMs)
+ }
+ if decoded.NetworkQuality.JitterMs != 42.25 {
+ t.Fatalf("jitter_ms=%f, want 42.25", decoded.NetworkQuality.JitterMs)
+ }
+ if decoded.NetworkQuality.ReconnectCount != 3 {
+ t.Fatalf("reconnect_count=%d, want 3", decoded.NetworkQuality.ReconnectCount)
+ }
+ if decoded.NetworkQuality.WebSocketWriteFailures != 2 {
+ t.Fatalf("websocket_write_failures=%d, want 2", decoded.NetworkQuality.WebSocketWriteFailures)
+ }
+ if decoded.NetworkQuality.LastWriteLatencyMs != 17.75 {
+ t.Fatalf("last_write_latency_ms=%f, want 17.75", decoded.NetworkQuality.LastWriteLatencyMs)
+ }
+}
+
+func TestHeartbeatMissingNetworkQualityDefaultsToZero(t *testing.T) {
+ raw := `{"type":"heartbeat","status":"idle","active_model":null,"stats":{"requests_served":0,"tokens_generated":0}}`
+
+ var pm ProviderMessage
+ if err := json.Unmarshal([]byte(raw), &pm); err != nil {
+ t.Fatalf("unmarshal: %v", err)
+ }
+
+ hb := pm.Payload.(*HeartbeatMessage)
+ if hb.NetworkQuality != (NetworkQuality{}) {
+ t.Fatalf("NetworkQuality=%+v, want zero value", hb.NetworkQuality)
+ }
+}
diff --git a/coordinator/internal/registry/network_quality_test.go b/coordinator/internal/registry/network_quality_test.go
new file mode 100644
index 00000000..ae8321ab
--- /dev/null
+++ b/coordinator/internal/registry/network_quality_test.go
@@ -0,0 +1,95 @@
+package registry
+
+import (
+ "testing"
+
+ "github.com/eigeninference/coordinator/internal/protocol"
+)
+
+func TestNetworkPenaltyDefaultsToZeroForMissingMetrics(t *testing.T) {
+ if got := networkPenaltyMs(protocol.NetworkQuality{}); got != 0 {
+ t.Fatalf("networkPenaltyMs(zero)=%f, want 0", got)
+ }
+}
+
+func TestNetworkPenaltyDegradesHighLatencyAndInstability(t *testing.T) {
+ good := protocol.NetworkQuality{RTTMs: 20, JitterMs: 5}
+ bad := protocol.NetworkQuality{
+ RTTMs: 900,
+ JitterMs: 250,
+ ReconnectCount: 4,
+ WebSocketWriteFailures: 3,
+ LastWriteLatencyMs: 150,
+ }
+
+ goodPenalty := networkPenaltyMs(good)
+ badPenalty := networkPenaltyMs(bad)
+ if goodPenalty != 0 {
+ t.Fatalf("good network penalty=%f, want 0 below thresholds", goodPenalty)
+ }
+ if badPenalty <= goodPenalty {
+ t.Fatalf("bad penalty=%f, want > good penalty=%f", badPenalty, goodPenalty)
+ }
+ if badPenalty > networkQualityMaxPenaltyMs {
+ t.Fatalf("bad penalty=%f exceeds max=%f", badPenalty, networkQualityMaxPenaltyMs)
+ }
+}
+
+func TestReserveProviderPenalizesLowerNetworkQuality(t *testing.T) {
+ reg := New(testLogger())
+ model := "network-quality-model"
+ good := makeSchedulerProvider(t, reg, "good-network", model, 100)
+ bad := makeSchedulerProvider(t, reg, "bad-network", model, 100)
+
+ bad.mu.Lock()
+ bad.NetworkQuality = protocol.NetworkQuality{
+ RTTMs: 1000,
+ JitterMs: 300,
+ ReconnectCount: 5,
+ WebSocketWriteFailures: 5,
+ LastWriteLatencyMs: 250,
+ }
+ bad.mu.Unlock()
+
+ selected, decision := reg.ReserveProviderEx(model, &PendingRequest{
+ RequestID: "network-quality-req",
+ Model: model,
+ RequestedMaxTokens: 128,
+ })
+ if selected == nil {
+ t.Fatal("ReserveProviderEx returned nil")
+ }
+ if selected.ID != good.ID {
+ t.Fatalf("selected %q, want good-network; decision=%+v", selected.ID, decision)
+ }
+ if decision.NetworkMs != 0 {
+ t.Fatalf("winning good provider NetworkMs=%f, want 0", decision.NetworkMs)
+ }
+}
+
+func TestHeartbeatStoresNetworkQualityForRoutingSnapshot(t *testing.T) {
+ reg := New(testLogger())
+ model := "heartbeat-network-model"
+ p := makeSchedulerProvider(t, reg, "heartbeat-network", model, 100)
+
+ reg.Heartbeat(p.ID, &protocol.HeartbeatMessage{
+ Type: protocol.TypeHeartbeat,
+ Status: "idle",
+ Stats: protocol.HeartbeatStats{},
+ NetworkQuality: protocol.NetworkQuality{
+ RTTMs: 321,
+ JitterMs: 45,
+ ReconnectCount: 2,
+ WebSocketWriteFailures: 1,
+ LastWriteLatencyMs: 12,
+ },
+ })
+
+ snap, ok := reg.snapshotProviderLocked(p, model, &PendingRequest{RequestID: "snap", Model: model})
+ if !ok {
+ t.Fatal("snapshotProviderLocked returned !ok")
+ }
+ if snap.networkQuality.RTTMs != 321 {
+ t.Fatalf("snapshot RTTMs=%f, want 321", snap.networkQuality.RTTMs)
+ }
+}
diff --git a/coordinator/internal/registry/registry.go b/coordinator/internal/registry/registry.go
index 5cd4c013..c2beb511 100644
--- a/coordinator/internal/registry/registry.go
+++ b/coordinator/internal/registry/registry.go
@@ -130,7 +130,8 @@ type Provider struct {
CurrentModel string // model currently being served
// Live system metrics from heartbeats
- SystemMetrics protocol.SystemMetrics
+ SystemMetrics protocol.SystemMetrics
+ NetworkQuality protocol.NetworkQuality
// Live backend capacity from heartbeats (nil for providers without capacity reporting)
BackendCapacity *protocol.BackendCapacity
@@ -696,6 +697,33 @@ func clampNonNeg(v, max float64) (float64, bool) {
// TotalMemoryGB=1e9 would make gpuUtil ~= 0 and dodge health penalties, so
// we cap it at maxMemoryGBFloat. Same for MaxTokensPotential which directly
// controls backlog cost. NaN/negative become 0.
+func clampNetworkQuality(n *protocol.NetworkQuality) {
+ if n == nil {
+ return
+ }
+ if v, changed := clampNonNeg(n.RTTMs, 10_000); changed {
+ n.RTTMs = v
+ }
+ if v, changed := clampNonNeg(n.JitterMs, 10_000); changed {
+ n.JitterMs = v
+ }
+ if v, changed := clampNonNeg(n.LastWriteLatencyMs, 10_000); changed {
+ n.LastWriteLatencyMs = v
+ }
+ if n.ReconnectCount < 0 {
+ n.ReconnectCount = 0
+ }
+ if n.ReconnectCount > 1_000_000 {
+ n.ReconnectCount = 1_000_000
+ }
+ if n.WebSocketWriteFailures < 0 {
+ n.WebSocketWriteFailures = 0
+ }
+ if n.WebSocketWriteFailures > 1_000_000 {
+ n.WebSocketWriteFailures = 1_000_000
+ }
+}
+
func clampBackendCapacity(logger *slog.Logger, providerID string, bc *protocol.BackendCapacity) {
if bc == nil {
return
@@ -923,6 +951,7 @@ func (r *Registry) Heartbeat(id string, msg *protocol.HeartbeatMessage) {
if v, changed := clampNonNeg(msg.SystemMetrics.CPUUsage, 1.0); changed {
msg.SystemMetrics.CPUUsage = v
}
+ clampNetworkQuality(&msg.NetworkQuality)
p.mu.Lock()
p.LastHeartbeat = time.Now()
@@ -930,6 +959,7 @@ func (r *Registry) Heartbeat(id string, msg *protocol.HeartbeatMessage) {
p.Stats.TokensGenerated += cumulativeDelta(p.lastSessionStats.TokensGenerated, msg.Stats.TokensGenerated)
p.lastSessionStats = msg.Stats
p.SystemMetrics = msg.SystemMetrics
+ p.NetworkQuality = msg.NetworkQuality
// Update backend capacity from heartbeat (nil-safe for old providers).
if msg.BackendCapacity != nil {
p.BackendCapacity = msg.BackendCapacity
diff --git a/coordinator/internal/registry/scheduler.go b/coordinator/internal/registry/scheduler.go
index 9f516df3..497f9153 100644
--- a/coordinator/internal/registry/scheduler.go
+++ b/coordinator/internal/registry/scheduler.go
@@ -26,15 +26,16 @@ const (
// slow-provider decode, so the cost function actually spreads load
// across the fleet. Wider tie window admits more candidates to the
// queue-depth tie-break + random distribution.
- queueDepthPenaltyMs = 3_000.0
- totalPendingPenaltyMs = 750.0
- memoryPressurePenaltyMs = 4_000.0
- cpuUsagePenaltyMs = 1_500.0
- gpuUtilizationPenaltyMs = 5_000.0
- thermalPenaltyFairMs = 2_000.0
- thermalPenaltySeriousMs = 8_000.0
- nearTieCostWindowMs = 3_000.0
- challengeFreshnessMaxAge = 6 * time.Minute
+ queueDepthPenaltyMs = 3_000.0
+ totalPendingPenaltyMs = 750.0
+ memoryPressurePenaltyMs = 4_000.0
+ cpuUsagePenaltyMs = 1_500.0
+ gpuUtilizationPenaltyMs = 5_000.0
+ thermalPenaltyFairMs = 2_000.0
+ thermalPenaltySeriousMs = 8_000.0
+ networkQualityMaxPenaltyMs = 5_000.0
+ nearTieCostWindowMs = 3_000.0
+ challengeFreshnessMaxAge = 6 * time.Minute
// kvCacheBytesPerToken is a per-token KV-cache size estimate used by
// the free-memory admission gate.
@@ -77,6 +78,7 @@ type routingSnapshot struct {
decodeTPS float64
prefillTPS float64
systemMetrics protocol.SystemMetrics
+ networkQuality protocol.NetworkQuality
gpuMemoryActiveGB float64
totalMemoryGB float64
modelSizeGB float64 // catalog-reported weight footprint (0 = unknown, gate disabled)
@@ -115,6 +117,7 @@ type costBreakdown struct {
BacklogMs float64
ThisReqMs float64
HealthMs float64
+ NetworkMs float64
Total float64
}
@@ -131,6 +134,7 @@ type RoutingDecision struct {
BacklogMs float64 // tokens-ahead / decodeTPS contribution
ThisReqMs float64 // this request's prefill+decode contribution
HealthMs float64 // memory/CPU/thermal/GPU-util contribution
+ NetworkMs float64 // bounded network-quality contribution
EffectiveQueue int // max(pendingForModel, backendRunning+backendWaiting)
CandidateCount int // total candidates that passed all gates
CapacityRejections int // candidates rejected by the free-memory admission gate
@@ -206,6 +210,7 @@ func (r *Registry) ReserveProviderEx(model string, pr *PendingRequest, excludeID
BacklogMs: bd.BacklogMs,
ThisReqMs: bd.ThisReqMs,
HealthMs: bd.HealthMs,
+ NetworkMs: bd.NetworkMs,
EffectiveQueue: selected.effectiveQueue,
CandidateCount: candidateCount,
CapacityRejections: capacityRejections,
@@ -349,6 +354,7 @@ func (r *Registry) logRoutingDecision(model string, pr *PendingRequest, winner *
"backlog_ms", bd.BacklogMs,
"this_req_ms", bd.ThisReqMs,
"health_ms", bd.HealthMs,
+ "network_ms", bd.NetworkMs,
"effective_tps", winner.effectiveTPS,
"effective_queue", winner.effectiveQueue,
"candidates", candidates,
@@ -384,15 +390,16 @@ func (r *Registry) snapshotProviderLocked(p *Provider, model string, pr *Pending
}
snap := routingSnapshot{
- provider: p,
- model: model,
- slotState: "unknown",
- totalPending: p.pendingCount(),
- systemMetrics: p.SystemMetrics,
- decodeTPS: resolvedDecodeTPS(p),
- prefillTPS: resolvedPrefillTPS(p),
- totalMemoryGB: float64(p.Hardware.MemoryGB),
- modelSizeGB: r.catalogSizeGBLocked(model),
+ provider: p,
+ model: model,
+ slotState: "unknown",
+ totalPending: p.pendingCount(),
+ systemMetrics: p.SystemMetrics,
+ networkQuality: p.NetworkQuality,
+ decodeTPS: resolvedDecodeTPS(p),
+ prefillTPS: resolvedPrefillTPS(p),
+ totalMemoryGB: float64(p.Hardware.MemoryGB),
+ modelSizeGB: r.catalogSizeGBLocked(model),
}
for _, pr := range p.pendingReqs {
@@ -513,7 +520,8 @@ func (r *Registry) buildCandidateWithReason(snap routingSnapshot, pr *PendingReq
backlogMs := backlogTokenMs(snap.maxTokensPotential, waitingBacklogTokens, unaccountedPendingTokens, effectiveTPS)
thisReqMs := float64(reqPrompt)/snap.prefillTPS*1000.0 + float64(reqMax)/effectiveTPS*1000.0
healthMs := healthPenaltyMs(snap.systemMetrics, snap.gpuMemoryActiveGB, snap.totalMemoryGB)
- cost := statePenalty + queueMs + pendingMs + backlogMs + thisReqMs + healthMs
+ networkMs := networkPenaltyMs(snap.networkQuality)
+ cost := statePenalty + queueMs + pendingMs + backlogMs + thisReqMs + healthMs + networkMs
return &routingCandidate{
provider: snap.provider,
@@ -528,6 +536,7 @@ func (r *Registry) buildCandidateWithReason(snap routingSnapshot, pr *PendingReq
BacklogMs: backlogMs,
ThisReqMs: thisReqMs,
HealthMs: healthMs,
+ NetworkMs: networkMs,
Total: cost,
},
}, rejectNone, true
@@ -580,6 +589,35 @@ func healthPenaltyMs(m protocol.SystemMetrics, gpuActiveGB, totalMemGB float64)
return penalty
}
+func networkPenaltyMs(n protocol.NetworkQuality) float64 {
+ // Keep normal regional RTT/jitter free, but penalize degraded transport
+ // enough to break otherwise-similar ties. The cap preserves hardware/load
+ // as dominant signals and keeps malicious counters bounded.
+ penalty := 0.0
+ if n.RTTMs > 50 {
+ penalty += (n.RTTMs - 50) * 2.0
+ }
+ if n.JitterMs > 20 {
+ penalty += (n.JitterMs - 20) * 5.0
+ }
+ if n.LastWriteLatencyMs > 25 {
+ penalty += (n.LastWriteLatencyMs - 25) * 2.0
+ }
+ if n.ReconnectCount > 0 {
+ penalty += float64(n.ReconnectCount) * 250.0
+ }
+ if n.WebSocketWriteFailures > 0 {
+ penalty += float64(n.WebSocketWriteFailures) * 500.0
+ }
+ if penalty < 0 || math.IsNaN(penalty) {
+ return 0
+ }
+ if penalty > networkQualityMaxPenaltyMs {
+ return networkQualityMaxPenaltyMs
+ }
+ return penalty
+}
+
// effectiveDecodeTPS scales the static decode TPS down by current
// backend batch size. Returns the static value when the load factor is
// disabled or batch is unknown. Floored at 1 token/s to avoid divide-
diff --git a/coordinator/internal/registry/scheduler_test.go b/coordinator/internal/registry/scheduler_test.go
index d10768b0..c6134ddd 100644
--- a/coordinator/internal/registry/scheduler_test.go
+++ b/coordinator/internal/registry/scheduler_test.go
@@ -107,7 +107,7 @@ func TestReserveProviderExReturnsCostBreakdown(t *testing.T) {
}
// Sum of components should approximately equal the total cost.
sum := decision.StateMs + decision.QueueMs + decision.PendingMs +
- decision.BacklogMs + decision.ThisReqMs + decision.HealthMs
+ decision.BacklogMs + decision.ThisReqMs + decision.HealthMs + decision.NetworkMs
if diff := sum - decision.CostMs; diff > 0.001 || diff < -0.001 {
t.Fatalf("breakdown sum %f != CostMs %f", sum, decision.CostMs)
}
diff --git a/coordinator/internal/registry/simulator.go b/coordinator/internal/registry/simulator.go
new file mode 100644
index 00000000..c8555661
--- /dev/null
+++ b/coordinator/internal/registry/simulator.go
@@ -0,0 +1,426 @@
+package registry
+
+import (
+ "math"
+ "math/rand"
+ "sort"
+
+ "github.com/eigeninference/coordinator/internal/protocol"
+)
+
+// RoutingReplayStrategy names a local-only simulator routing policy.
+type RoutingReplayStrategy string
+
+const (
+ RoutingStrategyCurrentCostModel RoutingReplayStrategy = "current_cost_model"
+ RoutingStrategyRoundRobin RoutingReplayStrategy = "round_robin"
+ RoutingStrategyLeastActive RoutingReplayStrategy = "least_active"
+ // RoutingStrategyLeastMetric intentionally uses only stale snapshot metrics
+ // (CPU/memory) and ignores assignments made during the replay. It models the
+ // Riot/Valorant failure mode where a burst herds onto the machine that looked
+ // best in the last heartbeat.
+ RoutingStrategyLeastMetric RoutingReplayStrategy = "least_metric"
+ RoutingStrategyRandomNearTie RoutingReplayStrategy = "random_near_tie"
+)
+
+// RoutingReplayProvider is the synthetic provider snapshot used by the replay
+// harness. It is intentionally small and independent of live registry state.
+type RoutingReplayProvider struct {
+ ID string
+ Model string
+ DecodeTPS float64
+ PrefillTPS float64
+ MaxConcurrency int
+ CPUUsage float64
+ MemoryPressure float64
+ ThermalState string
+ TotalMemoryGB float64
+ GPUActiveGB float64
+}
+
+// RoutingReplayRequest is a synthetic request arrival.
+type RoutingReplayRequest struct {
+ ID string
+ Model string
+ ArrivalMs float64
+ PromptTokens int
+ MaxTokens int
+ TimeoutMs float64
+}
+
+// RoutingReplayScenario contains all inputs for a deterministic replay.
+type RoutingReplayScenario struct {
+ Providers []RoutingReplayProvider
+ Requests []RoutingReplayRequest
+ Strategies []RoutingReplayStrategy
+ RandomSeed int64
+}
+
+// RoutingReplayAssignment records one replayed routing decision.
+type RoutingReplayAssignment struct {
+ RequestID string
+ ProviderID string
+ ArrivalMs float64
+ StartMs float64
+ TTFTMs float64
+ TotalLatencyMs float64
+ OverCapacity bool
+ TimedOut bool
+}
+
+// RoutingReplayMetrics contains per-strategy replay output.
+type RoutingReplayMetrics struct {
+ Strategy RoutingReplayStrategy
+ Assignments []RoutingReplayAssignment
+ ProviderRequestCounts map[string]int
+ ProviderUtilization map[string]float64
+ TTFTP50Ms float64
+ TTFTP95Ms float64
+ TTFTP99Ms float64
+ TotalLatencyP50Ms float64
+ TotalLatencyP95Ms float64
+ TotalLatencyP99Ms float64
+ TotalLatencyMs float64
+ TimeoutCount int
+ OverCapacityCount int
+}
+
+// ReplayRoutingScenario replays synthetic request arrivals against provider
+// snapshots for each requested strategy. It has no side effects on the live
+// registry and is intended for local unit tests and routing experiments.
+func ReplayRoutingScenario(s RoutingReplayScenario) map[RoutingReplayStrategy]RoutingReplayMetrics {
+ strategies := s.Strategies
+ if len(strategies) == 0 {
+ strategies = []RoutingReplayStrategy{RoutingStrategyCurrentCostModel, RoutingStrategyRoundRobin, RoutingStrategyLeastActive, RoutingStrategyLeastMetric}
+ }
+ out := make(map[RoutingReplayStrategy]RoutingReplayMetrics, len(strategies))
+ for _, strategy := range strategies {
+ out[strategy] = replayRoutingStrategy(s, strategy)
+ }
+ return out
+}
+
+type replayProviderState struct {
+ spec RoutingReplayProvider
+ active []replayActiveRequest
+ busyMs float64
+ rrCursor int
+ inputSlot int
+}
+
+type replayActiveRequest struct {
+ endMs float64
+ maxTokens int
+}
+
+func replayRoutingStrategy(s RoutingReplayScenario, strategy RoutingReplayStrategy) RoutingReplayMetrics {
+ providers := make([]*replayProviderState, 0, len(s.Providers))
+ for i, p := range s.Providers {
+ if p.ID == "" {
+ p.ID = string(rune('a' + i))
+ }
+ if p.DecodeTPS <= 0 {
+ p.DecodeTPS = 1
+ }
+ if p.PrefillTPS <= 0 {
+ p.PrefillTPS = p.DecodeTPS * 4
+ }
+ if p.MaxConcurrency <= 0 {
+ p.MaxConcurrency = DefaultMaxConcurrent
+ }
+ if p.ThermalState == "" {
+ p.ThermalState = "nominal"
+ }
+ providers = append(providers, &replayProviderState{spec: p, inputSlot: i})
+ }
+
+ requests := append([]RoutingReplayRequest(nil), s.Requests...)
+ sort.SliceStable(requests, func(i, j int) bool {
+ if requests[i].ArrivalMs == requests[j].ArrivalMs {
+ return requests[i].ID < requests[j].ID
+ }
+ return requests[i].ArrivalMs < requests[j].ArrivalMs
+ })
+
+ seed := s.RandomSeed
+ if seed == 0 {
+ seed = 1
+ }
+ rng := rand.New(rand.NewSource(seed))
+ counts := make(map[string]int, len(providers))
+ util := make(map[string]float64, len(providers))
+ assignments := make([]RoutingReplayAssignment, 0, len(requests))
+ var rr int
+
+ for _, req := range requests {
+ if req.MaxTokens <= 0 {
+ req.MaxTokens = defaultRequestedMaxTokens
+ }
+ for _, p := range providers {
+ p.completeThrough(req.ArrivalMs)
+ }
+
+ selected := selectReplayProvider(providers, req, strategy, rr, rng)
+ if selected == nil {
+ assignments = append(assignments, RoutingReplayAssignment{RequestID: req.ID, ArrivalMs: req.ArrivalMs, OverCapacity: true, TimedOut: true})
+ continue
+ }
+ if strategy == RoutingStrategyRoundRobin {
+ rr = selected.inputSlot + 1
+ }
+
+ activeBefore := len(selected.active)
+ overCapacity := activeBefore >= selected.spec.MaxConcurrency
+ start := req.ArrivalMs
+ if overCapacity {
+ start = selected.nextAvailableMs(req.ArrivalMs)
+ }
+ prefillMs := float64(req.PromptTokens) / selected.spec.PrefillTPS * 1000
+ decodeTPS := effectiveDecodeTPS(selected.spec.DecodeTPS, minInt(activeBefore, selected.spec.MaxConcurrency))
+ decodeMs := float64(req.MaxTokens) / decodeTPS * 1000
+ serviceMs := prefillMs + decodeMs
+ end := start + serviceMs
+ ttft := (start - req.ArrivalMs) + prefillMs
+ latency := end - req.ArrivalMs
+ timedOut := req.TimeoutMs > 0 && ttft > req.TimeoutMs
+
+ selected.active = append(selected.active, replayActiveRequest{endMs: end, maxTokens: req.MaxTokens})
+ selected.busyMs += serviceMs
+ counts[selected.spec.ID]++
+ assignments = append(assignments, RoutingReplayAssignment{
+ RequestID: req.ID,
+ ProviderID: selected.spec.ID,
+ ArrivalMs: req.ArrivalMs,
+ StartMs: start,
+ TTFTMs: ttft,
+ TotalLatencyMs: latency,
+ OverCapacity: overCapacity,
+ TimedOut: timedOut,
+ })
+ }
+
+ var horizon float64
+ for _, req := range requests {
+ if req.ArrivalMs > horizon {
+ horizon = req.ArrivalMs
+ }
+ }
+ for _, p := range providers {
+ for _, a := range p.active {
+ if a.endMs > horizon {
+ horizon = a.endMs
+ }
+ }
+ }
+ if horizon <= 0 {
+ horizon = 1
+ }
+ for _, p := range providers {
+ counts[p.spec.ID] += 0
+ util[p.spec.ID] = p.busyMs / horizon
+ }
+
+ metrics := RoutingReplayMetrics{
+ Strategy: strategy,
+ Assignments: assignments,
+ ProviderRequestCounts: counts,
+ ProviderUtilization: util,
+ }
+ var ttfts, latencies []float64
+ for _, a := range assignments {
+ if a.ProviderID == "" {
+ metrics.OverCapacityCount++
+ metrics.TimeoutCount++
+ continue
+ }
+ ttfts = append(ttfts, a.TTFTMs)
+ latencies = append(latencies, a.TotalLatencyMs)
+ metrics.TotalLatencyMs += a.TotalLatencyMs
+ if a.OverCapacity {
+ metrics.OverCapacityCount++
+ }
+ if a.TimedOut {
+ metrics.TimeoutCount++
+ }
+ }
+ metrics.TTFTP50Ms = percentile(ttfts, 0.50)
+ metrics.TTFTP95Ms = percentile(ttfts, 0.95)
+ metrics.TTFTP99Ms = percentile(ttfts, 0.99)
+ metrics.TotalLatencyP50Ms = percentile(latencies, 0.50)
+ metrics.TotalLatencyP95Ms = percentile(latencies, 0.95)
+ metrics.TotalLatencyP99Ms = percentile(latencies, 0.99)
+ return metrics
+}
+
+func (p *replayProviderState) completeThrough(nowMs float64) {
+ kept := p.active[:0]
+ for _, a := range p.active {
+ if a.endMs > nowMs {
+ kept = append(kept, a)
+ }
+ }
+ p.active = kept
+}
+
+func (p *replayProviderState) nextAvailableMs(arrivalMs float64) float64 {
+ if len(p.active) < p.spec.MaxConcurrency {
+ return arrivalMs
+ }
+ ends := make([]float64, 0, len(p.active))
+ for _, a := range p.active {
+ ends = append(ends, a.endMs)
+ }
+ sort.Float64s(ends)
+ idx := len(p.active) - p.spec.MaxConcurrency
+ if idx < 0 {
+ idx = 0
+ }
+ if ends[idx] < arrivalMs {
+ return arrivalMs
+ }
+ return ends[idx]
+}
+
+func selectReplayProvider(providers []*replayProviderState, req RoutingReplayRequest, strategy RoutingReplayStrategy, rr int, rng *rand.Rand) *replayProviderState {
+ eligible := make([]*replayProviderState, 0, len(providers))
+ for _, p := range providers {
+ if p.spec.Model == "" || p.spec.Model == req.Model {
+ eligible = append(eligible, p)
+ }
+ }
+ if len(eligible) == 0 {
+ return nil
+ }
+
+ switch strategy {
+ case RoutingStrategyRoundRobin:
+ for i := 0; i < len(providers); i++ {
+ p := providers[(rr+i)%len(providers)]
+ if (p.spec.Model == "" || p.spec.Model == req.Model) && len(p.active) < p.spec.MaxConcurrency {
+ return p
+ }
+ }
+ return eligible[rr%len(eligible)]
+ case RoutingStrategyLeastActive:
+ return minReplayProvider(eligible, func(p *replayProviderState) float64 { return float64(len(p.active)) })
+ case RoutingStrategyLeastMetric:
+ return minReplayProvider(eligible, func(p *replayProviderState) float64 { return p.spec.CPUUsage + p.spec.MemoryPressure })
+ case RoutingStrategyRandomNearTie:
+ costs := replayCostCandidates(eligible, req)
+ if len(costs) == 0 {
+ return minReplayProvider(eligible, func(p *replayProviderState) float64 { return float64(len(p.active)) })
+ }
+ best := costs[0].cost
+ pool := make([]*replayProviderState, 0, len(costs))
+ for _, c := range costs {
+ if math.Abs(c.cost-best) <= nearTieCostWindowMs {
+ pool = append(pool, c.provider)
+ }
+ }
+ return pool[rng.Intn(len(pool))]
+ case RoutingStrategyCurrentCostModel:
+ fallthrough
+ default:
+ costs := replayCostCandidates(eligible, req)
+ if len(costs) == 0 {
+ return nil
+ }
+ return costs[0].provider
+ }
+}
+
+type replayCostCandidate struct {
+ provider *replayProviderState
+ cost float64
+}
+
+func replayCostCandidates(providers []*replayProviderState, req RoutingReplayRequest) []replayCostCandidate {
+ out := make([]replayCostCandidate, 0, len(providers))
+ reg := &Registry{}
+ for _, p := range providers {
+ if len(p.active) >= p.spec.MaxConcurrency {
+ continue
+ }
+ pendingTokens := 0
+ maxPotential := int64(0)
+ for _, a := range p.active {
+ pendingTokens += a.maxTokens
+ maxPotential += int64(a.maxTokens)
+ }
+ snap := routingSnapshot{
+ provider: &Provider{ID: p.spec.ID},
+ model: req.Model,
+ slotState: "running",
+ totalPending: len(p.active),
+ pendingForModel: len(p.active),
+ pendingMaxTokens: pendingTokens,
+ backendRunning: len(p.active),
+ maxTokensPotential: maxPotential,
+ decodeTPS: p.spec.DecodeTPS,
+ prefillTPS: p.spec.PrefillTPS,
+ systemMetrics: protocol.SystemMetrics{
+ MemoryPressure: p.spec.MemoryPressure,
+ CPUUsage: p.spec.CPUUsage,
+ ThermalState: p.spec.ThermalState,
+ },
+ gpuMemoryActiveGB: p.spec.GPUActiveGB,
+ totalMemoryGB: p.spec.TotalMemoryGB,
+ modelLoaded: true,
+ }
+ candidate, _, ok := reg.buildCandidateWithReason(snap, &PendingRequest{
+ RequestID: req.ID,
+ Model: req.Model,
+ EstimatedPromptTokens: req.PromptTokens,
+ RequestedMaxTokens: req.MaxTokens,
+ })
+ if ok {
+ out = append(out, replayCostCandidate{provider: p, cost: candidate.costMs})
+ }
+ }
+ sort.SliceStable(out, func(i, j int) bool {
+ if out[i].cost == out[j].cost {
+ if len(out[i].provider.active) == len(out[j].provider.active) {
+ return out[i].provider.spec.ID < out[j].provider.spec.ID
+ }
+ return len(out[i].provider.active) < len(out[j].provider.active)
+ }
+ return out[i].cost < out[j].cost
+ })
+ return out
+}
+
+func minReplayProvider(providers []*replayProviderState, score func(*replayProviderState) float64) *replayProviderState {
+ var best *replayProviderState
+ var bestScore float64
+ for _, p := range providers {
+ s := score(p)
+ if best == nil || s < bestScore || (s == bestScore && p.spec.ID < best.spec.ID) {
+ best = p
+ bestScore = s
+ }
+ }
+ return best
+}
+
+func percentile(values []float64, q float64) float64 {
+ if len(values) == 0 {
+ return 0
+ }
+ copyVals := append([]float64(nil), values...)
+ sort.Float64s(copyVals)
+ idx := int(math.Ceil(q*float64(len(copyVals)))) - 1
+ if idx < 0 {
+ idx = 0
+ }
+ if idx >= len(copyVals) {
+ idx = len(copyVals) - 1
+ }
+ return copyVals[idx]
+}
+
+func minInt(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
diff --git a/coordinator/internal/registry/simulator_test.go b/coordinator/internal/registry/simulator_test.go
new file mode 100644
index 00000000..4fcb2a60
--- /dev/null
+++ b/coordinator/internal/registry/simulator_test.go
@@ -0,0 +1,110 @@
+package registry
+
+import "testing"
+
+func TestRoutingReplaySimulatorRiotLessonStaleMetricHerdsWhileRoundRobinSpreads(t *testing.T) {
+ model := "riot-lesson-model"
+ scenario := RoutingReplayScenario{
+ Providers: []RoutingReplayProvider{
+ {ID: "p0", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"},
+ {ID: "p1", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.50, MemoryPressure: 0.10, ThermalState: "nominal"},
+ {ID: "p2", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.50, MemoryPressure: 0.10, ThermalState: "nominal"},
+ },
+ Requests: burstRequests(model, 6, 0, 100, 200, 1_000),
+ Strategies: []RoutingReplayStrategy{
+ RoutingStrategyLeastMetric,
+ RoutingStrategyRoundRobin,
+ },
+ }
+
+ results := ReplayRoutingScenario(scenario)
+ leastMetric := results[RoutingStrategyLeastMetric]
+ roundRobin := results[RoutingStrategyRoundRobin]
+
+ if got := leastMetric.ProviderRequestCounts["p0"]; got != 6 {
+ t.Fatalf("stale least_metric routed %d requests to p0, want 6-request herd", got)
+ }
+ if leastMetric.ProviderRequestCounts["p1"] != 0 || leastMetric.ProviderRequestCounts["p2"] != 0 {
+ t.Fatalf("stale least_metric should not spread: counts=%v", leastMetric.ProviderRequestCounts)
+ }
+ if leastMetric.OverCapacityCount != 4 {
+ t.Fatalf("least_metric over-capacity count=%d, want 4", leastMetric.OverCapacityCount)
+ }
+ if leastMetric.TimeoutCount == 0 {
+ t.Fatal("least_metric should produce TTFT timeouts when stale CPU herds the burst")
+ }
+
+ for _, id := range []string{"p0", "p1", "p2"} {
+ if got := roundRobin.ProviderRequestCounts[id]; got != 2 {
+ t.Fatalf("round_robin count for %s=%d, want 2 (counts=%v)", id, got, roundRobin.ProviderRequestCounts)
+ }
+ }
+ if roundRobin.OverCapacityCount != 0 {
+ t.Fatalf("round_robin over-capacity count=%d, want 0", roundRobin.OverCapacityCount)
+ }
+ if roundRobin.TimeoutCount != 0 {
+ t.Fatalf("round_robin timeout count=%d, want 0", roundRobin.TimeoutCount)
+ }
+ if !(leastMetric.TTFTP95Ms > roundRobin.TTFTP95Ms) {
+ t.Fatalf("expected least_metric p95 TTFT (%f) > round_robin (%f)", leastMetric.TTFTP95Ms, roundRobin.TTFTP95Ms)
+ }
+}
+
+func TestRoutingReplaySimulatorCurrentCostModelDeterministicComparison(t *testing.T) {
+ model := "cost-model-replay"
+ scenario := RoutingReplayScenario{
+ Providers: []RoutingReplayProvider{
+ {ID: "fast", Model: model, DecodeTPS: 100, PrefillTPS: 400, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"},
+ {ID: "slow", Model: model, DecodeTPS: 80, PrefillTPS: 320, MaxConcurrency: 2, CPUUsage: 0.10, MemoryPressure: 0.10, ThermalState: "nominal"},
+ },
+ Requests: burstRequests(model, 4, 0, 100, 200, 10_000),
+ Strategies: []RoutingReplayStrategy{
+ RoutingStrategyCurrentCostModel,
+ RoutingStrategyLeastActive,
+ },
+ }
+
+ first := ReplayRoutingScenario(scenario)
+ second := ReplayRoutingScenario(scenario)
+
+ costA := first[RoutingStrategyCurrentCostModel]
+ costB := second[RoutingStrategyCurrentCostModel]
+ if len(costA.Assignments) != len(costB.Assignments) {
+ t.Fatalf("assignment lengths differ: %d vs %d", len(costA.Assignments), len(costB.Assignments))
+ }
+ for i := range costA.Assignments {
+ if costA.Assignments[i].ProviderID != costB.Assignments[i].ProviderID {
+ t.Fatalf("current cost model assignment %d not deterministic: %q vs %q", i, costA.Assignments[i].ProviderID, costB.Assignments[i].ProviderID)
+ }
+ }
+ if got := costA.ProviderRequestCounts["fast"]; got != 2 {
+ t.Fatalf("current cost model fast count=%d, want 2 (counts=%v assignments=%v)", got, costA.ProviderRequestCounts, costA.Assignments)
+ }
+ if got := costA.ProviderRequestCounts["slow"]; got != 2 {
+ t.Fatalf("current cost model slow count=%d, want 2 (counts=%v assignments=%v)", got, costA.ProviderRequestCounts, costA.Assignments)
+ }
+ if costA.TTFTP50Ms <= 0 || costA.TTFTP95Ms <= 0 || costA.TTFTP99Ms <= 0 {
+ t.Fatalf("TTFT percentiles must be populated: %+v", costA)
+ }
+ if costA.TotalLatencyP50Ms <= 0 || costA.TotalLatencyP95Ms <= 0 || costA.TotalLatencyP99Ms <= 0 || costA.TotalLatencyMs <= 0 {
+ t.Fatalf("latency metrics must be populated: %+v", costA)
+ }
+ if len(costA.ProviderUtilization) != 2 || costA.ProviderUtilization["fast"] <= 0 || costA.ProviderUtilization["slow"] <= 0 {
+ t.Fatalf("provider utilization must be populated: %+v", costA.ProviderUtilization)
+ }
+}
+
+func burstRequests(model string, n int, arrivalMs float64, promptTokens, maxTokens int, timeoutMs float64) []RoutingReplayRequest {
+ reqs := make([]RoutingReplayRequest, 0, n)
+ for i := 0; i < n; i++ {
+ reqs = append(reqs, RoutingReplayRequest{
+ ID: string(rune('a' + i)),
+ Model: model,
+ ArrivalMs: arrivalMs,
+ PromptTokens: promptTokens,
+ MaxTokens: maxTokens,
+ TimeoutMs: timeoutMs,
+ })
+ }
+ return reqs
+}
diff --git a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift
index ffc5a892..b97e7155 100644
--- a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift
+++ b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClient.swift
@@ -59,6 +59,7 @@ public final class ProviderState: @unchecked Sendable {
private var _warmModels: [String] = []
private var _currentModelHash: String? = nil
private var _backendCapacity: BackendCapacity? = nil
+ private var _networkQuality = NetworkQuality()
public init() {}
@@ -86,6 +87,11 @@ public final class ProviderState: @unchecked Sendable {
get { lock.withLock { _backendCapacity } }
set { lock.withLock { _backendCapacity = newValue } }
}
+
+ public var networkQuality: NetworkQuality {
+ get { lock.withLock { _networkQuality } }
+ set { lock.withLock { _networkQuality = newValue } }
+ }
}
// MARK: - os_unfair_lock wrapper (Sendable-safe)
@@ -280,6 +286,7 @@ public actor CoordinatorClient {
private let state: ProviderState
private let logger = Logger(subsystem: "dev.darkbloom.provider", category: "coordinator")
+ private let networkQualityTracker = NetworkQualityTracker()
private var eventContinuation: AsyncStream.Continuation?
private var outboundContinuation: AsyncStream.Continuation?
@@ -349,6 +356,8 @@ public actor CoordinatorClient {
logger.warning("Coordinator connection error: \(error.localizedDescription). Reconnecting in \(delay)s")
reconnectCount += 1
+ networkQualityTracker.recordReconnect()
+ state.networkQuality = networkQualityTracker.snapshot()
if shouldEmitReconnectTelemetry(count: reconnectCount) {
emitReconnectTelemetry(count: reconnectCount, error: error)
}
@@ -413,7 +422,15 @@ public actor CoordinatorClient {
let shutting = await self.shutdownRequested
if shutting { break }
let json = await self.encodeOutbound(msg)
- try await ws.send(.string(json))
+ let started = CFAbsoluteTimeGetCurrent()
+ do {
+ try await ws.send(.string(json))
+ let elapsedMs = (CFAbsoluteTimeGetCurrent() - started) * 1000.0
+ await self.recordWriteLatency(elapsedMs)
+ } catch {
+ await self.recordWriteFailure()
+ throw error
+ }
}
}
@@ -442,9 +459,18 @@ public actor CoordinatorClient {
throw CoordinatorError.pongTimeout
}
- ws.sendPing { error in
+ let started = CFAbsoluteTimeGetCurrent()
+ ws.sendPing { [weak self] error in
if error == nil {
pongTracker.recordPong()
+ let elapsedMs = (CFAbsoluteTimeGetCurrent() - started) * 1000.0
+ self?.networkQualityTracker.recordPong(rttMs: elapsedMs)
+ if let self {
+ Task { await self.publishNetworkQualitySnapshot() }
+ }
+ } else if let self {
+ self.networkQualityTracker.recordWriteFailure()
+ Task { await self.publishNetworkQualitySnapshot() }
}
}
}
@@ -613,6 +639,7 @@ public actor CoordinatorClient {
tokensGenerated: stats.tokensGenerated
),
systemMetrics: metrics,
+ networkQuality: state.networkQuality,
backendCapacity: capacity
)
@@ -623,6 +650,20 @@ public actor CoordinatorClient {
return json
}
+ private func recordWriteLatency(_ ms: Double) {
+ networkQualityTracker.recordWriteLatency(ms: ms)
+ state.networkQuality = networkQualityTracker.snapshot()
+ }
+
+ private func recordWriteFailure() {
+ networkQualityTracker.recordWriteFailure()
+ state.networkQuality = networkQualityTracker.snapshot()
+ }
+
+ private func publishNetworkQualitySnapshot() {
+ state.networkQuality = networkQualityTracker.snapshot()
+ }
+
// MARK: - Outbound Encoding
private func encodeOutbound(_ msg: OutboundMessage) -> String {
diff --git a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift
index ed340c59..da43b96b 100644
--- a/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift
+++ b/provider-swift/Sources/ProviderCore/Coordinator/CoordinatorClientCodec.swift
@@ -46,6 +46,7 @@ public enum CoordinatorClientCodec {
warmModels: [String],
stats: ProviderStats,
systemMetrics: SystemMetrics,
+ networkQuality: NetworkQuality = NetworkQuality(),
backendCapacity: BackendCapacity?
) -> ProviderMessage {
.heartbeat(ProviderMessage.Heartbeat(
@@ -54,6 +55,7 @@ public enum CoordinatorClientCodec {
warmModels: warmModels,
stats: stats,
systemMetrics: systemMetrics,
+ networkQuality: networkQuality,
backendCapacity: backendCapacity
))
}
diff --git a/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift b/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift
new file mode 100644
index 00000000..e9895f16
--- /dev/null
+++ b/provider-swift/Sources/ProviderCore/Coordinator/NetworkQualityTracker.swift
@@ -0,0 +1,59 @@
+import Foundation
+
+/// Tracks provider-observed coordinator WebSocket transport quality.
+/// All methods are synchronous and lock-backed so ping/write callbacks can
+/// update from arbitrary URLSession queues without actor hops.
+public final class NetworkQualityTracker: @unchecked Sendable {
+ private let lock = NSLock()
+ private var latestRTTMs: Double = 0
+ private var latestJitterMs: Double = 0
+ private var previousRTTMs: Double?
+ private var reconnects: UInt64 = 0
+ private var writeFailures: UInt64 = 0
+ private var latestWriteLatencyMs: Double = 0
+
+ public init() {}
+
+ public func recordPong(rttMs: Double) {
+ let bounded = max(0, rttMs)
+ lock.withLock {
+ if let previousRTTMs {
+ latestJitterMs = abs(bounded - previousRTTMs)
+ }
+ previousRTTMs = bounded
+ latestRTTMs = bounded
+ }
+ }
+
+ public func recordReconnect() {
+ lock.withLock { reconnects &+= 1 }
+ }
+
+ public func recordWriteFailure() {
+ lock.withLock { writeFailures &+= 1 }
+ }
+
+ public func recordWriteLatency(ms: Double) {
+ lock.withLock { latestWriteLatencyMs = max(0, ms) }
+ }
+
+ public func snapshot() -> NetworkQuality {
+ lock.withLock {
+ NetworkQuality(
+ rttMs: latestRTTMs,
+ jitterMs: latestJitterMs,
+ reconnectCount: reconnects,
+ websocketWriteFailures: writeFailures,
+ lastWriteLatencyMs: latestWriteLatencyMs
+ )
+ }
+ }
+}
+
+private extension NSLock {
+ func withLock(_ body: () -> T) -> T {
+ lock()
+ defer { unlock() }
+ return body()
+ }
+}
diff --git a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift
index e6d27b08..eb51d0f9 100644
--- a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift
+++ b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift
@@ -21,6 +21,7 @@ public struct SchedulerCapacity: Sendable {
public let model: String
public let activeRequests: Int
public let pendingRequests: Int
+ public let activeTokens: Int
public let maxConcurrent: Int
public let gpuMemoryActiveBytes: Int
public let gpuMemoryPeakBytes: Int
@@ -168,11 +169,13 @@ public actor BatchScheduler {
topK: request.top_k ?? 0,
seed: request.seed
)
+ let admitStartedAt = ContinuousClock.now
let assignedUids = gen.insert(
prompts: [promptTokens],
maxTokens: [maxTokens],
samplers: [sampler]
)
+ let admittedAt = ContinuousClock.now
guard let uid = assignedUids.first else {
continuation.yield(.error("BatchGenerator rejected the prompt"))
continuation.finish()
@@ -185,10 +188,15 @@ public actor BatchScheduler {
detokenizer: NaiveStreamingDetokenizer(tokenizer: tk.inner),
promptTokens: promptTokens.count,
completionTokens: 0,
- submittedAt: .now
+ submittedAt: admitStartedAt,
+ admittedAt: admittedAt,
+ firstTokenAt: nil,
+ model: telemetryModelId(for: request)
)
requestIdToUid[id] = uid
+ emitLifecycle(.schedulerAdmit, entry: active[uid]!, activeCount: active.count)
+
let scheduler = self
continuation.onTermination = { @Sendable termination in
if case .cancelled = termination {
@@ -214,10 +222,12 @@ public actor BatchScheduler {
// MARK: - Capacity
public func capacity() -> SchedulerCapacity {
- SchedulerCapacity(
+ let counters = capacityCounters()
+ return SchedulerCapacity(
model: modelId,
- activeRequests: active.count,
- pendingRequests: 0,
+ activeRequests: counters.activeRequests,
+ pendingRequests: counters.pendingRequests,
+ activeTokens: counters.activeTokens,
maxConcurrent: maxConcurrentRequests,
gpuMemoryActiveBytes: gpuMemory(.active),
gpuMemoryPeakBytes: gpuMemory(.peak),
@@ -234,7 +244,7 @@ public actor BatchScheduler {
state: cap.activeRequests > 0 ? "running" : "idle",
numRunning: UInt32(cap.activeRequests),
numWaiting: UInt32(cap.pendingRequests),
- activeTokens: 0,
+ activeTokens: Int64(cap.activeTokens),
maxTokensPotential: Int64(defaultMaxTokens * maxConcurrentRequests)
)
return BackendCapacity(
@@ -279,6 +289,10 @@ public actor BatchScheduler {
entry.detokenizer.append(token: response.token)
entry.completionTokens += 1
+ if entry.firstTokenAt == nil {
+ entry.firstTokenAt = .now
+ emitLifecycle(.firstToken, entry: entry, activeCount: active.count)
+ }
if let chunk = entry.detokenizer.next() {
entry.continuation.yield(.chunk(chunk))
}
@@ -307,6 +321,7 @@ public actor BatchScheduler {
entry.continuation.finish()
active.removeValue(forKey: response.uid)
requestIdToUid.removeValue(forKey: entry.requestId)
+ emitLifecycle(.inferenceComplete, entry: entry, activeCount: active.count)
}
}
@@ -315,6 +330,50 @@ public actor BatchScheduler {
requestIdToUid.removeValue(forKey: entry.requestId)
entry.continuation.yield(.error(error))
entry.continuation.finish()
+ let stage: InferenceLifecycleTelemetryStage = error == "Request cancelled"
+ ? .inferenceCancel
+ : .inferenceError
+ emitLifecycle(stage, entry: entry, activeCount: active.count, error: error, reason: error)
+ }
+
+ private func capacityCounters() -> SchedulerCapacityCounters {
+ SchedulerCapacityCounters.from(active.values.map { entry in
+ InferenceRequestProgress(
+ promptTokens: entry.promptTokens,
+ completionTokens: entry.completionTokens,
+ firstTokenReceived: entry.firstTokenAt != nil
+ )
+ })
+ }
+
+ private func emitLifecycle(
+ _ stage: InferenceLifecycleTelemetryStage,
+ entry: ActiveRequest,
+ activeCount: Int,
+ error: String? = nil,
+ reason: String? = nil
+ ) {
+ TelemetryClient.shared.emit(InferenceLifecycleTelemetry.event(
+ stage,
+ snapshot: InferenceLifecycleTelemetrySnapshot(
+ requestId: entry.requestId,
+ model: entry.model,
+ promptTokens: entry.promptTokens,
+ completionTokens: entry.completionTokens,
+ queueMilliseconds: InferenceLifecycleTelemetry.milliseconds(entry.admittedAt - entry.submittedAt),
+ admitMilliseconds: InferenceLifecycleTelemetry.milliseconds(entry.admittedAt - entry.submittedAt),
+ ttftMilliseconds: entry.firstTokenAt.map { InferenceLifecycleTelemetry.milliseconds($0 - entry.submittedAt) },
+ totalMilliseconds: InferenceLifecycleTelemetry.milliseconds(ContinuousClock.now - entry.submittedAt),
+ activeCount: activeCount
+ ),
+ error: error,
+ reason: reason
+ ))
+ }
+
+ private func telemetryModelId(for request: ChatCompletionRequest) -> String {
+ if !request.model.isEmpty { return request.model }
+ return modelId
}
private enum MemoryKind { case active, peak, cache }
@@ -341,6 +400,9 @@ private struct ActiveRequest {
var promptTokens: Int
var completionTokens: Int
let submittedAt: ContinuousClock.Instant
+ let admittedAt: ContinuousClock.Instant
+ var firstTokenAt: ContinuousClock.Instant?
+ let model: String
}
private struct LoadSnapshot: @unchecked Sendable {
diff --git a/provider-swift/Sources/ProviderCore/Protocol/Messages.swift b/provider-swift/Sources/ProviderCore/Protocol/Messages.swift
index 5b75c796..ebf5a1a4 100644
--- a/provider-swift/Sources/ProviderCore/Protocol/Messages.swift
+++ b/provider-swift/Sources/ProviderCore/Protocol/Messages.swift
@@ -70,6 +70,7 @@ public enum ProviderMessage: Sendable, Equatable {
public var warmModels: [String]
public var stats: ProviderStats
public var systemMetrics: SystemMetrics
+ public var networkQuality: NetworkQuality
public var backendCapacity: BackendCapacity?
public init(
@@ -78,6 +79,7 @@ public enum ProviderMessage: Sendable, Equatable {
warmModels: [String] = [],
stats: ProviderStats,
systemMetrics: SystemMetrics,
+ networkQuality: NetworkQuality = NetworkQuality(),
backendCapacity: BackendCapacity? = nil
) {
self.status = status
@@ -85,6 +87,7 @@ public enum ProviderMessage: Sendable, Equatable {
self.warmModels = warmModels
self.stats = stats
self.systemMetrics = systemMetrics
+ self.networkQuality = networkQuality
self.backendCapacity = backendCapacity
}
}
@@ -238,6 +241,7 @@ extension ProviderMessage: Codable {
case warmModels = "warm_models"
case stats
case systemMetrics = "system_metrics"
+ case networkQuality = "network_quality"
case backendCapacity = "backend_capacity"
// Common
case requestId = "request_id"
@@ -300,6 +304,7 @@ extension ProviderMessage: Codable {
}
try container.encode(h.stats, forKey: .stats)
try container.encode(h.systemMetrics, forKey: .systemMetrics)
+ try container.encode(h.networkQuality, forKey: .networkQuality)
try container.encodeIfPresent(h.backendCapacity, forKey: .backendCapacity)
case .inferenceAccepted(let a):
@@ -387,6 +392,7 @@ extension ProviderMessage: Codable {
warmModels: try container.decodeIfPresent([String].self, forKey: .warmModels) ?? [],
stats: try container.decode(ProviderStats.self, forKey: .stats),
systemMetrics: try container.decode(SystemMetrics.self, forKey: .systemMetrics),
+ networkQuality: try container.decodeIfPresent(NetworkQuality.self, forKey: .networkQuality) ?? NetworkQuality(),
backendCapacity: try container.decodeIfPresent(BackendCapacity.self, forKey: .backendCapacity)
))
diff --git a/provider-swift/Sources/ProviderCore/Protocol/Types.swift b/provider-swift/Sources/ProviderCore/Protocol/Types.swift
index 1417a8db..993467ea 100644
--- a/provider-swift/Sources/ProviderCore/Protocol/Types.swift
+++ b/provider-swift/Sources/ProviderCore/Protocol/Types.swift
@@ -140,6 +140,36 @@ public struct ProviderStats: Codable, Sendable, Equatable {
}
}
+public struct NetworkQuality: Codable, Sendable, Equatable {
+ public var rttMs: Double
+ public var jitterMs: Double
+ public var reconnectCount: UInt64
+ public var websocketWriteFailures: UInt64
+ public var lastWriteLatencyMs: Double
+
+ enum CodingKeys: String, CodingKey {
+ case rttMs = "rtt_ms"
+ case jitterMs = "jitter_ms"
+ case reconnectCount = "reconnect_count"
+ case websocketWriteFailures = "websocket_write_failures"
+ case lastWriteLatencyMs = "last_write_latency_ms"
+ }
+
+ public init(
+ rttMs: Double = 0,
+ jitterMs: Double = 0,
+ reconnectCount: UInt64 = 0,
+ websocketWriteFailures: UInt64 = 0,
+ lastWriteLatencyMs: Double = 0
+ ) {
+ self.rttMs = rttMs
+ self.jitterMs = jitterMs
+ self.reconnectCount = reconnectCount
+ self.websocketWriteFailures = websocketWriteFailures
+ self.lastWriteLatencyMs = lastWriteLatencyMs
+ }
+}
+
public struct UsageInfo: Codable, Sendable, Equatable {
public var promptTokens: UInt64
public var completionTokens: UInt64
diff --git a/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift b/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift
new file mode 100644
index 00000000..04a02b10
--- /dev/null
+++ b/provider-swift/Sources/ProviderCore/Telemetry/InferenceLifecycleTelemetry.swift
@@ -0,0 +1,172 @@
+import Foundation
+
+/// Lightweight request progress used to build capacity snapshots without
+/// depending on MLX runtime state. `completionTokens == 0` is treated as
+/// pending/prefill until the first token is observed.
+public struct InferenceRequestProgress: Sendable, Equatable {
+ public let promptTokens: Int
+ public let completionTokens: Int
+ public let firstTokenReceived: Bool
+
+ public init(
+ promptTokens: Int,
+ completionTokens: Int,
+ firstTokenReceived: Bool
+ ) {
+ self.promptTokens = promptTokens
+ self.completionTokens = completionTokens
+ self.firstTokenReceived = firstTokenReceived
+ }
+}
+
+public struct SchedulerCapacityCounters: Sendable, Equatable {
+ public let activeRequests: Int
+ public let pendingRequests: Int
+ public let activeTokens: Int
+
+ public init(activeRequests: Int, pendingRequests: Int, activeTokens: Int) {
+ self.activeRequests = activeRequests
+ self.pendingRequests = pendingRequests
+ self.activeTokens = activeTokens
+ }
+
+ public static func from(_ requests: some Sequence) -> SchedulerCapacityCounters {
+ var activeRequests = 0
+ var pendingRequests = 0
+ var activeTokens = 0
+
+ for request in requests {
+ activeRequests += 1
+ if !request.firstTokenReceived {
+ pendingRequests += 1
+ }
+ activeTokens += max(0, request.promptTokens) + max(0, request.completionTokens)
+ }
+
+ return SchedulerCapacityCounters(
+ activeRequests: activeRequests,
+ pendingRequests: pendingRequests,
+ activeTokens: activeTokens
+ )
+ }
+}
+
+public enum InferenceLifecycleTelemetryStage: String, Sendable, Equatable {
+ case schedulerAdmit = "scheduler_admit"
+ case firstToken = "first_token"
+ case inferenceComplete = "inference_complete"
+ case inferenceError = "inference_error"
+ case inferenceCancel = "inference_cancel"
+}
+
+public struct InferenceLifecycleTelemetrySnapshot: Sendable, Equatable {
+ public let requestId: String
+ public let model: String
+ public let promptTokens: Int
+ public let completionTokens: Int
+ public let queueMilliseconds: Double?
+ public let admitMilliseconds: Double?
+ public let ttftMilliseconds: Double?
+ public let totalMilliseconds: Double?
+ public let activeCount: Int?
+
+ public init(
+ requestId: String,
+ model: String,
+ promptTokens: Int,
+ completionTokens: Int,
+ queueMilliseconds: Double? = nil,
+ admitMilliseconds: Double? = nil,
+ ttftMilliseconds: Double? = nil,
+ totalMilliseconds: Double? = nil,
+ activeCount: Int? = nil
+ ) {
+ self.requestId = requestId
+ self.model = model
+ self.promptTokens = promptTokens
+ self.completionTokens = completionTokens
+ self.queueMilliseconds = queueMilliseconds
+ self.admitMilliseconds = admitMilliseconds
+ self.ttftMilliseconds = ttftMilliseconds
+ self.totalMilliseconds = totalMilliseconds
+ self.activeCount = activeCount
+ }
+}
+
+public enum InferenceLifecycleTelemetry {
+ public static func event(
+ _ stage: InferenceLifecycleTelemetryStage,
+ snapshot: InferenceLifecycleTelemetrySnapshot,
+ error: String? = nil,
+ reason: String? = nil
+ ) -> TelemetryEvent {
+ var fields: [String: AnyCodableValue] = [
+ "component": .string("batch_scheduler"),
+ "operation": .string(stage.rawValue),
+ "model": .string(snapshot.model),
+ "prompt_tokens": .int(snapshot.promptTokens),
+ "completion_tokens": .int(snapshot.completionTokens),
+ ]
+
+ if let queueMilliseconds = snapshot.queueMilliseconds {
+ fields["queue_ms"] = .double(queueMilliseconds)
+ }
+ if let admitMilliseconds = snapshot.admitMilliseconds {
+ fields["admit_ms"] = .double(admitMilliseconds)
+ }
+ if let ttftMilliseconds = snapshot.ttftMilliseconds {
+ fields["ttft_ms"] = .double(ttftMilliseconds)
+ }
+ if let totalMilliseconds = snapshot.totalMilliseconds {
+ fields["total_ms"] = .double(totalMilliseconds)
+ // Preserve compatibility with existing dashboards that key on the
+ // generic duration field while still emitting the explicit metric.
+ fields["duration_ms"] = .double(totalMilliseconds)
+ }
+ if let activeCount = snapshot.activeCount {
+ fields["active_count"] = .int(activeCount)
+ fields["queue_depth"] = .int(activeCount)
+ }
+ if let error {
+ fields["error"] = .string(error)
+ }
+ if let reason {
+ fields["reason"] = .string(reason)
+ }
+
+ return TelemetryEvent(
+ source: .provider,
+ severity: severity(for: stage),
+ kind: kind(for: stage),
+ message: stage.rawValue
+ )
+ .withRequestId(snapshot.requestId)
+ .withFields(fields)
+ }
+
+ public static func milliseconds(_ duration: Duration) -> Double {
+ let components = duration.components
+ return (Double(components.seconds) * 1_000.0)
+ + (Double(components.attoseconds) / 1e15)
+ }
+
+ private static func severity(for stage: InferenceLifecycleTelemetryStage) -> TelemetrySeverity {
+ switch stage {
+ case .inferenceError:
+ return .error
+ case .inferenceCancel:
+ return .warn
+ case .schedulerAdmit, .firstToken, .inferenceComplete:
+ return .info
+ }
+ }
+
+ private static func kind(for stage: InferenceLifecycleTelemetryStage) -> TelemetryKind {
+ switch stage {
+ case .inferenceError:
+ return .inferenceError
+ case .schedulerAdmit, .firstToken, .inferenceComplete, .inferenceCancel:
+ return .custom
+ }
+ }
+}
\ No newline at end of file
diff --git a/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift b/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift
index d592032b..a1678b98 100644
--- a/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift
+++ b/provider-swift/Sources/ProviderCore/Telemetry/TelemetryEvent.swift
@@ -231,6 +231,8 @@ public enum TelemetryFieldFilter {
"status_code", "error_class", "error", "model", "backend",
"exit_code", "signal", "hardware_chip", "memory_gb", "macos_version",
"handler", "provider_id", "trust_level", "queue_depth", "reason",
+ "queue_ms", "admit_ms", "prompt_tokens", "completion_tokens",
+ "ttft_ms", "total_ms", "active_count",
"runtime_component", "reconnect_count", "last_error", "ws_state",
"billing_method", "payment_failed", "target",
]
diff --git a/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift b/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift
new file mode 100644
index 00000000..28f84d33
--- /dev/null
+++ b/provider-swift/Tests/ProviderCoreTests/BatchSchedulerTelemetryTests.swift
@@ -0,0 +1,75 @@
+import Testing
+@testable import ProviderCore
+
+@Suite("BatchScheduler telemetry + capacity helpers")
+struct BatchSchedulerTelemetryTests {
+
+ @Test("capacity counters report pending requests and active tokens")
+ func capacityCountersReflectRequestProgress() {
+ let counters = SchedulerCapacityCounters.from([
+ InferenceRequestProgress(promptTokens: 12, completionTokens: 0, firstTokenReceived: false),
+ InferenceRequestProgress(promptTokens: 9, completionTokens: 3, firstTokenReceived: true),
+ InferenceRequestProgress(promptTokens: 4, completionTokens: 1, firstTokenReceived: true),
+ ])
+
+ #expect(counters.activeRequests == 3)
+ #expect(counters.pendingRequests == 1)
+ #expect(counters.activeTokens == 29)
+ }
+
+ @Test("lifecycle telemetry event includes request correlation and timing fields")
+ func lifecycleTelemetryEventConstruction() {
+ let snapshot = InferenceLifecycleTelemetrySnapshot(
+ requestId: "req-123",
+ model: "mlx-community/test-model",
+ promptTokens: 17,
+ completionTokens: 5,
+ queueMilliseconds: 2.5,
+ admitMilliseconds: 3.25,
+ ttftMilliseconds: 42.75,
+ totalMilliseconds: 123.5,
+ activeCount: 4
+ )
+
+ let event = InferenceLifecycleTelemetry.event(.inferenceComplete, snapshot: snapshot)
+
+ #expect(event.requestId == "req-123")
+ #expect(event.message == "inference_complete")
+ #expect(event.severity == .info)
+ #expect(event.kind == .custom)
+ #expect(event.fields?["operation"]?.description == "inference_complete")
+ #expect(event.fields?["model"]?.description == "mlx-community/test-model")
+ #expect(event.fields?["prompt_tokens"]?.description == "17")
+ #expect(event.fields?["completion_tokens"]?.description == "5")
+ #expect(event.fields?["queue_ms"]?.description == "2.5")
+ #expect(event.fields?["admit_ms"]?.description == "3.25")
+ #expect(event.fields?["ttft_ms"]?.description == "42.75")
+ #expect(event.fields?["total_ms"]?.description == "123.5")
+ #expect(event.fields?["active_count"]?.description == "4")
+ }
+
+ @Test("telemetry field filter preserves lifecycle metrics")
+ func telemetryFieldFilterAllowsLifecycleMetrics() {
+ let filtered = TelemetryFieldFilter.filter([
+ "operation": .string("first_token"),
+ "prompt_tokens": .int(10),
+ "completion_tokens": .int(1),
+ "queue_ms": .double(1.0),
+ "admit_ms": .double(2.0),
+ "ttft_ms": .double(3.0),
+ "total_ms": .double(4.0),
+ "active_count": .int(2),
+ "not_allowed": .string("drop"),
+ ])
+
+ #expect(filtered?["operation"]?.description == "first_token")
+ #expect(filtered?["prompt_tokens"]?.description == "10")
+ #expect(filtered?["completion_tokens"]?.description == "1")
+ #expect(filtered?["queue_ms"]?.description == "1.0")
+ #expect(filtered?["admit_ms"]?.description == "2.0")
+ #expect(filtered?["ttft_ms"]?.description == "3.0")
+ #expect(filtered?["total_ms"]?.description == "4.0")
+ #expect(filtered?["active_count"]?.description == "2")
+ #expect(filtered?["not_allowed"] == nil)
+ }
+}
diff --git a/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift b/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift
new file mode 100644
index 00000000..40358d80
--- /dev/null
+++ b/provider-swift/Tests/ProviderCoreTests/NetworkQualityTests.swift
@@ -0,0 +1,34 @@
+import Testing
+@testable import ProviderCore
+
+@Test func providerStateStoresNetworkQuality() {
+ let state = ProviderState()
+ let quality = NetworkQuality(
+ rttMs: 80,
+ jitterMs: 12,
+ reconnectCount: 2,
+ websocketWriteFailures: 1,
+ lastWriteLatencyMs: 4
+ )
+
+ state.networkQuality = quality
+
+ #expect(state.networkQuality == quality)
+}
+
+@Test func networkQualityTrackerComputesPingRttAndJitterAndCounters() {
+ let tracker = NetworkQualityTracker()
+
+ tracker.recordPong(rttMs: 100)
+ tracker.recordPong(rttMs: 140)
+ tracker.recordReconnect()
+ tracker.recordWriteFailure()
+ tracker.recordWriteLatency(ms: 8)
+
+ let snapshot = tracker.snapshot()
+ #expect(snapshot.rttMs == 140)
+ #expect(snapshot.jitterMs == 40)
+ #expect(snapshot.reconnectCount == 1)
+ #expect(snapshot.websocketWriteFailures == 1)
+ #expect(snapshot.lastWriteLatencyMs == 8)
+}
diff --git a/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift b/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift
index 6d395b7e..21f5d823 100644
--- a/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift
+++ b/provider-swift/Tests/ProviderCoreTests/ProtocolTests.swift
@@ -53,6 +53,13 @@ import Testing
warmModels: ["mlx-community/Qwen2.5-7B-4bit"],
stats: ProviderStats(requestsServed: 4, tokensGenerated: 4096),
systemMetrics: SystemMetrics(memoryPressure: 0.2, cpuUsage: 0.3, thermalState: .nominal),
+ networkQuality: NetworkQuality(
+ rttMs: 42.5,
+ jitterMs: 7.25,
+ reconnectCount: 2,
+ websocketWriteFailures: 1,
+ lastWriteLatencyMs: 3.5
+ ),
backendCapacity: BackendCapacity(
slots: [BackendSlotCapacity(
model: "mlx-community/Qwen2.5-7B-4bit",
@@ -160,6 +167,37 @@ import Testing
#expect(failedObj["error"] as? String == "GPU OOM")
}
+@Test func heartbeatNetworkQualityUsesSnakeCaseAndDefaultsToZero() throws {
+ let heartbeat = ProviderMessage.heartbeat(ProviderMessage.Heartbeat(
+ status: .idle,
+ stats: ProviderStats(),
+ systemMetrics: SystemMetrics(memoryPressure: 0, cpuUsage: 0, thermalState: .nominal),
+ networkQuality: NetworkQuality(
+ rttMs: 125.5,
+ jitterMs: 24.25,
+ reconnectCount: 3,
+ websocketWriteFailures: 2,
+ lastWriteLatencyMs: 9.75
+ )
+ ))
+
+ let data = try ProviderProtocolCodec.encodeProviderMessage(heartbeat)
+ let object = try jsonObject(data)
+ let network = object["network_quality"] as? [String: Any]
+ #expect(network?["rtt_ms"] as? Double == 125.5)
+ #expect(network?["jitter_ms"] as? Double == 24.25)
+ #expect(network?["reconnect_count"] as? Int == 3)
+ #expect(network?["websocket_write_failures"] as? Int == 2)
+ #expect(network?["last_write_latency_ms"] as? Double == 9.75)
+
+ let legacyJSON = #"{"type":"heartbeat","status":"idle","stats":{"requests_served":0,"tokens_generated":0},"system_metrics":{"memory_pressure":0,"cpu_usage":0,"thermal_state":"nominal"}}"#
+ let decoded = try ProviderProtocolCodec.decodeProviderMessage(from: legacyJSON)
+ guard case .heartbeat(let legacy) = decoded else {
+ throw TestFailure.unexpectedMessage
+ }
+ #expect(legacy.networkQuality == NetworkQuality())
+}
+
@Test func coordinatorMessagesDecodeAndEncodeWithSnakeCaseKeys() throws {
let encryptedRequest = #"{"type":"inference_request","request_id":"go-enc-req-1","body":null,"encrypted_body":{"ephemeral_public_key":"ZXBoZW1lcmFs","ciphertext":"Y2lwaGVy"}}"#
let request = try ProviderProtocolCodec.decodeCoordinatorMessage(from: encryptedRequest)