Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion coordinator/internal/api/device_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
})

case "approved":
// Generate a long-lived provider token.
// Generate and persist the provider token first, then consume the
// device code. This ordering ensures that a transient failure in token
// creation leaves the device code in "approved" state so the client
// can retry, rather than permanently invalidating it.
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
writeJSON(w, http.StatusInternalServerError, errorResponse("server_error", "failed to generate token"))
Expand All @@ -146,6 +149,22 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
return
}

// Consume the device code to prevent a second poll from issuing a
// second token. If this fails (race: another request already consumed
// it, or a DB error), revoke the token we just created and return the
// appropriate status — the device code remains in its current state
// and the client can retry if it was a transient error.
if err := s.store.ConsumeDeviceCode(dc.DeviceCode); err != nil {
_ = s.store.RevokeProviderToken(rawToken)
// Distinguish race (already consumed) from transient store failure.
if strings.Contains(err.Error(), "not approved") || strings.Contains(err.Error(), "consumed") {
writeJSON(w, http.StatusGone, errorResponse("expired_token", "device code is no longer valid"))
} else {
writeJSON(w, http.StatusInternalServerError, errorResponse("server_error", "failed to finalize token"))
}
return
}

s.logger.Info("provider token issued",
"account_id", dc.AccountID,
"user_code", dc.UserCode,
Expand Down
6 changes: 6 additions & 0 deletions coordinator/internal/store/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ type Store interface {
// ApproveDeviceCode links a device code to an account, marking it approved.
ApproveDeviceCode(deviceCode, accountID string) error

// ConsumeDeviceCode atomically transitions a device code from approved to
// consumed. Returns an error if the code is not in the approved state.
// Must be called when issuing a provider token to prevent a second poll
// from minting an additional token for the same approval.
ConsumeDeviceCode(deviceCode string) error

// DeleteExpiredDeviceCodes removes device codes that have passed their expiry.
DeleteExpiredDeviceCodes() error

Expand Down
15 changes: 15 additions & 0 deletions coordinator/internal/store/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,21 @@ func (s *MemoryStore) ApproveDeviceCode(deviceCode, accountID string) error {
return nil
}

func (s *MemoryStore) ConsumeDeviceCode(deviceCode string) error {
s.mu.Lock()
defer s.mu.Unlock()

dc, ok := s.deviceCodesByCode[deviceCode]
if !ok {
return errors.New("device code not found")
}
if dc.Status != "approved" {
return fmt.Errorf("device code is %s, not approved", dc.Status)
}
dc.Status = "consumed"
return nil
}

func (s *MemoryStore) DeleteExpiredDeviceCodes() error {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
18 changes: 18 additions & 0 deletions coordinator/internal/store/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,24 @@ func (s *PostgresStore) ApproveDeviceCode(deviceCode, accountID string) error {
return nil
}

func (s *PostgresStore) ConsumeDeviceCode(deviceCode string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

tag, err := s.pool.Exec(ctx,
`UPDATE device_codes SET status = 'consumed'
WHERE device_code = $1 AND status = 'approved'`,
deviceCode,
)
if err != nil {
return fmt.Errorf("store: consume device code: %w", err)
}
if tag.RowsAffected() == 0 {
return errors.New("device code not found or not in approved state")
}
return nil
}

func (s *PostgresStore) DeleteExpiredDeviceCodes() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down
Loading