diff --git a/cache.go b/cache.go index e9f6dea..07112d0 100644 --- a/cache.go +++ b/cache.go @@ -226,17 +226,14 @@ func (c *Cache[K, V]) Get(key K) (zero V, ok bool) { // key is not present, sets the given value. // The loaded result is true if the value was loaded, false if stored. func (c *Cache[K, V]) GetOrSet(key K, val V, opts ...ItemOption) (actual V, loaded bool) { - c.mu.Lock() - defer c.mu.Unlock() - item, ok := c.cache.Get(key) - - if !ok || item.Expired() { - item := newItem(key, val, opts...) - c.cache.Set(key, item) - return val, false - } - - return item.Value, true + actual = c.ModifyFn(key, func(key K, item *Item[K, V]) (V, []ItemOption) { + if item == nil || item.Expired() { + return val, opts + } + loaded = true + return item.Value, nil + }) + return } // DeleteExpired all expired items from the cache. @@ -271,13 +268,9 @@ func (c *Cache[K, V]) DeleteExpired() { // Set sets a value to the cache with key. replacing any existing value. func (c *Cache[K, V]) Set(key K, val V, opts ...ItemOption) { - c.mu.Lock() - defer c.mu.Unlock() - item := newItem(key, val, opts...) - if item.hasExpiration() { - c.expManager.update(key, item.Expiration) - } - c.cache.Set(key, item) + c.ModifyFn(key, func(K, *Item[K, V]) (V, []ItemOption) { + return val, opts + }) } // Keys returns the keys of the cache. the order is relied on algorithms. @@ -310,13 +303,40 @@ func (c *Cache[K, V]) Contains(key K) bool { return ok } +// GetOrSetFn atomically gets a key's value from the cache, or if the +// key is not present, calls the given function to calculate the value. +// Returns the value stored in the cache. +func (c *Cache[K, V]) GetOrSetFn(key K, calc func(K) (V, []ItemOption)) V { + return c.ModifyFn(key, func(key K, item *Item[K, V]) (V, []ItemOption) { + if item != nil { + return item.Value, nil + } + return calc(key) + }) +} + +// ModifyFn gets a key's value from the cache, and calls the given function to modify the value, +// stores the modified value back to the cache. Returns the value stored in the cache. +func (c *Cache[K, V]) ModifyFn(key K, modify func(K, *Item[K, V]) (V, []ItemOption)) V { + c.mu.Lock() + defer c.mu.Unlock() + + item, _ := c.cache.Get(key) + + val, options := modify(key, item) + + item = newItem(key, val, options...) + if item.hasExpiration() { + c.expManager.update(key, item.Expiration) + } + c.cache.Set(key, item) + + return val +} + // NumberCache is a in-memory cache which is able to store only Number constraint. type NumberCache[K comparable, V Number] struct { *Cache[K, V] - // nmu is used to do lock in Increment/Decrement process. - // Note that this must be here as a separate mutex because mu in Cache struct is Locked in Get, - // and if we call mu.Lock in Increment/Decrement, it will cause deadlock. - nmu sync.Mutex } // NewNumber creates a new cache for Number constraint. @@ -329,22 +349,16 @@ func NewNumber[K comparable, V Number](opts ...Option[K, V]) *NumberCache[K, V] // Increment an item of type Number constraint by n. // Returns the incremented value. func (nc *NumberCache[K, V]) Increment(key K, n V) V { - // In order to avoid lost update, we must lock whole Increment/Decrement process. - nc.nmu.Lock() - defer nc.nmu.Unlock() - got, _ := nc.Cache.Get(key) - nv := got + n - nc.Cache.Set(key, nv) - return nv + return nc.ModifyFn(key, func(key K, item *Item[K, V]) (V, []ItemOption) { + if item != nil { + return item.Value + n, nil + } + return n, nil + }) } // Decrement an item of type Number constraint by n. // Returns the decremented value. func (nc *NumberCache[K, V]) Decrement(key K, n V) V { - nc.nmu.Lock() - defer nc.nmu.Unlock() - got, _ := nc.Cache.Get(key) - nv := got - n - nc.Cache.Set(key, nv) - return nv + return nc.Increment(key, -n) }