Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 35 additions & 44 deletions cmd/lakectl/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"reflect"
"slices"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -184,17 +183,12 @@ const (
)

const (
CacheFileName = "lakectl_token_cache.json"
LakectlDirName = ".lakectl"
CacheDirName = "cache"
cacheFileName = ".lakectl_token_cache.json"
)

var (
cachedToken *apigen.AuthenticationToken
tokenLoadOnce sync.Once
tokenCache *awsiam.JWTCache
tokenCacheOnce sync.Once
ErrTokenUnavailable = fmt.Errorf("token is not available")
ErrCacheUnavailable = fmt.Errorf("cache is not available")
)

func withRecursiveFlag(cmd *cobra.Command, usage string) {
Expand Down Expand Up @@ -654,15 +648,17 @@ func getClient() *apigen.ClientWithResponses {

func CreateTokenCacheCallback() awsiam.TokenCacheCallback {
return func(newToken *apigen.AuthenticationToken) {
cachedToken = newToken
if err := SaveTokenToCache(); err != nil {
if err := SaveTokenToCache(newToken); err != nil {
logging.ContextUnavailable().Debugf("error saving token to cache: %w", err)
}
}
}

func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string) []apigen.ClientOption {
token := getTokenOnce()
token, err := getToken()
if err != nil {
logging.ContextUnavailable().Debugf("no token available in cache: %w", err)
}

tokenCacheCallback := CreateTokenCacheCallback()

Expand All @@ -680,7 +676,6 @@ func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string)
DieErr(err)
}
loginClient := &awsiam.ExternalPrincipalLoginClient{Client: noAuthClient}

awsAuthProvider := awsiam.WithAWSIAMRoleAuthProviderOption(
awsIAMparams,
logging.ContextUnavailable(),
Expand All @@ -692,47 +687,43 @@ func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string)
return []apigen.ClientOption{awsAuthProvider}
}

func getTokenOnce() *apigen.AuthenticationToken {
tokenLoadOnce.Do(func() {
cache := getTokenCacheOnce()
var err error
if cache != nil {
if token, err := cache.GetToken(); err == nil {
cachedToken = token
return
}
logging.ContextUnavailable().Debugf("Error loading token from cache: %w", err)
func getToken() (*apigen.AuthenticationToken, error) {
cache := getTokenCache()
if cache != nil {
token, err := cache.GetToken()
if err != nil {
return nil, err
}
})
return cachedToken
return token, nil
}
return nil, ErrCacheUnavailable
}

func getTokenCacheOnce() *awsiam.JWTCache {
tokenCacheOnce.Do(func() {
homeDir, err := os.UserHomeDir()
if err != nil {
logging.ContextUnavailable().Debugf("Error getting user homedir: %w", err)
}
cache, err := awsiam.NewJWTCache(homeDir, LakectlDirName, CacheDirName, CacheFileName)
if err != nil {
logging.ContextUnavailable().Debugf("Error creating token cache: %w", err)
tokenCache = nil
} else {
tokenCache = cache
}
})
return tokenCache
func getTokenCache() *awsiam.JWTCache {
homeDir, err := os.UserHomeDir()
if err != nil {
logging.ContextUnavailable().WithError(err).Debug("unable to get user homedir")
return nil
}
cache, err := awsiam.NewJWTCache(homeDir, cacheFileName)
if err != nil {
logging.ContextUnavailable().WithError(err).Debug("unable to create token cache")
return nil
}
return cache
}

func SaveTokenToCache() error {
cache := getTokenCacheOnce()
if cache == nil || cachedToken == nil {
func SaveTokenToCache(newToken *apigen.AuthenticationToken) error {
cache := getTokenCache()
if cache == nil {
return ErrCacheUnavailable
}
if newToken == nil {
return ErrTokenUnavailable
}
if err := cache.SaveToken(cachedToken); err != nil {
if err := cache.SaveToken(newToken); err != nil {
return err
}
tokenLoadOnce = sync.Once{}
return nil
}

Expand Down
14 changes: 3 additions & 11 deletions pkg/authentication/externalidp/awsiam/token_caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,13 @@ type JWTCache struct {
FilePath string
}

func NewJWTCache(baseDir, lakectlDir, cacheDir, fileName string) (*JWTCache, error) {
if baseDir == "" {
var err error
baseDir, err = os.UserHomeDir()
if err != nil {
return nil, err
}
}
cachePath := filepath.Join(baseDir, lakectlDir, cacheDir)
if err := os.MkdirAll(cachePath, ReadWriteExecuteOwnerOnly); err != nil {
func NewJWTCache(baseDir, fileName string) (*JWTCache, error) {
if err := os.MkdirAll(baseDir, ReadWriteExecuteOwnerOnly); err != nil {
return nil, ErrFailedToCreateCacheDir
}

jwtCache := &JWTCache{
FilePath: filepath.Join(cachePath, fileName),
FilePath: filepath.Join(baseDir, fileName),
}
return jwtCache, nil
}
Expand Down
29 changes: 10 additions & 19 deletions pkg/authentication/externalidp/awsiam/token_caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,17 @@ import (
func TestNewJWTCache(t *testing.T) {
t.Run("with custom cache dir", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)
require.NotEmpty(t, cache)
require.Equal(t, filepath.Join(tempDir, ".lakectl", "cache", "lakectl_token_cache.json"), cache.FilePath)
})

t.Run("with empty cache dir uses home dir", func(t *testing.T) {
cache, err := awsiam.NewJWTCache("", ".lakectl", "cache", "lakectl_token_cache.json")
require.NoError(t, err)
require.NotEmpty(t, cache)
homeDir, _ := os.UserHomeDir()
expectedPath := filepath.Join(homeDir, ".lakectl", "cache", "lakectl_token_cache.json")
require.Equal(t, expectedPath, cache.FilePath)
require.Equal(t, filepath.Join(tempDir, ".lakectl_token_cache.json"), cache.FilePath)
})
}

func TestJWTCacheSaveToken(t *testing.T) {
t.Run("saves valid token successfully", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(1 * time.Hour).Unix()
Expand Down Expand Up @@ -63,7 +54,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles nil token", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

err = cache.SaveToken(nil)
Expand All @@ -76,7 +67,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles token with empty string", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

token := &apigen.AuthenticationToken{
Expand All @@ -93,7 +84,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles token without expiration", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

token := &apigen.AuthenticationToken{
Expand All @@ -113,7 +104,7 @@ func TestJWTCacheSaveToken(t *testing.T) {
func TestJWTCacheGetToken(t *testing.T) {
t.Run("loads valid non-expired token", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(1 * time.Hour).Unix()
Expand All @@ -135,7 +126,7 @@ func TestJWTCacheGetToken(t *testing.T) {

t.Run("returns nil when cache file doesn't exist", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

loadedToken, err := cache.GetToken()
Expand All @@ -145,7 +136,7 @@ func TestJWTCacheGetToken(t *testing.T) {

t.Run("returns error for corrupted cache file", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

// Write invalid JSON
Expand All @@ -160,7 +151,7 @@ func TestJWTCacheGetToken(t *testing.T) {

func TestJWTCacheSaveAndLoad(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(30 * time.Minute).Unix()
Expand Down
Loading