Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 17 additions & 29 deletions apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ const ApiKeyTablePK = "value"
const (
paramNewKeyId = "newKeyId"
paramNewKeySecret = "newKeySecret"
paramOldKeyId = "oldKeyId"
paramOldKeySecret = "oldKeySecret"
)

const (
Expand Down Expand Up @@ -458,22 +456,28 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
// any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the
// new key. Check the response data to determine if the rotation process is complete.
func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) {
requestBody, err := parseRotateKeyRequestBody(r.Body)
var requestBody map[string]string
err := json.NewDecoder(r.Body).Decode(&requestBody)
if err != nil {
if strings.HasSuffix(err.Error(), "is required") {
jsonResponse(w, err, http.StatusBadRequest)
} else {
log.Printf("invalid request in RotateApiKey: %s", err)
jsonResponse(w, invalidRequest, http.StatusBadRequest)
}
log.Printf("invalid request in ActivateApiKey: %s", err)
jsonResponse(w, invalidRequest, http.StatusBadRequest)
return
}

if requestBody[paramNewKeyId] == "" {
jsonResponse(w, paramNewKeyId+" is required", http.StatusBadRequest)
return
}

oldKey := ApiKey{Key: requestBody[paramOldKeyId], Store: a.GetDB()}
err = oldKey.loadAndCheck(requestBody[paramOldKeySecret])
if requestBody[paramNewKeySecret] == "" {
jsonResponse(w, paramNewKeySecret+" is required", http.StatusBadRequest)
return
}

oldKey, err := getAPIKey(r)
if err != nil {
log.Printf("old key is not valid: %s", err)
jsonResponse(w, apiKeyNotFound, http.StatusNotFound)
log.Printf("Rotate API key error: %v", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -509,22 +513,6 @@ func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, responseBody, http.StatusOK)
}

func parseRotateKeyRequestBody(body io.Reader) (map[string]string, error) {
var requestBody map[string]string
err := json.NewDecoder(body).Decode(&requestBody)
if err != nil {
return nil, fmt.Errorf("invalid request in RotateApiKey: %w", err)
}

fields := []string{paramNewKeyId, paramNewKeySecret, paramOldKeyId, paramOldKeySecret}
for _, field := range fields {
if _, ok := requestBody[field]; !ok {
return nil, fmt.Errorf("%s is required", field)
}
}
return requestBody, nil
}

func (k *ApiKey) loadAndCheck(secret string) error {
err := k.Load()
if err != nil {
Expand Down
66 changes: 30 additions & 36 deletions apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mfa

import (
"bytes"
"context"
"crypto/aes"
"crypto/rand"
"encoding/base64"
Expand All @@ -10,6 +11,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"time"
Expand Down Expand Up @@ -372,78 +374,70 @@ func (ms *MfaSuite) TestAppRotateApiKey() {
tests := []struct {
name string
body any
key ApiKey
wantStatus int
wantError error
wantError string
}{
{
name: "missing oldKeyId",
body: map[string]interface{}{
paramNewKeyId: newKey.Key,
paramNewKeySecret: newKey.Secret,
paramOldKeySecret: key.Secret,
},
wantStatus: http.StatusBadRequest,
wantError: errors.New("oldKeyId is required"),
},
{
name: "missing oldKeySecret",
name: "missing key",
body: map[string]interface{}{
paramNewKeyId: newKey.Key,
paramNewKeySecret: newKey.Secret,
paramOldKeyId: key.Key,
},
wantStatus: http.StatusBadRequest,
wantError: errors.New("oldKeySecret is required"),
wantStatus: http.StatusUnauthorized,
wantError: "Unauthorized",
},
{
name: "missing newKeyId",
body: map[string]interface{}{
paramNewKeySecret: newKey.Secret,
paramOldKeyId: key.Key,
paramOldKeySecret: key.Secret,
},
key: key,
wantStatus: http.StatusBadRequest,
wantError: errors.New("newKeyId is required"),
wantError: "newKeyId is required",
},
{
name: "missing newKeySecret",
body: map[string]interface{}{
paramNewKeyId: newKey.Key,
paramOldKeyId: key.Key,
paramOldKeySecret: key.Secret,
paramNewKeyId: newKey.Key,
},
key: key,
wantStatus: http.StatusBadRequest,
wantError: errors.New("newKeySecret is required"),
wantError: "newKeySecret is required",
},
{
name: "good",
body: map[string]interface{}{
paramNewKeyId: newKey.Key,
paramNewKeySecret: newKey.Secret,
paramOldKeyId: user.ApiKey.Key,
paramOldKeySecret: key.Secret,
},
key: key,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
ms.Run(tt.name, func() {
res := &lambdaResponseWriter{Headers: http.Header{}}
req := requestWithUser(tt.body, key)
ms.app.RotateApiKey(res, req)

if tt.wantError != nil {
ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body))
var se simpleError
ms.decodeBody(res.Body, &se)
ms.ErrorIs(se, tt.wantError)
jsonBody, err := json.Marshal(tt.body)
must(err)
b := io.NopCloser(bytes.NewReader(jsonBody))
request, _ := http.NewRequest(http.MethodPost, "/api-key/rotate", b)
request.Header.Set(HeaderAPIKey, tt.key.Key)
request.Header.Set(HeaderAPISecret, tt.key.Secret)

ctxWithUser := context.WithValue(request.Context(), UserContextKey, tt.key)
request = request.WithContext(ctxWithUser)

res := httptest.NewRecorder()
Router(ms.app).ServeHTTP(res, request)
ms.Equal(tt.wantStatus, res.Code, "incorrect http status, body: %s", res.Body.String())

if tt.wantError != "" {
ms.Contains(res.Body.String(), tt.wantError)
return
}

ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body))

var response map[string]int
ms.decodeBody(res.Body, &response)
ms.decodeBody(res.Body.Bytes(), &response)
ms.Equal(1, response["totpComplete"])
ms.Equal(1, response["webauthnComplete"])

Expand Down
9 changes: 7 additions & 2 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ import (
"strings"
)

const (
HeaderAPIKey = "x-mfa-apikey"
HeaderAPISecret = "x-mfa-apisecret"
)

type User interface{}

// AuthenticateRequest checks the provided API key against the keys stored in the database. If the key is active and
// valid, a Webauthn client and WebauthnUser are created and stored in the request context.
func AuthenticateRequest(r *http.Request) (User, error) {
// get key and secret from headers
key := r.Header.Get("x-mfa-apikey")
secret := r.Header.Get("x-mfa-apisecret")
key := r.Header.Get(HeaderAPIKey)
secret := r.Header.Get(HeaderAPISecret)

if key == "" || secret == "" {
return nil, fmt.Errorf("x-mfa-apikey and x-mfa-apisecret are required")
Expand Down
13 changes: 2 additions & 11 deletions openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ paths:
operationId: rotateApiKey
summary: Rotate API Key
description: >
All data in webauthn and totp tables that is encrypted by the old key will be re-encrypted using the new key.
All data in webauthn and totp tables that is encrypted by the old key (identified by the request headers
`x-mfa-apikey` and `x-mfa-apisecret`) will be re-encrypted using a new key identified in the request body.
If the process does not run to completion, this endpoint can be called any number of times to continue the
process. A status of 200 does not indicate that all keys were encrypted using the new key. Check the response
data to determine if the rotation process is complete.
Expand All @@ -500,16 +501,6 @@ paths:
schema:
type: object
properties:
oldKeyId:
type: string
description: old API Key ID
required: true
example: 0123456789012345678901234567890123456789
oldKeySecret:
type: string
description: old API Key secret
required: true
example: 0123456789012345678901234567890123456789012=
newKeyId:
type: string
description: new API Key ID
Expand Down
10 changes: 7 additions & 3 deletions webauthn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -744,6 +745,7 @@ func Test_GetPublicKeyAsBytes(t *testing.T) {

func Router(app *App) http.Handler {
mux := &http.ServeMux{}
mux.HandleFunc("POST /api-key/rotate", app.RotateApiKey)
mux.HandleFunc(fmt.Sprintf("DELETE /webauthn/credential/{%s}", IDParam), app.DeleteCredential)
// Ensure a request without an id gets handled properly
mux.HandleFunc("DELETE /webauthn/credential/", app.DeleteCredential)
Expand All @@ -752,11 +754,13 @@ func Router(app *App) http.Handler {
return testAuthnMiddleware(mux)
}

// testAuthnMiddleware is a copy of the authenticationMiddleware function
func testAuthnMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, err := AuthenticateRequest(r)
if err != nil {
http.Error(w, fmt.Sprintf("unable to authenticate request: %s", err), http.StatusUnauthorized)
log.Printf("unable to authenticate request: %s", err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

Expand Down Expand Up @@ -841,8 +845,8 @@ func (ms *MfaSuite) Test_DeleteCredential() {
ms.T().Run(tt.name, func(t *testing.T) {
request, _ := http.NewRequest("DELETE", fmt.Sprintf("/webauthn/credential/%s", tt.credID), nil)

request.Header.Set("x-mfa-apikey", tt.user.ApiKeyValue)
request.Header.Set("x-mfa-apisecret", tt.user.ApiKey.Secret)
request.Header.Set(HeaderAPIKey, tt.user.ApiKeyValue)
request.Header.Set(HeaderAPISecret, tt.user.ApiKey.Secret)
request.Header.Set("x-mfa-RPDisplayName", "TestRPName")
request.Header.Set("x-mfa-RPID", "111.11.11.11")
request.Header.Set("x-mfa-UserUUID", tt.user.ID)
Expand Down