Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions pkg/transport/session/storage_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package session
import (
"context"
"fmt"
"io"
"log/slog"
"sync"
"time"
)
Expand Down Expand Up @@ -66,9 +68,12 @@ func (s *LocalStorage) Delete(_ context.Context, id string) error {

// DeleteExpired removes all sessions that haven't been updated since the given time.
func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) error {
var toDelete []string
var toDelete []struct {
id string
session Session
}

// First pass: collect IDs of expired sessions
// First pass: collect expired sessions
s.sessions.Range(func(key, val any) bool {
// Check for context cancellation
select {
Expand All @@ -80,16 +85,44 @@ func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) erro
if session, ok := val.(Session); ok {
if session.UpdatedAt().Before(before) {
if id, ok := key.(string); ok {
toDelete = append(toDelete, id)
toDelete = append(toDelete, struct {
id string
session Session
}{id, session})
}
}
}
return true
})

// Second pass: delete expired sessions
for _, id := range toDelete {
s.sessions.Delete(id)
// Second pass: close and delete expired sessions
for _, item := range toDelete {
// Check for context cancellation before processing each session
select {
case <-ctx.Done():
return ctx.Err()
default:
}

// Re-check expiration and use CompareAndDelete to handle race conditions:
// - Session may have been touched via Manager.Get().Touch() and is no longer expired
// - Session may have been replaced via Store/UpsertSession with a new object
// Only proceed if the stored value is still the same session object and still expired
if item.session.UpdatedAt().Before(before) {
// CompareAndDelete ensures we only delete if the value hasn't been replaced
if deleted := s.sessions.CompareAndDelete(item.id, item.session); deleted {
// Successfully deleted - now close if implements io.Closer
if closer, ok := item.session.(io.Closer); ok {
if err := closer.Close(); err != nil {
slog.Warn("failed to close session during cleanup",
"session_id", item.id,
"error", err)
}
}
}
// If CompareAndDelete returned false, the session was already replaced/deleted - skip it
}
// If re-check shows session is no longer expired (was touched), skip it
}

return nil
Expand Down
302 changes: 302 additions & 0 deletions pkg/transport/session/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package session

import (
"context"
"errors"
"fmt"
"testing"
"time"
Expand All @@ -13,6 +14,24 @@ import (
"github.com/stretchr/testify/require"
)

// mockClosableSession is a test session that implements io.Closer
type mockClosableSession struct {
*ProxySession
closeCalled bool
closeError error
}

func newMockClosableSession(id string) *mockClosableSession {
return &mockClosableSession{
ProxySession: NewProxySession(id),
}
}

func (m *mockClosableSession) Close() error {
m.closeCalled = true
return m.closeError
}

// TestLocalStorage tests the LocalStorage implementation
func TestLocalStorage(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -302,6 +321,289 @@ func TestLocalStorage(t *testing.T) {
// Should not error, just stop early
assert.NoError(t, err)
})

t.Run("DeleteExpired calls Close on io.Closer sessions", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create a closable session (implements io.Closer)
closableSession := newMockClosableSession("closable-session")
closableSession.updated = time.Now().Add(-2 * time.Hour)

// Create a regular session (does not implement io.Closer)
regularSession := NewProxySession("regular-session")
regularSession.updated = time.Now().Add(-2 * time.Hour)

// Store both sessions
err := storage.Store(ctx, closableSession)
require.NoError(t, err)
err = storage.Store(ctx, regularSession)
require.NoError(t, err)

// Delete sessions older than 1 hour
cutoff := time.Now().Add(-1 * time.Hour)
err = storage.DeleteExpired(ctx, cutoff)
require.NoError(t, err)

// Both sessions should be deleted
_, err = storage.Load(ctx, "closable-session")
assert.Equal(t, ErrSessionNotFound, err)
_, err = storage.Load(ctx, "regular-session")
assert.Equal(t, ErrSessionNotFound, err)

// Close() is called synchronously, so it should already be done
assert.True(t, closableSession.closeCalled,
"Close() should have been called on closable session")
})

t.Run("DeleteExpired continues deletion even if Close fails", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create a closable session that returns an error on Close()
failingSession := newMockClosableSession("failing-session")
failingSession.closeError = errors.New("close failed")
failingSession.updated = time.Now().Add(-2 * time.Hour)

// Store the session
err := storage.Store(ctx, failingSession)
require.NoError(t, err)

// Delete expired sessions - should not fail even if Close() returns an error
cutoff := time.Now().Add(-1 * time.Hour)
err = storage.DeleteExpired(ctx, cutoff)
require.NoError(t, err)

// Session should be deleted from storage even though Close() failed
_, err = storage.Load(ctx, "failing-session")
assert.Equal(t, ErrSessionNotFound, err)

// Close() is called synchronously, so it should already be done
assert.True(t, failingSession.closeCalled,
"Close() should have been called even though it returned an error")

// Note: We don't verify log output to maintain t.Parallel() compatibility.
// The important behavior is that deletion continues even when Close() fails.
})

t.Run("DeleteExpired handles non-io.Closer sessions without error", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create multiple regular sessions (do not implement io.Closer)
for i := 0; i < 5; i++ {
session := NewProxySession(fmt.Sprintf("session-%d", i))
session.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, session)
require.NoError(t, err)
}

// Delete expired sessions
cutoff := time.Now().Add(-1 * time.Hour)
err := storage.DeleteExpired(ctx, cutoff)
require.NoError(t, err)

// All sessions should be deleted
for i := 0; i < 5; i++ {
_, err := storage.Load(ctx, fmt.Sprintf("session-%d", i))
assert.Equal(t, ErrSessionNotFound, err)
}
})

