Skip to content

Commit 1a4b692

Browse files
authored
User ID to context methods (#223)
1 parent 7a9c6ef commit 1a4b692

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

context/keys/keys.go

+1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ const (
1515
TenantIdCtxKey = ContextKey(jwt.TenantIdCtxKey)
1616
AuthHeaderCtxKey = ContextKey(jwt.AuthHeaderCtxKey)
1717
WebTokenCtxKey = ContextKey(jwt.WebTokenCtxKey)
18+
UserIDCtxKey = ContextKey("userId")
1819
)

context/service_context.go

+18
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,21 @@ func GetIsTechnicalIssuerFromContext(ctx context.Context) bool {
8787

8888
return isTechnicalIsser
8989
}
90+
91+
func AddUserIDToContext(ctx context.Context, userID string) context.Context {
92+
return context.WithValue(ctx, keys.UserIDCtxKey, userID)
93+
}
94+
95+
func GetUserIDFromContext(ctx context.Context) (string, error) {
96+
userID, ok := ctx.Value(keys.UserIDCtxKey).(string)
97+
if !ok {
98+
return userID, fmt.Errorf("someone stored a wrong value in the [%s] key with type [%T], expected [string]", keys.UserIDCtxKey, ctx.Value(keys.UserIDCtxKey))
99+
}
100+
return userID, nil
101+
}
102+
103+
func HasUserIDInContext(ctx context.Context) bool {
104+
_, ok := ctx.Value(keys.UserIDCtxKey).(string)
105+
return ok
106+
}
107+

context/service_context_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package context_test
22

33
import (
44
"context"
5+
"fmt"
56
"testing"
67

78
"github.com/go-jose/go-jose/v4"
@@ -196,3 +197,39 @@ func TestHasTenantInContextNegative(t *testing.T) {
196197
hasTenant := openmfpctx.HasTenantInContext(ctx)
197198
assert.False(t, hasTenant)
198199
}
200+
201+
func TestAddUserIDToContextAndGetUserIDFromContext(t *testing.T) {
202+
baseCtx := context.Background()
203+
userID := "testUser123"
204+
205+
ctxWithUserID := openmfpctx.AddUserIDToContext(baseCtx, userID)
206+
207+
retrievedUserID, err := openmfpctx.GetUserIDFromContext(ctxWithUserID)
208+
assert.NoError(t, err, "Expected no error when retrieving userID")
209+
assert.Equal(t, userID, retrievedUserID, "Retrieved userID should match the added value")
210+
}
211+
212+
func TestGetUserIDFromContextWrongType(t *testing.T) {
213+
baseCtx := context.Background()
214+
215+
ctxWithWrongType := context.WithValue(baseCtx, keys.UserIDCtxKey, 123)
216+
217+
retrievedUserID, err := openmfpctx.GetUserIDFromContext(ctxWithWrongType)
218+
assert.Error(t, err, "Expected an error when retrieving userID with the wrong type")
219+
expectedErrorMsg := fmt.Sprintf("someone stored a wrong value in the [%s] key with type [%T], expected [string]", keys.UserIDCtxKey, ctxWithWrongType.Value(keys.UserIDCtxKey))
220+
assert.Equal(t, expectedErrorMsg, err.Error(), "Error message should match the expected message")
221+
assert.Equal(t, "", retrievedUserID, "Retrieved userID should be an empty string when an error occurs")
222+
}
223+
224+
func TestHasUserIDInContext(t *testing.T) {
225+
baseCtx := context.Background()
226+
227+
assert.False(t, openmfpctx.HasUserIDInContext(baseCtx), "Expected false when userID is not set in context")
228+
229+
ctxWithUserID := openmfpctx.AddUserIDToContext(baseCtx, "user123")
230+
assert.True(t, openmfpctx.HasUserIDInContext(ctxWithUserID), "Expected true when a valid userID is set in context")
231+
232+
ctxWithWrongType := context.WithValue(baseCtx, keys.UserIDCtxKey, 456)
233+
assert.False(t, openmfpctx.HasUserIDInContext(ctxWithWrongType), "Expected false when the value stored is of the wrong type")
234+
}
235+

0 commit comments

Comments
 (0)