diff --git a/go/appencryption/session_cache.go b/go/appencryption/session_cache.go index ae7e285eb..fc191ad66 100644 --- a/go/appencryption/session_cache.go +++ b/go/appencryption/session_cache.go @@ -27,6 +27,7 @@ type sharedEncryption struct { accessCounter int mu *sync.Mutex cond *sync.Cond + closeOnce sync.Once // Prevents double-close of underlying Encryption } func (s *sharedEncryption) incrementUsage() { @@ -53,7 +54,10 @@ func (s *sharedEncryption) Remove() { s.cond.Wait() } - s.Encryption.Close() + // Use sync.Once to prevent double-close + s.closeOnce.Do(func() { + s.Encryption.Close() + }) s.mu.Unlock() } diff --git a/go/appencryption/session_cache_double_close_test.go b/go/appencryption/session_cache_double_close_test.go new file mode 100644 index 000000000..31d200644 --- /dev/null +++ b/go/appencryption/session_cache_double_close_test.go @@ -0,0 +1,132 @@ +package appencryption + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestSharedEncryption_DoubleCloseProtection verifies that calling Remove() multiple times +// only calls the underlying Encryption.Close() once, preventing panics or undefined behavior. +func TestSharedEncryption_DoubleCloseProtection(t *testing.T) { + var closeCount int + mockEncryption := &testEncryption{ + closeFunc: func() error { + closeCount++ + return nil + }, + } + + sharedEnc := &sharedEncryption{ + Encryption: mockEncryption, + created: time.Now(), + accessCounter: 0, // No active users + mu: new(sync.Mutex), + } + sharedEnc.cond = sync.NewCond(sharedEnc.mu) + + // Call Remove() multiple times + sharedEnc.Remove() + sharedEnc.Remove() + sharedEnc.Remove() + + // Verify Close() was called exactly once + assert.Equal(t, 1, closeCount, "Close() should be called exactly once despite multiple Remove() calls") +} + +// TestSharedEncryption_ConcurrentDoubleClose tests concurrent calls to Remove() +// to ensure no race conditions in the double-close protection. +func TestSharedEncryption_ConcurrentDoubleClose(t *testing.T) { + var closeCount int + var mu sync.Mutex + mockEncryption := &testEncryption{ + closeFunc: func() error { + mu.Lock() + closeCount++ + mu.Unlock() + return nil + }, + } + + sharedEnc := &sharedEncryption{ + Encryption: mockEncryption, + created: time.Now(), + accessCounter: 0, // No active users + mu: new(sync.Mutex), + } + sharedEnc.cond = sync.NewCond(sharedEnc.mu) + + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Launch multiple goroutines calling Remove() concurrently + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + sharedEnc.Remove() + }() + } + + wg.Wait() + + // Verify Close() was called exactly once despite concurrent calls + assert.Equal(t, 1, closeCount, "Close() should be called exactly once despite concurrent Remove() calls") +} + +// TestSharedEncryption_CloseAndRemove tests calling both Close() and Remove() +// to ensure they work together correctly. +func TestSharedEncryption_CloseAndRemove(t *testing.T) { + var closeCount int + mockEncryption := &testEncryption{ + closeFunc: func() error { + closeCount++ + return nil + }, + } + + sharedEnc := &sharedEncryption{ + Encryption: mockEncryption, + created: time.Now(), + accessCounter: 1, // One active user + mu: new(sync.Mutex), + } + sharedEnc.cond = sync.NewCond(sharedEnc.mu) + + // Simulate user calling Close() (decrements counter) + err := sharedEnc.Close() + assert.NoError(t, err) + assert.Equal(t, 0, sharedEnc.accessCounter) + + // Now call Remove() - should close the underlying encryption + sharedEnc.Remove() + + // Call Remove() again - should not double-close + sharedEnc.Remove() + + // Verify Close() was called exactly once + assert.Equal(t, 1, closeCount, "Close() should be called exactly once") +} + +// testEncryption is a simple test double for the Encryption interface. +type testEncryption struct { + closeFunc func() error +} + +func (t *testEncryption) EncryptPayload(ctx context.Context, data []byte) (*DataRowRecord, error) { + return nil, nil +} + +func (t *testEncryption) DecryptDataRowRecord(ctx context.Context, record DataRowRecord) ([]byte, error) { + return nil, nil +} + +func (t *testEncryption) Close() error { + if t.closeFunc != nil { + return t.closeFunc() + } + return nil +}