diff --git a/coordinator/internal/api/device_auth.go b/coordinator/internal/api/device_auth.go index 778c85150..def1abc19 100644 --- a/coordinator/internal/api/device_auth.go +++ b/coordinator/internal/api/device_auth.go @@ -126,7 +126,10 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { }) case "approved": - // Generate a long-lived provider token. + // Generate and persist the provider token first, then consume the + // device code. This ordering ensures that a transient failure in token + // creation leaves the device code in "approved" state so the client + // can retry, rather than permanently invalidating it. tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { writeJSON(w, http.StatusInternalServerError, errorResponse("server_error", "failed to generate token")) @@ -146,6 +149,22 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { return } + // Consume the device code to prevent a second poll from issuing a + // second token. If this fails (race: another request already consumed + // it, or a DB error), revoke the token we just created and return the + // appropriate status — the device code remains in its current state + // and the client can retry if it was a transient error. + if err := s.store.ConsumeDeviceCode(dc.DeviceCode); err != nil { + _ = s.store.RevokeProviderToken(rawToken) + // Distinguish race (already consumed) from transient store failure. + if strings.Contains(err.Error(), "not approved") || strings.Contains(err.Error(), "consumed") { + writeJSON(w, http.StatusGone, errorResponse("expired_token", "device code is no longer valid")) + } else { + writeJSON(w, http.StatusInternalServerError, errorResponse("server_error", "failed to finalize token")) + } + return + } + s.logger.Info("provider token issued", "account_id", dc.AccountID, "user_code", dc.UserCode, diff --git a/coordinator/internal/store/interface.go b/coordinator/internal/store/interface.go index 4c2322190..0be9b6b2d 100644 --- a/coordinator/internal/store/interface.go +++ b/coordinator/internal/store/interface.go @@ -234,6 +234,12 @@ type Store interface { // ApproveDeviceCode links a device code to an account, marking it approved. ApproveDeviceCode(deviceCode, accountID string) error + // ConsumeDeviceCode atomically transitions a device code from approved to + // consumed. Returns an error if the code is not in the approved state. + // Must be called when issuing a provider token to prevent a second poll + // from minting an additional token for the same approval. + ConsumeDeviceCode(deviceCode string) error + // DeleteExpiredDeviceCodes removes device codes that have passed their expiry. DeleteExpiredDeviceCodes() error diff --git a/coordinator/internal/store/memory.go b/coordinator/internal/store/memory.go index 69544f0ef..e7117d859 100644 --- a/coordinator/internal/store/memory.go +++ b/coordinator/internal/store/memory.go @@ -1095,6 +1095,21 @@ func (s *MemoryStore) ApproveDeviceCode(deviceCode, accountID string) error { return nil } +func (s *MemoryStore) ConsumeDeviceCode(deviceCode string) error { + s.mu.Lock() + defer s.mu.Unlock() + + dc, ok := s.deviceCodesByCode[deviceCode] + if !ok { + return errors.New("device code not found") + } + if dc.Status != "approved" { + return fmt.Errorf("device code is %s, not approved", dc.Status) + } + dc.Status = "consumed" + return nil +} + func (s *MemoryStore) DeleteExpiredDeviceCodes() error { s.mu.Lock() defer s.mu.Unlock() diff --git a/coordinator/internal/store/postgres.go b/coordinator/internal/store/postgres.go index dbea9e43c..e43b68a71 100644 --- a/coordinator/internal/store/postgres.go +++ b/coordinator/internal/store/postgres.go @@ -1829,6 +1829,24 @@ func (s *PostgresStore) ApproveDeviceCode(deviceCode, accountID string) error { return nil } +func (s *PostgresStore) ConsumeDeviceCode(deviceCode string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tag, err := s.pool.Exec(ctx, + `UPDATE device_codes SET status = 'consumed' + WHERE device_code = $1 AND status = 'approved'`, + deviceCode, + ) + if err != nil { + return fmt.Errorf("store: consume device code: %w", err) + } + if tag.RowsAffected() == 0 { + return errors.New("device code not found or not in approved state") + } + return nil +} + func (s *PostgresStore) DeleteExpiredDeviceCodes() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()