Skip to content

implement token cache #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions auth/oauth/tokencache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package oauth

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
)

// TokenCache defines the interface for storing and retrieving OAuth tokens
type TokenCache interface {
// Save stores a token in the cache
Save(token *oauth2.Token) error

Check failure on line 20 in auth/oauth/tokencache.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofmt`-ed with `-s` (gofmt)
// Load retrieves a token from the cache
Load() (*oauth2.Token, error)
}

// FileTokenCache implements TokenCache using file-based storage
type FileTokenCache struct {
path string
}

// NewFileTokenCache creates a new file-based token cache
func NewFileTokenCache(path string) *FileTokenCache {
return &FileTokenCache{path: path}
}

// Save stores a token in the file cache
func (c *FileTokenCache) Save(token *oauth2.Token) error {
// Create parent directories if they don't exist
if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil {
return fmt.Errorf("failed to create token cache directory: %w", err)
}

// Serialize token to JSON
data, err := json.Marshal(token)
if err != nil {
return fmt.Errorf("failed to serialize token: %w", err)
}

// Write to file with secure permissions
if err := os.WriteFile(c.path, data, 0600); err != nil {
return fmt.Errorf("failed to write token to cache: %w", err)
}

log.Debug().Str("path", c.path).Msg("Token saved to cache")
return nil
}

// Load retrieves a token from the file cache
func (c *FileTokenCache) Load() (*oauth2.Token, error) {
data, err := os.ReadFile(c.path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // Cache doesn't exist yet
}
return nil, fmt.Errorf("failed to read token from cache: %w", err)
}

var token oauth2.Token
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to deserialize token: %w", err)
}

log.Debug().Str("path", c.path).Msg("Token loaded from cache")
return &token, nil
}

// GetCacheFilePath returns the path to the token cache file
func GetCacheFilePath(host, clientID string, scopes []string) string {
// Create SHA-256 hash of host, client_id, and scopes
h := sha256.New()
h.Write([]byte(host))
h.Write([]byte(clientID))
h.Write([]byte(strings.Join(scopes, ",")))
hash := hex.EncodeToString(h.Sum(nil))

// Get user's home directory
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}

return filepath.Join(homeDir, ".config", "databricks-sql-go", "oauth", hash)
}

// CachingTokenSource wraps a TokenSource with a TokenCache
type CachingTokenSource struct {
src oauth2.TokenSource
cache TokenCache
}

// NewCachingTokenSource creates a new TokenSource that caches tokens
func NewCachingTokenSource(src oauth2.TokenSource, cache TokenCache) oauth2.TokenSource {
return &CachingTokenSource{
src: src,
cache: cache,
}
}

// Token returns a valid token from either the cache or the underlying source
func (cts *CachingTokenSource) Token() (*oauth2.Token, error) {
// Try to get token from cache first
if cts.cache != nil {
token, err := cts.cache.Load()
if err == nil && token != nil && token.Valid() {
log.Debug().Msg("Using cached token")
return token, nil
}
}

// Get a new token from the source
token, err := cts.src.Token()
if err != nil {
return nil, err
}

// Save the token to cache
if cts.cache != nil {
if err := cts.cache.Save(token); err != nil {
log.Warn().Err(err).Msg("Failed to save token to cache")
}
}

return token, nil
}
155 changes: 155 additions & 0 deletions auth/oauth/tokencache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package oauth

import (
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestFileTokenCache(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "token-cache-test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

cachePath := filepath.Join(tempDir, "token.json")
cache := NewFileTokenCache(cachePath)

// Create a test token
expiry := time.Now().Add(1 * time.Hour)
token := &oauth2.Token{
AccessToken: "test-access-token",
TokenType: "Bearer",
RefreshToken: "test-refresh-token",
Expiry: expiry,
}

// Test saving the token
err = cache.Save(token)
require.NoError(t, err)

// Verify the file exists
_, err = os.Stat(cachePath)
assert.NoError(t, err)

// Test loading the token
loadedToken, err := cache.Load()
require.NoError(t, err)
assert.NotNil(t, loadedToken)
assert.Equal(t, token.AccessToken, loadedToken.AccessToken)
assert.Equal(t, token.TokenType, loadedToken.TokenType)
assert.Equal(t, token.RefreshToken, loadedToken.RefreshToken)
assert.WithinDuration(t, token.Expiry, loadedToken.Expiry, time.Second)

// Test file permissions
fileInfo, err := os.Stat(cachePath)
require.NoError(t, err)
assert.Equal(t, os.FileMode(0600), fileInfo.Mode().Perm())
}

func TestCachingTokenSource(t *testing.T) {
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "token-source-test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

cachePath := filepath.Join(tempDir, "token.json")
cache := NewFileTokenCache(cachePath)

// Create a mock token source
mockSource := &mockTokenSource{
token: &oauth2.Token{
AccessToken: "test-access-token",
TokenType: "Bearer",
RefreshToken: "test-refresh-token",
Expiry: time.Now().Add(1 * time.Hour),
},
}

// Create a caching token source
cachingSource := NewCachingTokenSource(mockSource, cache)

// First call should use the mock source
token, err := cachingSource.Token()
require.NoError(t, err)
assert.Equal(t, mockSource.token.AccessToken, token.AccessToken)
assert.Equal(t, 1, mockSource.callCount)

// Second call should use the cached token
token, err = cachingSource.Token()
require.NoError(t, err)
assert.Equal(t, mockSource.token.AccessToken, token.AccessToken)
assert.Equal(t, 1, mockSource.callCount) // Call count should still be 1

// Create a new caching token source with the same cache
cachingSource2 := NewCachingTokenSource(mockSource, cache)

// This should use the cached token
token, err = cachingSource2.Token()
require.NoError(t, err)
assert.Equal(t, mockSource.token.AccessToken, token.AccessToken)
assert.Equal(t, 1, mockSource.callCount) // Call count should still be 1

// Create a token that's expired
mockSource.token = &oauth2.Token{
AccessToken: "expired-token",
TokenType: "Bearer",
RefreshToken: "test-refresh-token",
Expiry: time.Now().Add(-1 * time.Hour),
}

// Save the expired token to cache
err = cache.Save(mockSource.token)
require.NoError(t, err)

// Create a new mock source with a fresh token
freshMockSource := &mockTokenSource{
token: &oauth2.Token{
AccessToken: "fresh-token",
TokenType: "Bearer",
RefreshToken: "test-refresh-token",
Expiry: time.Now().Add(1 * time.Hour),
},
}

// Create a new caching token source with the expired cache
cachingSource3 := NewCachingTokenSource(freshMockSource, cache)

// This should detect the expired token and get a fresh one
token, err = cachingSource3.Token()
require.NoError(t, err)
assert.Equal(t, freshMockSource.token.AccessToken, token.AccessToken)
assert.Equal(t, 1, freshMockSource.callCount) // Should have called the fresh source
}

func TestGetCacheFilePath(t *testing.T) {
host := "test-host.cloud.databricks.com"
clientID := "test-client-id"
scopes := []string{"scope1", "scope2"}

path1 := GetCacheFilePath(host, clientID, scopes)
path2 := GetCacheFilePath(host, clientID, scopes)

// Same inputs should produce the same path
assert.Equal(t, path1, path2)

// Different inputs should produce different paths
path3 := GetCacheFilePath("different-host", clientID, scopes)
assert.NotEqual(t, path1, path3)
}

// mockTokenSource is a simple implementation of oauth2.TokenSource for testing
type mockTokenSource struct {
token *oauth2.Token
callCount int
}

func (m *mockTokenSource) Token() (*oauth2.Token, error) {
m.callCount++
return m.token, nil
}

Check failure on line 155 in auth/oauth/tokencache_test.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofmt`-ed with `-s` (gofmt)
Loading
Loading