t.Run("DeleteExpired with mixed session types", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create a mix of closable and regular expired sessions
closable1 := newMockClosableSession("closable-1")
closable1.updated = time.Now().Add(-2 * time.Hour)
closable2 := newMockClosableSession("closable-2")
closable2.updated = time.Now().Add(-2 * time.Hour)

regular1 := NewProxySession("regular-1")
regular1.updated = time.Now().Add(-2 * time.Hour)
regular2 := NewProxySession("regular-2")
regular2.updated = time.Now().Add(-2 * time.Hour)

// Store all sessions
err := storage.Store(ctx, closable1)
require.NoError(t, err)
err = storage.Store(ctx, closable2)
require.NoError(t, err)
err = storage.Store(ctx, regular1)
require.NoError(t, err)
err = storage.Store(ctx, regular2)
require.NoError(t, err)

// Delete expired sessions
cutoff := time.Now().Add(-1 * time.Hour)
err = storage.DeleteExpired(ctx, cutoff)
require.NoError(t, err)

// All sessions should be deleted
_, err = storage.Load(ctx, "closable-1")
assert.Equal(t, ErrSessionNotFound, err)
_, err = storage.Load(ctx, "closable-2")
assert.Equal(t, ErrSessionNotFound, err)
_, err = storage.Load(ctx, "regular-1")
assert.Equal(t, ErrSessionNotFound, err)
_, err = storage.Load(ctx, "regular-2")
assert.Equal(t, ErrSessionNotFound, err)

// Close() is called synchronously, so it should already be done
assert.True(t, closable1.closeCalled,
"Close() should have been called on closable-1")
assert.True(t, closable2.closeCalled,
"Close() should have been called on closable-2")
})

t.Run("DeleteExpired respects context cancellation during deletion", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create many expired sessions to increase chance of context check
for i := 0; i < 10000; i++ {
session := NewProxySession(fmt.Sprintf("session-%d", i))
session.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, session)
require.NoError(t, err)
}

// Create a context with a very short timeout
timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()

// Wait a bit to ensure context times out
time.Sleep(10 * time.Millisecond)

// DeleteExpired should respect context timeout
cutoff := time.Now().Add(-1 * time.Hour)
err := storage.DeleteExpired(timeoutCtx, cutoff)

// With 10000 sessions, the context check should trigger during cleanup
// If it completes too quickly, that's also acceptable behavior
if err != nil {
assert.Equal(t, context.DeadlineExceeded, err)
// Some sessions deleted, but not all due to timeout
remaining := storage.Count()
assert.Greater(t, remaining, 0, "Some sessions should remain due to context timeout")
}
})

t.Run("DeleteExpired handles concurrent Touch() race condition", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create an expired session
session := NewProxySession("race-session")
session.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, session)
require.NoError(t, err)

// Create many other expired sessions to slow down the deletion loop
for i := 0; i < 1000; i++ {
dummySession := NewProxySession(fmt.Sprintf("dummy-%d", i))
dummySession.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, dummySession)
require.NoError(t, err)
}

// Start DeleteExpired in a goroutine
done := make(chan error, 1)
go func() {
cutoff := time.Now().Add(-1 * time.Hour)
done <- storage.DeleteExpired(ctx, cutoff)
}()

// Concurrently touch the session to make it non-expired
// This simulates Manager.Get().Touch() being called during cleanup
session.Touch()

// Wait for DeleteExpired to complete
err = <-done
require.NoError(t, err)

// The session should NOT be deleted because it was touched
// (CompareAndDelete would fail due to updated timestamp or re-check would skip it)
loaded, err := storage.Load(ctx, "race-session")
if err == nil {
// Session still exists - this is correct behavior
assert.NotNil(t, loaded)
assert.True(t, loaded.UpdatedAt().After(time.Now().Add(-1*time.Hour)),
"Session should have recent timestamp after Touch()")
}
// Note: Due to timing, the session might still be deleted if Touch() happened
// after the re-check but before CompareAndDelete. This is acceptable as the
// important thing is we don't close a session that's been replaced.
})

t.Run("DeleteExpired handles concurrent Store() replacement race condition", func(t *testing.T) {
t.Parallel()
storage := NewLocalStorage()
defer storage.Close()

ctx := context.Background()

// Create an expired closable session
oldSession := newMockClosableSession("replace-session")
oldSession.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, oldSession)
require.NoError(t, err)

// Create many other expired sessions to slow down the deletion loop
for i := 0; i < 1000; i++ {
dummySession := NewProxySession(fmt.Sprintf("dummy-%d", i))
dummySession.updated = time.Now().Add(-2 * time.Hour)
err := storage.Store(ctx, dummySession)
require.NoError(t, err)
}

// Start DeleteExpired in a goroutine
done := make(chan error, 1)
go func() {
cutoff := time.Now().Add(-1 * time.Hour)
done <- storage.DeleteExpired(ctx, cutoff)
}()

// Concurrently replace the session with a new one (same ID, different object)
// This simulates UpsertSession being called during cleanup
newSession := newMockClosableSession("replace-session")
err = storage.Store(ctx, newSession)
require.NoError(t, err)

// Wait for DeleteExpired to complete
err = <-done
require.NoError(t, err)

// The new session should still exist (CompareAndDelete prevents deleting it)
loaded, err := storage.Load(ctx, "replace-session")
require.NoError(t, err)
assert.NotNil(t, loaded)

// The old session's Close() may or may not have been called depending on timing
// The important thing is the new session is not closed
// Since Close() is now synchronous, we can check immediately
assert.False(t, newSession.closeCalled,
"New session should not be closed (CompareAndDelete should prevent this)")
})
}

// TestManagerWithStorage tests the Manager with the Storage interface
Expand Down
Loading