diff --git a/dataloader.go b/dataloader.go index 3719df7..e50c1ab 100644 --- a/dataloader.go +++ b/dataloader.go @@ -368,7 +368,7 @@ func (l *Loader) reset() { l.curBatcher = nil if l.clearCacheOnBatch { - l.cache.Clear() + l.ClearAll() } } diff --git a/dataloader_test.go b/dataloader_test.go index fcb6dbf..0719a96 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -333,6 +333,10 @@ func TestLoader(t *testing.T) { t.Parallel() batchOnlyLoader, loadCalls := BatchOnlyLoader(0) ctx := context.Background() + + // The first two calls should be part of the same batch. + // After the batch fires, we expect the cache to be cleared, + // so subsequent calls should be recorded. future1 := batchOnlyLoader.Load(ctx, StringKey("1")) future2 := batchOnlyLoader.Load(ctx, StringKey("1")) @@ -352,8 +356,18 @@ func TestLoader(t *testing.T) { t.Errorf("did not batch queries. Expected %#v, got %#v", expected, calls) } - if _, found := batchOnlyLoader.cache.Get(ctx, StringKey("1")); found { - t.Errorf("did not clear cache after batch. Expected %#v, got %#v", false, found) + // This call should record another call + future3 := batchOnlyLoader.Load(ctx, StringKey("1")) + + _, err = future3() + if err != nil { + t.Error(err.Error()) + } + + calls = *loadCalls + expected = [][]string{{"1"}, {"1"}} + if !reflect.DeepEqual(calls, expected) { + t.Errorf("expected a second batch, got %#v", calls) } })