diff --git a/coordinator/internal/api/edge_case_test.go b/coordinator/internal/api/edge_case_test.go index a571d4c0..909aef09 100644 --- a/coordinator/internal/api/edge_case_test.go +++ b/coordinator/internal/api/edge_case_test.go @@ -7,9 +7,14 @@ package api // (no real backends needed) and run in CI. import ( + "archive/tar" + "bytes" + "compress/gzip" "context" "crypto/rand" + "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io" @@ -766,7 +771,18 @@ func TestEdge_ReleaseRegisterAndRetrieve(t *testing.T) { srv.SetReleaseKey("release-key") // Register a release - body := `{"version":"1.0.0","platform":"macos-arm64","binary_hash":"abc123","bundle_hash":"def456","url":"http://example.com/bundle.tar.gz","changelog":"First release"}` + bundle, binaryHash, bundleHash := buildReleaseBundleForTest(t, []byte("provider-binary")) + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL + "/") + + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q,"changelog":"First release"}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) req.Header.Set("Authorization", "Bearer release-key") w := httptest.NewRecorder() @@ -798,6 +814,351 @@ func TestEdge_ReleaseRegisterAndRetrieve(t *testing.T) { } } +func TestEdge_ReleaseRegisterRejectsInvalidHashMetadata(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + body := `{"version":"1.0.0","platform":"macos-arm64","binary_hash":"abc123","bundle_hash":"def456","url":"http://example.com/bundle.tar.gz"}` + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with invalid hashes: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsStoreOnlyFields(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + binaryHash := strings.Repeat("a", 64) + bundleHash := strings.Repeat("b", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":"https://r2.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz","active":true,"created_at":"2099-01-01T00:00:00Z"}`, binaryHash, bundleHash) + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with store-only fields: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsOffOriginURLWhenR2Configured(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + srv.SetR2CDNURL("https://r2.example.com") + + binaryHash := strings.Repeat("a", 64) + bundleHash := strings.Repeat("b", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":"https://evil.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz"}`, binaryHash, bundleHash) + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with off-origin URL: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsHTTPArtifactOrigin(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + srv.SetR2CDNURL("http://r2.example.com") + + binaryHash := strings.Repeat("a", 64) + bundleHash := strings.Repeat("b", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":"http://r2.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz"}`, binaryHash, bundleHash) + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with http artifact origin: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsCredentialedArtifactURL(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + srv.SetR2CDNURL("https://r2.example.com") + + binaryHash := strings.Repeat("a", 64) + bundleHash := strings.Repeat("b", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":"https://user:pass@r2.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz"}`, binaryHash, bundleHash) + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with credentialed artifact URL: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterVerifiesBundleArtifact(t *testing.T) { + srv, st := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, binaryHash, bundleHash := buildReleaseBundleForTest(t, []byte("provider-binary")) + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("release register with verified artifact: status = %d, want 200, body = %s", w.Code, w.Body.String()) + } + releases := st.ListReleases() + if len(releases) != 1 || releases[0].BinaryHash != binaryHash { + t.Fatalf("release was not stored with verified binary hash: %+v", releases) + } +} + +func TestEdge_ReleaseRegisterAcceptsLegacyRegularBundleEntry(t *testing.T) { + srv, st := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, binaryHash, bundleHash := buildReleaseBundleWithEntryForTest(t, "bin/darkbloom", tar.TypeRegA, []byte("provider-binary"), "") + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("release register with legacy regular bundle entry: status = %d, want 200, body = %s", w.Code, w.Body.String()) + } + releases := st.ListReleases() + if len(releases) != 1 || releases[0].BinaryHash != binaryHash { + t.Fatalf("release was not stored with legacy regular bundle entry: %+v", releases) + } +} + +func TestEdge_ReleaseRegisterRejectsBundledBinaryHashMismatch(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, _, bundleHash := buildReleaseBundleForTest(t, []byte("provider-binary")) + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + wrongBinaryHash := strings.Repeat("c", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, wrongBinaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with mismatched binary hash: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsOversizedBundledBinary(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, bundleHash := buildOversizedBinaryReleaseBundleForTest(t) + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + binaryHash := strings.Repeat("d", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with oversized bundled binary: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsRedirectedBundleDownload(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, binaryHash, bundleHash := buildReleaseBundleForTest(t, []byte("provider-binary")) + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(bundle) + })) + defer target.Close() + + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL+"/bundle.tar.gz", http.StatusFound) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with redirected bundle: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsUnsafeBundlePath(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, binaryHash, bundleHash := buildReleaseBundleWithEntryForTest(t, "../bin/darkbloom", tar.TypeReg, []byte("provider-binary"), "") + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with unsafe bundle path: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func TestEdge_ReleaseRegisterRejectsNonRegularProviderBinary(t *testing.T) { + srv, _ := testServer(t) + srv.SetReleaseKey("release-key") + + bundle, _, bundleHash := buildReleaseBundleWithEntryForTest(t, "bin/darkbloom", tar.TypeSymlink, nil, "darkbloom.real") + cdn := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz" { + http.NotFound(w, r) + return + } + w.Write(bundle) + })) + defer cdn.Close() + srv.SetR2CDNURL(cdn.URL) + + binaryHash := strings.Repeat("e", 64) + body := fmt.Sprintf(`{"version":"1.0.0","platform":"macos-arm64","binary_hash":%q,"bundle_hash":%q,"url":%q}`, binaryHash, bundleHash, cdn.URL+"/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz") + req := httptest.NewRequest(http.MethodPost, "/v1/releases", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer release-key") + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("release register with non-regular provider binary: status = %d, want 400, body = %s", w.Code, w.Body.String()) + } +} + +func buildReleaseBundleForTest(t *testing.T, binary []byte) ([]byte, string, string) { + t.Helper() + + return buildReleaseBundleWithEntryForTest(t, "bin/darkbloom", tar.TypeReg, binary, "") +} + +func buildReleaseBundleWithEntryForTest(t *testing.T, name string, typeflag byte, binary []byte, linkname string) ([]byte, string, string) { + t.Helper() + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + + header := &tar.Header{ + Name: name, + Mode: 0o755, + Typeflag: typeflag, + Linkname: linkname, + } + if typeflag == tar.TypeReg || typeflag == tar.TypeRegA { + header.Size = int64(len(binary)) + } + if err := tw.WriteHeader(header); err != nil { + t.Fatalf("write tar header: %v", err) + } + if len(binary) > 0 { + if _, err := tw.Write(binary); err != nil { + t.Fatalf("write binary: %v", err) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("close gzip: %v", err) + } + + return buf.Bytes(), sha256HexBytesForReleaseTest(binary), sha256HexBytesForReleaseTest(buf.Bytes()) +} + +func buildOversizedBinaryReleaseBundleForTest(t *testing.T) ([]byte, string) { + t.Helper() + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + if err := tw.WriteHeader(&tar.Header{ + Name: "bin/darkbloom", + Mode: 0o755, + Size: maxReleaseProviderBinBytes + 1, + }); err != nil { + t.Fatalf("write oversized tar header: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("close gzip: %v", err) + } + + return buf.Bytes(), sha256HexBytesForReleaseTest(buf.Bytes()) +} + +func sha256HexBytesForReleaseTest(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // --------------------------------------------------------------------------- // Error response format // --------------------------------------------------------------------------- diff --git a/coordinator/internal/api/provider.go b/coordinator/internal/api/provider.go index 4c0765cb..3cec2359 100644 --- a/coordinator/internal/api/provider.go +++ b/coordinator/internal/api/provider.go @@ -685,9 +685,40 @@ func (s *Server) verifyChallengeResponse(providerID string, provider *registry.P ) } - // Verify fresh binary hash if reported and known hashes are configured. - if resp.BinaryHash != "" && len(s.knownBinaryHashes) > 0 { - if !s.knownBinaryHashes[resp.BinaryHash] { + // Verify fresh binary hash when a known-good policy is configured. A + // reported binary hash only counts when the response is signed by the + // provider key from a valid registration attestation. + policyConfigured, knownBinaryHashes := s.binaryHashPolicySnapshot() + if policyConfigured { + attestationResult := provider.AttestationResult + if attestationResult == nil || !attestationResult.Valid || attestationResult.PublicKey == "" { + s.logger.Error("provider cannot prove binary hash without valid attestation", + "provider_id", providerID, + ) + s.registry.MarkUntrusted(providerID) + s.handleChallengeFailure(providerID, "valid attestation required for binary hash policy") + return + } + if resp.BinaryHash == "" { + s.logger.Error("provider omitted binary hash while known-good policy is configured", + "provider_id", providerID, + ) + s.registry.MarkUntrusted(providerID) + s.handleChallengeFailure(providerID, "binary hash missing") + return + } + attestedBinaryHash, err := normalizeSHA256Hex(attestationResult.BinaryHash, "attested binary_hash") + if err != nil { + s.logger.Error("provider attestation has no usable binary hash", + "provider_id", providerID, + "binary_hash", attestationResult.BinaryHash, + ) + s.registry.MarkUntrusted(providerID) + s.handleChallengeFailure(providerID, "attested binary hash missing") + return + } + binaryHash, err := normalizeSHA256Hex(resp.BinaryHash, "binary_hash") + if err != nil || !knownBinaryHashes[binaryHash] { s.logger.Error("provider binary hash changed — no longer matches known-good list", "provider_id", providerID, "binary_hash", resp.BinaryHash, @@ -696,6 +727,16 @@ func (s *Server) verifyChallengeResponse(providerID string, provider *registry.P s.handleChallengeFailure(providerID, "binary hash mismatch") return } + if binaryHash != attestedBinaryHash { + s.logger.Error("provider binary hash changed from registration attestation", + "provider_id", providerID, + "attested_binary_hash", registry.TruncHash(attestedBinaryHash), + "challenge_binary_hash", registry.TruncHash(binaryHash), + ) + s.registry.MarkUntrusted(providerID) + s.handleChallengeFailure(providerID, "binary hash changed from registration attestation") + return + } } // Verify active model hash if reported and catalog has expected hash. @@ -1129,12 +1170,22 @@ func (s *Server) handleInferenceError(providerID string, provider *registry.Prov // verifyProviderAttestation verifies a provider's Secure Enclave attestation // if one was included in the registration message. If the attestation is valid, // the provider is marked as attested. If missing or invalid, the provider is -// still accepted (Open Mode) but marked as not attested. +// accepted in Open Mode only when no binary hash policy is configured. func (s *Server) verifyProviderAttestation(providerID string, provider *registry.Provider, regMsg *protocol.RegisterMessage) { + policyConfigured, knownBinaryHashes := s.binaryHashPolicySnapshot() if len(regMsg.Attestation) == 0 { - s.logger.Info("provider registered without attestation (Open Mode)", - "provider_id", providerID, - ) + if policyConfigured { + s.logger.Warn("provider registered without attestation while binary hash policy is configured", + "provider_id", providerID, + ) + provider.SetAttestationResult(&attestation.VerificationResult{ + Valid: false, + Error: "attestation missing", + }) + s.registry.MarkUntrusted(providerID) + return + } + s.logger.Info("provider registered without attestation (Open Mode)", "provider_id", providerID) return } @@ -1183,9 +1234,21 @@ func (s *Server) verifyProviderAttestation(providerID string, provider *registry } } - // Verify binary hash against known-good hashes. - if len(s.knownBinaryHashes) > 0 && result.BinaryHash != "" { - if !s.knownBinaryHashes[result.BinaryHash] { + // Verify binary hash against known-good hashes. Once a binary hash policy is + // configured, omission is a policy violation, not an Open Mode downgrade. + if policyConfigured { + if result.BinaryHash == "" { + s.logger.Warn("provider binary hash missing while known-good policy is configured", + "provider_id", providerID, + ) + result.Valid = false + result.Error = "binary hash missing" + provider.SetAttestationResult(&result) + s.registry.MarkUntrusted(providerID) + return + } + binaryHash, err := normalizeSHA256Hex(result.BinaryHash, "binary_hash") + if err != nil || !knownBinaryHashes[binaryHash] { s.logger.Warn("provider binary hash not in known-good list", "provider_id", providerID, "binary_hash", result.BinaryHash, @@ -1193,6 +1256,7 @@ func (s *Server) verifyProviderAttestation(providerID string, provider *registry result.Valid = false result.Error = "binary hash not recognized" provider.SetAttestationResult(&result) + s.registry.MarkUntrusted(providerID) return } s.logger.Info("provider binary hash verified", diff --git a/coordinator/internal/api/provider_test.go b/coordinator/internal/api/provider_test.go index 4db3885f..46aa3a67 100644 --- a/coordinator/internal/api/provider_test.go +++ b/coordinator/internal/api/provider_test.go @@ -9,6 +9,7 @@ import ( "encoding/asn1" "encoding/base64" "encoding/json" + "fmt" "io" "log/slog" "math/big" @@ -28,6 +29,8 @@ import ( "nhooyr.io/websocket" ) +const knownGoodBinaryHashForTest = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + func TestProviderWebSocketConnect(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) st := store.NewMemory("test-key") @@ -288,6 +291,10 @@ func rawP256PublicKeyB64ForTest(t *testing.T, pubKey *ecdsa.PublicKey) string { } func createTestAttestationJSON(t *testing.T, encryptionKey string) json.RawMessage { + return createTestAttestationJSONWithBinaryHash(t, encryptionKey, "") +} + +func createTestAttestationJSONWithBinaryHash(t *testing.T, encryptionKey, binaryHash string) json.RawMessage { t.Helper() privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -322,6 +329,9 @@ func createTestAttestationJSON(t *testing.T, encryptionKey string) json.RawMessa blobMap["encryptionPublicKey"] = encryptionKey registerTestChallengeSigner(encryptionKey, privKey) } + if binaryHash != "" { + blobMap["binaryHash"] = binaryHash + } blobJSON, err := json.Marshal(blobMap) if err != nil { @@ -411,6 +421,252 @@ func TestProviderRegistrationWithValidAttestation(t *testing.T) { } } +func TestProviderRegistrationRequiresBinaryHashWhenPolicyConfigured(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "missing-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSON(t, pubKey), + } + p := reg.Register("provider-1", nil, regMsg) + + srv.verifyProviderAttestation("provider-1", p, regMsg) + + if p.AttestationResult == nil { + t.Fatal("expected attestation result") + } + if p.AttestationResult.Valid { + t.Fatal("attestation should be invalid when binary hash policy is configured and hash is missing") + } + if p.AttestationResult.Error != "binary hash missing" { + t.Fatalf("attestation error = %q, want %q", p.AttestationResult.Error, "binary hash missing") + } + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } + if p.TrustLevel != registry.TrustNone { + t.Fatalf("provider trust = %q, want %q", p.TrustLevel, registry.TrustNone) + } +} + +func TestProviderRegistrationAcceptsKnownBinaryHash(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "known-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSONWithBinaryHash(t, pubKey, knownGoodBinaryHashForTest), + } + p := reg.Register("provider-1", nil, regMsg) + + srv.verifyProviderAttestation("provider-1", p, regMsg) + + if p.AttestationResult == nil { + t.Fatal("expected attestation result") + } + if !p.AttestationResult.Valid { + t.Fatalf("attestation should be valid with a known binary hash, got %q", p.AttestationResult.Error) + } + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status == registry.StatusUntrusted { + t.Fatal("provider should not be marked untrusted with a known binary hash") + } + if p.TrustLevel != registry.TrustSelfSigned { + t.Fatalf("provider trust = %q, want %q", p.TrustLevel, registry.TrustSelfSigned) + } +} + +func TestProviderRegistrationRejectsInvalidConfiguredBinaryHash(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{"not-a-sha256"}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "invalid-configured-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSONWithBinaryHash(t, pubKey, "not-a-sha256"), + } + p := reg.Register("provider-1", nil, regMsg) + + srv.verifyProviderAttestation("provider-1", p, regMsg) + + policyConfigured, knownHashes := srv.binaryHashPolicySnapshot() + if !policyConfigured { + t.Fatal("binary hash policy should remain configured even when configured hashes are invalid") + } + if len(knownHashes) != 0 { + t.Fatalf("known binary hashes = %d, want 0 valid hashes", len(knownHashes)) + } + if p.AttestationResult == nil { + t.Fatal("expected attestation result") + } + if p.AttestationResult.Valid { + t.Fatal("attestation should be invalid when configured hash and reported hash are invalid") + } + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } +} + +func TestSyncBinaryHashesRejectsInvalidStoredReleaseHashWithoutFailingOpen(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + if err := st.SetRelease(&store.Release{ + Version: "1.0.0", + Platform: "macos-arm64", + BinaryHash: "not-a-sha256", + BundleHash: strings.Repeat("b", 64), + URL: "https://r2.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz", + }); err != nil { + t.Fatalf("SetRelease: %v", err) + } + + srv.SyncBinaryHashes() + + policyConfigured, knownHashes := srv.binaryHashPolicySnapshot() + if !policyConfigured { + t.Fatal("binary hash policy should remain configured when an active release has an invalid hash") + } + if len(knownHashes) != 0 { + t.Fatalf("known binary hashes = %d, want 0 valid hashes", len(knownHashes)) + } +} + +func TestSyncBinaryHashesPreservesAdditionalConfiguredHashes(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + + manualHash := strings.Repeat("a", 64) + releaseHash := strings.Repeat("b", 64) + srv.AddKnownBinaryHashes([]string{manualHash}) + if err := st.SetRelease(&store.Release{ + Version: "1.0.0", + Platform: "macos-arm64", + BinaryHash: releaseHash, + BundleHash: strings.Repeat("c", 64), + URL: "https://r2.example.com/releases/v1.0.0/eigeninference-bundle-macos-arm64.tar.gz", + }); err != nil { + t.Fatalf("SetRelease: %v", err) + } + + srv.SyncBinaryHashes() + policyConfigured, knownHashes := srv.binaryHashPolicySnapshot() + if !policyConfigured { + t.Fatal("binary hash policy should be configured after manual hash and active release") + } + if !knownHashes[manualHash] { + t.Fatal("manual binary hash was dropped during release sync") + } + if !knownHashes[releaseHash] { + t.Fatal("release binary hash was not synced") + } + + if err := st.DeleteRelease("1.0.0", "macos-arm64"); err != nil { + t.Fatalf("DeleteRelease: %v", err) + } + srv.SyncBinaryHashes() + policyConfigured, knownHashes = srv.binaryHashPolicySnapshot() + if !policyConfigured { + t.Fatal("binary hash policy should remain configured after release deletion because manual hash remains") + } + if !knownHashes[manualHash] { + t.Fatal("manual binary hash was dropped during release deletion sync") + } + if knownHashes[releaseHash] { + t.Fatal("inactive release binary hash should not remain after sync") + } +} + +func TestBinaryHashPolicySnapshotConcurrentSync(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + manualHash := strings.Repeat("a", 64) + srv.AddKnownBinaryHashes([]string{manualHash}) + + done := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + policyConfigured, knownHashes := srv.binaryHashPolicySnapshot() + if policyConfigured && !knownHashes[manualHash] { + t.Errorf("manual hash missing from policy snapshot") + return + } + } + } + }() + } + + for i := 0; i < 50; i++ { + version := fmt.Sprintf("1.0.%d", i) + releaseHash := fmt.Sprintf("%064x", i+1) + if err := st.SetRelease(&store.Release{ + Version: version, + Platform: "macos-arm64", + BinaryHash: releaseHash, + BundleHash: strings.Repeat("c", 64), + URL: "https://r2.example.com/releases/v" + version + "/eigeninference-bundle-macos-arm64.tar.gz", + }); err != nil { + t.Fatalf("SetRelease: %v", err) + } + srv.SyncBinaryHashes() + if err := st.DeleteRelease(version, "macos-arm64"); err != nil { + t.Fatalf("DeleteRelease: %v", err) + } + srv.SyncBinaryHashes() + } + + close(done) + wg.Wait() +} + // TestProviderRegistrationWithInvalidAttestation verifies that a provider // with an invalid attestation is still registered but not marked as attested. func TestProviderRegistrationWithInvalidAttestation(t *testing.T) { @@ -504,6 +760,41 @@ func TestProviderRegistrationWithoutAttestation(t *testing.T) { } } +func TestProviderRegistrationWithoutAttestationRejectedWhenBinaryHashPolicyConfigured(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "no-attestation-policy-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + } + p := reg.Register("provider-1", nil, regMsg) + + srv.verifyProviderAttestation("provider-1", p, regMsg) + + if p.AttestationResult == nil { + t.Fatal("expected attestation result") + } + if p.AttestationResult.Valid { + t.Fatal("missing attestation should be invalid when binary hash policy is configured") + } + if p.AttestationResult.Error != "attestation missing" { + t.Fatalf("attestation error = %q, want %q", p.AttestationResult.Error, "attestation missing") + } + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } +} + // TestListModelsWithAttestationInfo verifies that /v1/models includes // attestation metadata. func TestListModelsWithAttestationInfo(t *testing.T) { @@ -860,6 +1151,210 @@ func TestChallengeResponseRejectsMissingSIPStatus(t *testing.T) { } } +func TestChallengeResponseRequiresBinaryHashWhenPolicyConfigured(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "missing-challenge-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSONWithBinaryHash(t, pubKey, knownGoodBinaryHashForTest), + } + p := reg.Register("provider-1", nil, regMsg) + srv.verifyProviderAttestation("provider-1", p, regMsg) + sipEnabled := true + secureBootEnabled := true + rdmaDisabled := true + challengeTimestamp := "2026-04-24T12:00:00Z" + + srv.verifyChallengeResponse("provider-1", p, &pendingChallenge{ + nonce: "nonce-1", + timestamp: challengeTimestamp, + }, &protocol.AttestationResponseMessage{ + Type: protocol.TypeAttestationResponse, + Nonce: "nonce-1", + Signature: testChallengeSignature("nonce-1", challengeTimestamp, pubKey), + PublicKey: pubKey, + SIPEnabled: &sipEnabled, + SecureBootEnabled: &secureBootEnabled, + RDMADisabled: &rdmaDisabled, + }) + + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } + if p.FailedChallenges != 1 { + t.Fatalf("failed challenges = %d, want 1", p.FailedChallenges) + } + if !p.LastChallengeVerified.IsZero() { + t.Fatal("provider should not record challenge success when binary hash is omitted") + } +} + +func TestChallengeResponseRejectsUnsignedBinaryHashWhenPolicyConfigured(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + pubKey := testPublicKeyB64() + p := reg.Register("provider-1", nil, &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "unsigned-challenge-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + }) + sipEnabled := true + secureBootEnabled := true + rdmaDisabled := true + + srv.verifyChallengeResponse("provider-1", p, &pendingChallenge{ + nonce: "nonce-1", + timestamp: "2026-04-24T12:00:00Z", + }, &protocol.AttestationResponseMessage{ + Type: protocol.TypeAttestationResponse, + Nonce: "nonce-1", + Signature: "dGVzdHNpZ25hdHVyZQ==", + PublicKey: pubKey, + SIPEnabled: &sipEnabled, + SecureBootEnabled: &secureBootEnabled, + RDMADisabled: &rdmaDisabled, + BinaryHash: knownGoodBinaryHashForTest, + }) + + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } + if p.FailedChallenges != 1 { + t.Fatalf("failed challenges = %d, want 1", p.FailedChallenges) + } + if !p.LastChallengeVerified.IsZero() { + t.Fatal("provider should not record challenge success for an unsigned binary hash") + } +} + +func TestChallengeResponseRejectsHashChangedFromRegistrationAttestation(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + otherKnownHash := strings.Repeat("f", 64) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest, otherKnownHash}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "changed-challenge-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSONWithBinaryHash(t, pubKey, knownGoodBinaryHashForTest), + } + p := reg.Register("provider-1", nil, regMsg) + srv.verifyProviderAttestation("provider-1", p, regMsg) + sipEnabled := true + secureBootEnabled := true + rdmaDisabled := true + challengeTimestamp := "2026-04-24T12:00:00Z" + + srv.verifyChallengeResponse("provider-1", p, &pendingChallenge{ + nonce: "nonce-1", + timestamp: challengeTimestamp, + }, &protocol.AttestationResponseMessage{ + Type: protocol.TypeAttestationResponse, + Nonce: "nonce-1", + Signature: testChallengeSignature("nonce-1", challengeTimestamp, pubKey), + PublicKey: pubKey, + SIPEnabled: &sipEnabled, + SecureBootEnabled: &secureBootEnabled, + RDMADisabled: &rdmaDisabled, + BinaryHash: otherKnownHash, + }) + + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status != registry.StatusUntrusted { + t.Fatalf("provider status = %q, want %q", p.Status, registry.StatusUntrusted) + } + if p.FailedChallenges != 1 { + t.Fatalf("failed challenges = %d, want 1", p.FailedChallenges) + } + if !p.LastChallengeVerified.IsZero() { + t.Fatal("provider should not record challenge success when binary hash changed from attestation") + } +} + +func TestChallengeResponseAcceptsKnownBinaryHash(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + st := store.NewMemory("test-key") + reg := registry.New(logger) + srv := NewServer(reg, st, logger) + srv.SetKnownBinaryHashes([]string{knownGoodBinaryHashForTest}) + + pubKey := testPublicKeyB64() + regMsg := &protocol.RegisterMessage{ + Type: protocol.TypeRegister, + Hardware: protocol.Hardware{ChipName: "Apple M3 Max", MemoryGB: 64}, + Models: []protocol.ModelInfo{{ID: "known-challenge-binary-hash-model", ModelType: "chat", Quantization: "4bit"}}, + Backend: "inprocess-mlx", + PublicKey: pubKey, + EncryptedResponseChunks: true, + PrivacyCapabilities: testPrivacyCaps(), + Attestation: createTestAttestationJSONWithBinaryHash(t, pubKey, knownGoodBinaryHashForTest), + } + p := reg.Register("provider-1", nil, regMsg) + srv.verifyProviderAttestation("provider-1", p, regMsg) + sipEnabled := true + secureBootEnabled := true + rdmaDisabled := true + challengeTimestamp := "2026-04-24T12:00:00Z" + + srv.verifyChallengeResponse("provider-1", p, &pendingChallenge{ + nonce: "nonce-1", + timestamp: challengeTimestamp, + }, &protocol.AttestationResponseMessage{ + Type: protocol.TypeAttestationResponse, + Nonce: "nonce-1", + Signature: testChallengeSignature("nonce-1", challengeTimestamp, pubKey), + PublicKey: pubKey, + SIPEnabled: &sipEnabled, + SecureBootEnabled: &secureBootEnabled, + RDMADisabled: &rdmaDisabled, + BinaryHash: knownGoodBinaryHashForTest, + }) + + p.Mu().Lock() + defer p.Mu().Unlock() + if p.Status == registry.StatusUntrusted { + t.Fatal("provider should not be marked untrusted with a known binary hash") + } + if p.FailedChallenges != 0 { + t.Fatalf("failed challenges = %d, want 0", p.FailedChallenges) + } + if p.LastChallengeVerified.IsZero() { + t.Fatal("provider should record challenge success with a known binary hash") + } +} + func TestChallengeResponseMissingSIPClearsExistingRoutingEligibility(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) st := store.NewMemory("test-key") diff --git a/coordinator/internal/api/release_handlers.go b/coordinator/internal/api/release_handlers.go index 7cdef4f3..dc5fd106 100644 --- a/coordinator/internal/api/release_handlers.go +++ b/coordinator/internal/api/release_handlers.go @@ -1,44 +1,118 @@ package api import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" "crypto/subtle" + "encoding/hex" "encoding/json" + "fmt" + "io" + "net" "net/http" + "net/url" + "os" + "path" + "regexp" + "strings" "time" "github.com/eigeninference/coordinator/internal/auth" "github.com/eigeninference/coordinator/internal/store" ) +const ( + maxReleaseRegisterBodyBytes = 64 * 1024 + maxReleaseArtifactBytes = 2 << 30 // 2 GiB + maxReleaseProviderBinBytes = 512 << 20 + releaseArtifactTimeout = 2 * time.Minute +) + +var ( + releaseVersionPattern = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+(?:[-+][0-9A-Za-z.-]+)?$`) + releasePlatformPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._-]{0,63}$`) + releaseTemplateNamePattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`) +) + +type registerReleaseRequest struct { + Version string `json:"version"` + Platform string `json:"platform"` + BinaryHash string `json:"binary_hash"` + BundleHash string `json:"bundle_hash"` + PythonHash string `json:"python_hash,omitempty"` + RuntimeHash string `json:"runtime_hash,omitempty"` + TemplateHashes string `json:"template_hashes,omitempty"` + URL string `json:"url"` + Changelog string `json:"changelog"` +} + +func (req registerReleaseRequest) toRelease() store.Release { + return store.Release{ + Version: req.Version, + Platform: req.Platform, + BinaryHash: req.BinaryHash, + BundleHash: req.BundleHash, + PythonHash: req.PythonHash, + RuntimeHash: req.RuntimeHash, + TemplateHashes: req.TemplateHashes, + URL: req.URL, + Changelog: req.Changelog, + } +} + // handleRegisterRelease handles POST /v1/releases. // Called by GitHub Actions to register a new provider binary release. // Authenticated with a scoped release key (NOT admin credentials). func (s *Server) handleRegisterRelease(w http.ResponseWriter, r *http.Request) { // Verify scoped release key. token := extractBearerToken(r) - if s.releaseKey == "" || token != s.releaseKey { + if !s.releaseKeyAuthorized(token) { writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized", "invalid release key")) return } - var release store.Release - if err := json.NewDecoder(r.Body).Decode(&release); err != nil { + var req registerReleaseRequest + r.Body = http.MaxBytesReader(w, r.Body, maxReleaseRegisterBodyBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "invalid JSON: "+err.Error())) return } - if release.Version == "" { - writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "version is required")) + if err := dec.Decode(&struct{}{}); err != io.EOF { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "invalid JSON: multiple JSON values")) return } + release := req.toRelease() if release.Platform == "" { release.Platform = "macos-arm64" // default } - if release.BinaryHash == "" { - writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "binary_hash is required")) + + if err := s.validateReleaseMetadata(&release); err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", err.Error())) return } - if release.URL == "" { - writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "url is required")) + + if s.r2CDNURL == "" { + s.logger.Error("release: artifact verification unavailable because R2 CDN URL is not configured", + "version", release.Version, + "platform", release.Platform, + ) + writeJSON(w, http.StatusServiceUnavailable, errorResponse("not_configured", "release artifact verification requires R2 CDN URL")) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), releaseArtifactTimeout) + defer cancel() + if err := s.verifyReleaseArtifact(ctx, &release); err != nil { + s.logger.Warn("release: artifact verification failed", + "version", release.Version, + "platform", release.Platform, + "error", err, + ) + writeJSON(w, http.StatusBadRequest, errorResponse("invalid_request_error", "release artifact verification failed: "+err.Error())) return } @@ -71,6 +145,298 @@ func (s *Server) handleRegisterRelease(w http.ResponseWriter, r *http.Request) { }) } +func (s *Server) releaseKeyAuthorized(token string) bool { + if s.releaseKey == "" || token == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(token), []byte(s.releaseKey)) == 1 +} + +func (s *Server) validateReleaseMetadata(release *store.Release) error { + release.Version = strings.TrimSpace(release.Version) + release.Platform = strings.TrimSpace(release.Platform) + release.BinaryHash = strings.TrimSpace(release.BinaryHash) + release.BundleHash = strings.TrimSpace(release.BundleHash) + release.PythonHash = strings.TrimSpace(release.PythonHash) + release.RuntimeHash = strings.TrimSpace(release.RuntimeHash) + release.TemplateHashes = strings.TrimSpace(release.TemplateHashes) + release.URL = strings.TrimSpace(release.URL) + + if release.Version == "" { + return fmt.Errorf("version is required") + } + if !releaseVersionPattern.MatchString(release.Version) { + return fmt.Errorf("version must be semver, e.g. 1.2.3 or 1.2.3-dev.1") + } + if release.Platform == "" { + return fmt.Errorf("platform is required") + } + if !releasePlatformPattern.MatchString(release.Platform) { + return fmt.Errorf("platform contains invalid characters") + } + + var err error + if release.BinaryHash, err = normalizeSHA256Hex(release.BinaryHash, "binary_hash"); err != nil { + return err + } + if release.BundleHash, err = normalizeSHA256Hex(release.BundleHash, "bundle_hash"); err != nil { + return err + } + if release.PythonHash != "" { + if release.PythonHash, err = normalizeSHA256Hex(release.PythonHash, "python_hash"); err != nil { + return err + } + } + if release.RuntimeHash != "" { + if release.RuntimeHash, err = normalizeSHA256Hex(release.RuntimeHash, "runtime_hash"); err != nil { + return err + } + } + if release.TemplateHashes != "" { + if release.TemplateHashes, err = normalizeTemplateHashes(release.TemplateHashes); err != nil { + return err + } + } + if release.URL == "" { + return fmt.Errorf("url is required") + } + if s.r2CDNURL != "" { + if _, err := s.trustedReleaseArtifactURL(release); err != nil { + return err + } + } + return nil +} + +func (s *Server) trustedReleaseArtifactURL(release *store.Release) (*url.URL, error) { + expectedURL, err := expectedReleaseArtifactURL(s.r2CDNURL, release.Version, release.Platform) + if err != nil { + return nil, err + } + if !sameReleaseArtifactURL(release.URL, expectedURL) { + return nil, fmt.Errorf("url must match configured release artifact path") + } + parsed, err := url.Parse(expectedURL) + if err != nil { + return nil, fmt.Errorf("configured release artifact URL is invalid") + } + return parsed, nil +} + +func expectedReleaseArtifactURL(baseURL, version, platform string) (string, error) { + version = strings.TrimSpace(version) + platform = strings.TrimSpace(platform) + if !releaseVersionPattern.MatchString(version) { + return "", fmt.Errorf("version must be semver, e.g. 1.2.3 or 1.2.3-dev.1") + } + if !releasePlatformPattern.MatchString(platform) { + return "", fmt.Errorf("platform contains invalid characters") + } + + u, err := url.Parse(strings.TrimSpace(baseURL)) + if err != nil { + return "", fmt.Errorf("configured R2 CDN URL is invalid") + } + if u.User != nil || u.RawQuery != "" || u.Fragment != "" { + return "", fmt.Errorf("configured R2 CDN URL must not include credentials, query, or fragment") + } + if u.Host == "" { + return "", fmt.Errorf("configured R2 CDN URL must include a host") + } + if u.Scheme != "https" && u.Scheme != "http" { + return "", fmt.Errorf("configured R2 CDN URL must be absolute") + } + if u.Scheme == "http" && !isLoopbackHost(u.Hostname()) { + return "", fmt.Errorf("configured R2 CDN URL must use https") + } + u.Path = path.Join(u.Path, "releases", "v"+version, "eigeninference-bundle-"+platform+".tar.gz") + u.RawQuery = "" + u.Fragment = "" + return u.String(), nil +} + +func isLoopbackHost(host string) bool { + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback() +} + +func sameReleaseArtifactURL(actual, expected string) bool { + actualURL, err := url.Parse(strings.TrimSpace(actual)) + if err != nil { + return false + } + expectedURL, err := url.Parse(expected) + if err != nil { + return false + } + if actualURL.User != nil || expectedURL.User != nil { + return false + } + return strings.EqualFold(actualURL.Scheme, expectedURL.Scheme) && + strings.EqualFold(actualURL.Host, expectedURL.Host) && + path.Clean(actualURL.EscapedPath()) == path.Clean(expectedURL.EscapedPath()) && + actualURL.RawQuery == "" && + actualURL.Fragment == "" +} + +func normalizeSHA256Hex(value, field string) (string, error) { + value = strings.ToLower(strings.TrimSpace(value)) + if len(value) != sha256.Size*2 { + return "", fmt.Errorf("%s must be a 64-character SHA-256 hex digest", field) + } + if _, err := hex.DecodeString(value); err != nil { + return "", fmt.Errorf("%s must be a valid SHA-256 hex digest", field) + } + return value, nil +} + +func normalizeTemplateHashes(raw string) (string, error) { + entries := strings.Split(raw, ",") + normalized := make([]string, 0, len(entries)) + for _, entry := range entries { + entry = strings.TrimSpace(entry) + if entry == "" { + continue + } + name, hash, ok := strings.Cut(entry, "=") + if !ok { + return "", fmt.Errorf("template_hashes entries must be name=sha256") + } + name = strings.TrimSpace(name) + if name == "" || !releaseTemplateNamePattern.MatchString(name) { + return "", fmt.Errorf("template_hashes contains an invalid template name") + } + hash, err := normalizeSHA256Hex(hash, "template_hashes") + if err != nil { + return "", err + } + normalized = append(normalized, name+"="+hash) + } + return strings.Join(normalized, ","), nil +} + +func (s *Server) verifyReleaseArtifact(ctx context.Context, release *store.Release) error { + downloadURL, err := s.trustedReleaseArtifactURL(release) + if err != nil { + return err + } + req := &http.Request{ + Method: http.MethodGet, + URL: downloadURL, + Header: make(http.Header), + } + req = req.WithContext(ctx) + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("download bundle: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download bundle returned status %d", resp.StatusCode) + } + + tmp, err := os.CreateTemp("", "darkbloom-release-*.tar.gz") + if err != nil { + return fmt.Errorf("create temp bundle: %w", err) + } + defer func() { + tmp.Close() + os.Remove(tmp.Name()) + }() + + bundleHash := sha256.New() + limited := io.LimitReader(resp.Body, maxReleaseArtifactBytes+1) + n, err := io.Copy(io.MultiWriter(tmp, bundleHash), limited) + if err != nil { + return fmt.Errorf("read bundle: %w", err) + } + if n > maxReleaseArtifactBytes { + return fmt.Errorf("bundle exceeds maximum size") + } + actualBundleHash := hex.EncodeToString(bundleHash.Sum(nil)) + if actualBundleHash != release.BundleHash { + return fmt.Errorf("bundle_hash does not match release artifact") + } + + if _, err := tmp.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("rewind bundle: %w", err) + } + + gz, err := gzip.NewReader(tmp) + if err != nil { + return fmt.Errorf("open bundle gzip: %w", err) + } + defer gz.Close() + + tarReader := tar.NewReader(gz) + binaryHash := sha256.New() + foundBinary := false + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("read bundle tar: %w", err) + } + cleanName, err := cleanReleaseTarPath(header.Name) + if err != nil { + return err + } + if cleanName != "bin/darkbloom" { + continue + } + if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA { + return fmt.Errorf("bundled provider binary is not a regular file") + } + if foundBinary { + return fmt.Errorf("bundle contains multiple provider binaries") + } + if header.Size < 0 || header.Size > maxReleaseProviderBinBytes { + return fmt.Errorf("provider binary exceeds maximum size") + } + n, err := io.Copy(binaryHash, io.LimitReader(tarReader, maxReleaseProviderBinBytes+1)) + if err != nil { + return fmt.Errorf("read provider binary: %w", err) + } + if n > maxReleaseProviderBinBytes { + return fmt.Errorf("provider binary exceeds maximum size") + } + foundBinary = true + } + if !foundBinary { + return fmt.Errorf("bundle is missing bin/darkbloom") + } + + actualBinaryHash := hex.EncodeToString(binaryHash.Sum(nil)) + if actualBinaryHash != release.BinaryHash { + return fmt.Errorf("binary_hash does not match bundled provider binary") + } + return nil +} + +func cleanReleaseTarPath(name string) (string, error) { + if name == "" || strings.HasPrefix(name, "/") { + return "", fmt.Errorf("bundle contains unsafe path") + } + for _, part := range strings.Split(name, "/") { + if part == ".." { + return "", fmt.Errorf("bundle contains unsafe path") + } + } + return strings.TrimPrefix(path.Clean(name), "./"), nil +} + // handleLatestRelease handles GET /v1/releases/latest. // Public endpoint — returns the latest active release for a platform. // Used by install.sh to get the download URL and expected hash. diff --git a/coordinator/internal/api/server.go b/coordinator/internal/api/server.go index 303d71ce..8d6cda64 100644 --- a/coordinator/internal/api/server.go +++ b/coordinator/internal/api/server.go @@ -113,9 +113,16 @@ type Server struct { stepCAIntermediateCert *x509.Certificate // step-ca intermediate CA // knownBinaryHashes is the set of accepted provider binary SHA-256 hashes. - // When non-empty, providers whose binary hash doesn't match are rejected. + // When binaryHashPolicyConfigured is true, providers whose binary hash is + // missing or doesn't match are rejected. // Auto-populated from active releases via SyncBinaryHashes(). - knownBinaryHashes map[string]bool + binaryHashPolicyMu sync.RWMutex + knownBinaryHashes map[string]bool + manualKnownBinaryHashes map[string]bool + releaseKnownBinaryHashes map[string]bool + manualBinaryHashPolicyConfigured bool + releaseBinaryHashPolicyConfigured bool + binaryHashPolicyConfigured bool // knownRuntimeManifest holds accepted runtime component hashes. // When set, providers whose runtime hashes don't match are marked as @@ -412,24 +419,57 @@ func (s *Server) SyncModelCatalog() { // SetKnownBinaryHashes configures the set of accepted provider binary hashes. // Providers whose binary SHA-256 doesn't match any known hash are rejected. func (s *Server) SetKnownBinaryHashes(hashes []string) { - s.knownBinaryHashes = make(map[string]bool, len(hashes)) + normalized := normalizeKnownBinaryHashes(hashes, s.logger) + + s.binaryHashPolicyMu.Lock() + defer s.binaryHashPolicyMu.Unlock() + + s.manualKnownBinaryHashes = normalized + s.manualBinaryHashPolicyConfigured = hasConfiguredHashInput(hashes) + s.rebuildBinaryHashPolicyLocked() +} + +func normalizeKnownBinaryHashes(hashes []string, logger *slog.Logger) map[string]bool { + normalizedHashes := make(map[string]bool, len(hashes)) for _, h := range hashes { - if h != "" { - s.knownBinaryHashes[h] = true + normalized, err := normalizeSHA256Hex(h, "known_binary_hashes") + if err != nil { + if strings.TrimSpace(h) != "" { + logger.Warn("invalid known binary hash ignored", "hash", h, "error", err) + } + continue } + normalizedHashes[normalized] = true } + return normalizedHashes } // AddKnownBinaryHashes adds hashes to the existing known set (for env var fallback). func (s *Server) AddKnownBinaryHashes(hashes []string) { - if s.knownBinaryHashes == nil { - s.knownBinaryHashes = make(map[string]bool) + normalized := normalizeKnownBinaryHashes(hashes, s.logger) + + s.binaryHashPolicyMu.Lock() + defer s.binaryHashPolicyMu.Unlock() + + if s.manualKnownBinaryHashes == nil { + s.manualKnownBinaryHashes = make(map[string]bool) + } + if hasConfiguredHashInput(hashes) { + s.manualBinaryHashPolicyConfigured = true } + for h := range normalized { + s.manualKnownBinaryHashes[h] = true + } + s.rebuildBinaryHashPolicyLocked() +} + +func hasConfiguredHashInput(hashes []string) bool { for _, h := range hashes { - if h != "" { - s.knownBinaryHashes[h] = true + if strings.TrimSpace(h) != "" { + return true } } + return false } // SetConsoleURL sets the frontend URL for device auth verification links. @@ -464,13 +504,52 @@ func (s *Server) CoordinatorKey() *e2e.CoordinatorKey { func (s *Server) SyncBinaryHashes() { releases := s.store.ListReleases() hashes := make(map[string]bool) + policyConfigured := false for _, r := range releases { - if r.Active && r.BinaryHash != "" { - hashes[r.BinaryHash] = true + if !r.Active { + continue } + policyConfigured = true + normalized, err := normalizeSHA256Hex(r.BinaryHash, "release.binary_hash") + if err != nil { + s.logger.Warn("invalid release binary hash ignored", + "version", r.Version, + "platform", r.Platform, + "error", err, + ) + continue + } + hashes[normalized] = true + } + + s.binaryHashPolicyMu.Lock() + s.releaseKnownBinaryHashes = hashes + s.releaseBinaryHashPolicyConfigured = policyConfigured + s.rebuildBinaryHashPolicyLocked() + knownHashCount := len(s.knownBinaryHashes) + effectivePolicyConfigured := s.binaryHashPolicyConfigured + s.binaryHashPolicyMu.Unlock() + + s.logger.Info("binary hashes synced from releases", "known_hashes", knownHashCount, "policy_configured", effectivePolicyConfigured) +} + +func (s *Server) rebuildBinaryHashPolicyLocked() { + hashes := make(map[string]bool, len(s.manualKnownBinaryHashes)+len(s.releaseKnownBinaryHashes)) + for h := range s.releaseKnownBinaryHashes { + hashes[h] = true + } + for h := range s.manualKnownBinaryHashes { + hashes[h] = true } s.knownBinaryHashes = hashes - s.logger.Info("binary hashes synced from releases", "known_hashes", len(hashes)) + s.binaryHashPolicyConfigured = s.manualBinaryHashPolicyConfigured || s.releaseBinaryHashPolicyConfigured +} + +func (s *Server) binaryHashPolicySnapshot() (bool, map[string]bool) { + s.binaryHashPolicyMu.RLock() + defer s.binaryHashPolicyMu.RUnlock() + + return s.binaryHashPolicyConfigured, s.knownBinaryHashes } // SyncRuntimeManifest builds the runtime manifest from active releases. diff --git a/docs/release-runbook.md b/docs/release-runbook.md index b266876d..c5f3855d 100644 --- a/docs/release-runbook.md +++ b/docs/release-runbook.md @@ -187,7 +187,7 @@ To rollback to a previous version, deactivate the bad version. The old version i | Env var (coordinator) | `EIGENINFERENCE_RELEASE_KEY` | | GitHub Secret | `EIGENINFERENCE_RELEASE_KEY` | | Scope | Can only `POST /v1/releases` — no admin access | -| If leaked | Attacker can register fake releases, but binary hash must match what providers actually run — no impact | +| If leaked | Release registration still requires the URL to match `EIGENINFERENCE_R2_CDN_URL` and the coordinator verifies the downloaded bundle hash plus bundled `bin/darkbloom` hash before whitelisting it. Treat leakage as serious, but the key alone should not be enough to whitelist an arbitrary provider binary unless the release artifact origin is also compromised. | ### Admin access (managing releases) @@ -210,11 +210,11 @@ To rollback to a previous version, deactivate the bad version. The old version i ## How Binary Verification Works 1. **At build time**: SHA-256 of `darkbloom` binary is computed -2. **At release registration**: hash stored in coordinator's release table -3. **At startup**: `SyncBinaryHashes()` loads all active release hashes into `knownBinaryHashes` -4. **At provider registration**: attestation blob contains `binaryHash` → checked against known set -5. **At every challenge** (every 3 min): provider re-computes its binary hash → sent in response → checked against known set -6. **Unknown hash**: provider's attestation is rejected (stays at TrustNone, no requests routed) +2. **At release registration**: coordinator downloads the R2 bundle, verifies `bundle_hash`, extracts `bin/darkbloom`, and verifies `binary_hash` before storing the release +3. **At startup and release changes**: `SyncBinaryHashes()` loads all active release hashes and preserves additive env/manual hashes in `knownBinaryHashes` +4. **At provider registration**: attestation blob must contain `binaryHash` → checked against known set; Open Mode is rejected when a binary hash policy is configured +5. **At every challenge** (every 3 min): provider re-computes its binary hash → challenge signature must verify against the attested Secure Enclave key → hash must still match the signed registration attestation and known set +6. **Missing or unknown hash**: provider's attestation/challenge is rejected and the provider is marked untrusted ## How Install Verification Works diff --git a/provider/src/coordinator.rs b/provider/src/coordinator.rs index 1c7d5a38..1d02f11a 100644 --- a/provider/src/coordinator.rs +++ b/provider/src/coordinator.rs @@ -1222,9 +1222,34 @@ mod tests { .await .unwrap(); - // Read heartbeat or any response - if let Some(Ok(Message::Text(text))) = read.next().await { - received_messages.push(text.to_string()); + // Read until the plaintext rejection is observed. The client may + // emit WebSocket control frames or heartbeats before the error. + let deadline = tokio::time::Instant::now() + Duration::from_secs(5); + loop { + let Some(remaining) = deadline.checked_duration_since(tokio::time::Instant::now()) + else { + break; + }; + let frame = match tokio::time::timeout(remaining, read.next()).await { + Ok(Some(Ok(frame))) => frame, + _ => break, + }; + + match frame { + Message::Text(text) => { + let is_inference_error = serde_json::from_str::(&text) + .map(|v| v["type"] == "inference_error") + .unwrap_or(false); + received_messages.push(text.to_string()); + if is_inference_error { + break; + } + } + Message::Ping(data) => { + let _ = write.send(Message::Pong(data)).await; + } + _ => {} + } } // Send cancel