From feb9c9d1fe4baa035ec4e69f8ea947c422a6c8aa Mon Sep 17 00:00:00 2001 From: Taimoor Zaeem Date: Mon, 24 Mar 2025 12:10:18 +0500 Subject: [PATCH] feat: add metric pgrst_jwt_cache_size_bytes in admin server --- CHANGELOG.md | 1 + default.nix | 3 + docs/references/observability.rst | 14 ++++ nix/overlays/haskell-packages.nix | 4 ++ postgrest.cabal | 11 +++ src/PostgREST/Auth.hs | 7 +- src/PostgREST/Auth/JwtCache.hs | 114 ++++++++++++++++++++++++++---- src/PostgREST/Logger.hs | 3 + src/PostgREST/Metrics.hs | 11 +-- src/PostgREST/Observation.hs | 3 + stack-21.7.yaml | 1 + stack.yaml | 1 + test/io/test_io.py | 36 ++++++++++ 13 files changed, 187 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0a336cb9b..997fe45419 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #3041, Allow spreading one-to-many and many-to-many embedded resources - @laurenceisla + The selected columns in the embedded resources are aggregated into arrays + Aggregates are not supported + - #3802, Add metric `pgrst_jwt_cache_size_bytes` in admin server - @taimoorzaeem ### Fixed diff --git a/default.nix b/default.nix index b5f0cd15dc..9256afb5ea 100644 --- a/default.nix +++ b/default.nix @@ -85,6 +85,9 @@ rec { lib.enableExecutableProfiling lib.enableLibraryProfiling lib.dontHaddock + # we disable the jwt-cache-metric flag in cabal to disable + # jwt cache size calculation + (drv: lib.appendConfigureFlags drv [ "--flags=-jwt-cache-metric" ]) ]; inherit (postgrest) env; diff --git a/docs/references/observability.rst b/docs/references/observability.rst index 603d8e3de2..ed5a84b9b8 100644 --- a/docs/references/observability.rst +++ b/docs/references/observability.rst @@ -195,6 +195,20 @@ pgrst_db_pool_max Max pool connections. +JWT Cache Metric +---------------- + +Related to the :ref:`jwt_caching`. + +pgrst_jwt_cache_size_bytes +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +======== ======= +**Type** Gauge +======== ======= + +The JWT cache size in bytes. + Traces ====== diff --git a/nix/overlays/haskell-packages.nix b/nix/overlays/haskell-packages.nix index 6d14cf0d4d..9b5078f64e 100644 --- a/nix/overlays/haskell-packages.nix +++ b/nix/overlays/haskell-packages.nix @@ -50,6 +50,10 @@ let # jailbreak, because hspec limit for tests fuzzyset = prev.fuzzyset_0_2_4; + # TODO: Remove this once https://github.com/NixOS/nixpkgs/pull/375121 + # has made it to us. + ghc-datasize = lib.markUnbroken prev.ghc-datasize; + hasql-pool = lib.dontCheck (prev.callHackageDirect { pkg = "hasql-pool"; diff --git a/postgrest.cabal b/postgrest.cabal index f06fe783b5..c4dc1bb38b 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -40,6 +40,13 @@ flag hpc manual: True description: Enable HPC (dev only) +-- this flag is set to false when running memory tests using a profiled +-- build because profiling is disabled for ghc-datasize library +flag jwt-cache-metric + default: True + manual: True + description: Jwt cache metric is calculated + library default-language: Haskell2010 default-extensions: OverloadedStrings @@ -168,6 +175,10 @@ library else ghc-options: -O2 + if flag(jwt-cache-metric) + cpp-options: -DJWT_CACHE_METRIC + build-depends: ghc-datasize >= 0.2.7 && < 0.3 + if !os(windows) build-depends: unix diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index cbcef6b5ec..d1506df0ad 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -44,7 +44,7 @@ import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) import PostgREST.AppState (AppState, getConfig, getJwtCacheState, - getTime) + getObserver, getTime) import PostgREST.Auth.JwtCache (lookupJwtCache) import PostgREST.Auth.Types (AuthResult (..)) import PostgREST.Config (AppConfig (..), FilterExp (..), @@ -161,6 +161,7 @@ middleware appState app req respond = do let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf jwtCacheState = getJwtCacheState appState + observer = getObserver appState -- If ServerTimingEnabled -> calculate JWT validation time -- If JwtCacheMaxLifetime -> cache JWT validation result @@ -171,7 +172,7 @@ middleware appState app req respond = do (True, maxLifetime) -> do (dur, authResult) <- timeItT $ case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time + Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time observer Nothing -> parseJwt return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } @@ -181,7 +182,7 @@ middleware appState app req respond = do (False, maxLifetime) -> do authResult <- case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time + Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time observer Nothing -> parseJwt return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } diff --git a/src/PostgREST/Auth/JwtCache.hs b/src/PostgREST/Auth/JwtCache.hs index e02193a9ce..9aa2097513 100644 --- a/src/PostgREST/Auth/JwtCache.hs +++ b/src/PostgREST/Auth/JwtCache.hs @@ -4,7 +4,7 @@ Description : PostgREST Jwt Authentication Result Cache. This module provides functions to deal with the JWT cache -} -{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE CPP #-} module PostgREST.Auth.JwtCache ( init , JwtCacheState @@ -18,30 +18,50 @@ import qualified Data.Scientific as Sci import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import System.Clock (TimeSpec (..)) +#ifdef JWT_CACHE_METRIC /* Include this in a non-profiled postgrest build */ +import GHC.DataSize (recursiveSizeNF) +#endif +import System.Clock (TimeSpec (..)) -import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Error (Error (..)) +import PostgREST.Auth.Types (AuthResult (..)) +import PostgREST.Error (Error (..)) +import PostgREST.Observation (Observation (..), ObservationHandler) +import Control.Debounce import Protolude -newtype JwtCacheState = JwtCacheState - { jwtCache :: C.Cache ByteString AuthResult +-- Jwt Cache State +-- +-- Calculating the size of each cache entry is an expensive operation. We don't +-- want to recalculate the size of each entry after the cache eviction/purging. +-- +-- To avoid this, we store the size of each cache entry with the value of the +-- cache entry as a tuple (AuthResult,SizeInBytes). Now after the purging +-- operation, the size of cache entry will be evicted along with the entry and +-- updating the cache size becomes a simple sum of all sizes that are store in +-- the cache +data JwtCacheState = JwtCacheState + -- | Jwt Cache + { jwtCache :: C.Cache ByteString (AuthResult,SizeInBytes) + -- | Calculate cache size with debounce + , cacheSizeCalcDebounceTimeout :: MVar (IO ()) } +type SizeInBytes = Int + -- | Initialize JwtCacheState init :: IO JwtCacheState init = do cache <- C.newCache Nothing -- no default expiration - return $ JwtCacheState cache + JwtCacheState cache <$> newEmptyMVar -- | Used to retrieve and insert JWT to JWT Cache -lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) -lookupJwtCache JwtCacheState{jwtCache} token maxLifetime parseJwt utc = do - checkCache <- C.lookup jwtCache token - authResult <- maybe parseJwt (pure . Right) checkCache +lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> ObservationHandler -> IO (Either Error AuthResult) +lookupJwtCache jwtCacheState token maxLifetime parseJwt utc observer = do + checkCache <- C.lookup (jwtCache jwtCacheState) token + authResult <- maybe parseJwt (pure . Right . fst) checkCache - case (authResult,checkCache) of + case (authResult, checkCache) of -- From comment: -- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914 -- @@ -56,13 +76,20 @@ lookupJwtCache JwtCacheState{jwtCache} token maxLifetime parseJwt utc = do (Right res, Nothing) -> do -- cache miss + -- get expiration time let timeSpec = getTimeSpec res maxLifetime utc -- purge expired cache entries - C.purgeExpired jwtCache + C.purgeExpired (jwtCache jwtCacheState) + + -- calculate size of the cache entry to store it with authResult + sz <- calcCacheEntrySizeInBytes (token,res,timeSpec) - -- insert new cache entry - C.insert' jwtCache (Just timeSpec) token res + -- insert new cache entry with byte size + C.insert' (jwtCache jwtCacheState) (Just timeSpec) token (res,sz) + + -- calculate complete cache size with debounce and log it + updateCacheSizeWithDebounce jwtCacheState observer _ -> pure () @@ -77,3 +104,60 @@ getTimeSpec res maxLifetime utc = do case expireJSON of Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0 _ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0 + +-- | Update JwtCacheSize Metric +-- +-- Runs the cache size calculation with debounce +updateCacheSizeWithDebounce :: JwtCacheState -> ObservationHandler -> IO () +updateCacheSizeWithDebounce jwtCacheState observer = do + cSizeDebouncer <- tryReadMVar $ cacheSizeCalcDebounceTimeout jwtCacheState + case cSizeDebouncer of + Just d -> d + Nothing -> do + newDebouncer <- + mkDebounce defaultDebounceSettings + -- debounceFreq is set to default 1 second + { debounceAction = calculateSizeThenLog + , debounceEdge = leadingEdge -- logs at the start and the end + } + putMVar (cacheSizeCalcDebounceTimeout jwtCacheState) newDebouncer + newDebouncer + where + calculateSizeThenLog :: IO () + calculateSizeThenLog = do + entries <- C.toList $ jwtCache jwtCacheState + -- extract the size from each entry and sum them all + let size = sum [ sz | (_,(_,sz),_) <- entries] + observer $ JwtCache size -- updates and logs the metric + +-- | Calculate JWT Cache Size in Bytes +-- +-- The cache size is updated by calculating the size of every +-- cache entry and updating the metric. +-- +-- The cache entry consists of +-- key :: ByteString +-- value :: AuthReult +-- expire value :: TimeSpec +-- +-- We calculate the size of each cache entry component +-- by using recursiveSizeNF function which first evaluates +-- the data structure to Normal Form and then calculate size. +-- The normal form evaluation is necessary for accurate size +-- calculation because haskell is lazy and we dont wanna count +-- the size of large thunks (unevaluated expressions) +calcCacheEntrySizeInBytes :: (ByteString, AuthResult, TimeSpec) -> IO Int +#ifdef JWT_CACHE_METRIC /* Include this in a non-profiled postgrest build */ +calcCacheEntrySizeInBytes entry = fromIntegral <$> getSize entry + where + getSize :: (ByteString, AuthResult, TimeSpec) -> IO Word + getSize (bs, ar, ts) = do + keySize <- recursiveSizeNF bs + arClaimsSize <- recursiveSizeNF $ authClaims ar + arRoleSize <- recursiveSizeNF $ authRole ar + timeSpecSize <- liftA2 (+) (recursiveSizeNF (sec ts)) (recursiveSizeNF (nsec ts)) + let sizeOfSizeEntryItself = 8 -- a constant 8 bytes size of each size entry in the cache + return (keySize + arClaimsSize + arRoleSize + timeSpecSize + sizeOfSizeEntryItself) +#else /* otherwise set it to 0 for a profiled build (used in memory-tests) */ +calcCacheEntrySizeInBytes _ = return 0 +#endif diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index dac4092234..0ddfe570ca 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -100,6 +100,9 @@ observationLogger loggerState logLevel obs = case obs of o@PoolRequestFullfilled -> when (logLevel >= LogDebug) $ do logWithZTime loggerState $ observationMessage o + o@(JwtCache _) -> do + when (logLevel >= LogInfo) $ do + logWithZTime loggerState $ observationMessage o o -> logWithZTime loggerState $ observationMessage o diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 3999e43d83..2a63f51cc9 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -1,5 +1,5 @@ {-| -Module : PostgREST.Logger +Module : PostgREST.Metrics Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics @@ -19,7 +19,7 @@ import PostgREST.Observation import Protolude data MetricsState = - MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge + MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge Gauge init :: Int -> IO MetricsState init configDbPoolSize = do @@ -29,12 +29,13 @@ init configDbPoolSize = do poolMaxSize <- register $ gauge (Info "pgrst_db_pool_max" "Max pool connections") schemaCacheLoads <- register $ vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded") schemaCacheQueryTime <- register $ gauge (Info "pgrst_schema_cache_query_time_seconds" "The query time in seconds of the last schema cache load") + jwtCacheSize <- register $ gauge (Info "pgrst_jwt_cache_size_bytes" "The JWT cache size in bytes") setGauge poolMaxSize (fromIntegral configDbPoolSize) - pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime + pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime jwtCacheSize -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler -observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime) obs = case obs of +observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime jwtCacheSize) obs = case obs of (PoolAcqTimeoutObs _) -> do incCounter poolTimeouts (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of @@ -54,6 +55,8 @@ observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schema setGauge schemaCacheQueryTime resTime SchemaCacheErrorObs _ -> do withLabel schemaCacheLoads "FAIL" incCounter + JwtCache cacheSize -> do + setGauge jwtCacheSize (fromIntegral cacheSize) _ -> pure () diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index 1b3335710f..937f3c08dd 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -59,6 +59,7 @@ data Observation | HasqlPoolObs SQL.Observation | PoolRequest | PoolRequestFullfilled + | JwtCache Int data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01 @@ -146,6 +147,8 @@ observationMessage = \case "Trying to borrow a connection from pool" PoolRequestFullfilled -> "Borrowed a connection from the pool" + JwtCache sz -> + "The JWT Cache size updated to " <> show sz <> " bytes" where showMillis :: Double -> Text showMillis x = toS $ showFFloat (Just 1) (x * 1000) "" diff --git a/stack-21.7.yaml b/stack-21.7.yaml index de212c4e36..29ae4c84de 100644 --- a/stack-21.7.yaml +++ b/stack-21.7.yaml @@ -19,6 +19,7 @@ nix: extra-deps: - configurator-pg-0.2.10 - fuzzyset-0.2.4 + - ghc-datasize-0.2.7 - hasql-notifications-0.2.2.2 - hasql-pool-1.0.1 - postgresql-libpq-0.10.1.0 diff --git a/stack.yaml b/stack.yaml index 66ab159576..9e629c4a2d 100644 --- a/stack.yaml +++ b/stack.yaml @@ -11,6 +11,7 @@ nix: extra-deps: - fuzzyset-0.2.4 + - ghc-datasize-0.2.7 - hasql-pool-1.0.1 - jose-jwt-0.10.0 - postgresql-libpq-0.10.1.0 diff --git a/test/io/test_io.py b/test/io/test_io.py index bb6ea9e585..b3f58b1c01 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1719,6 +1719,8 @@ def test_admin_metrics(defaultenv): assert "pgrst_db_pool_available" in response.text assert "pgrst_db_pool_timeouts_total" in response.text + assert "pgrst_jwt_cache_size_bytes" in response.text + def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): "verify that the Schema Cache loads correctly at startup, using the in-db `pgrst.db_schemas` config" @@ -1737,6 +1739,7 @@ def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): assert response.status_code == 204 +# TODO: Rewrite this test using jwt cache size metric def test_jwt_cache_purges_expired_entries(defaultenv): "test expired cache entries are purged on cache miss" @@ -1830,3 +1833,36 @@ def test_log_pool_req_observation(level, defaultenv): else: output = postgrest.read_stdout(nlines=4) assert len(output) == 0 + + +def test_jwt_cache_size_bytes_update_log(defaultenv): + "JWT cache size should update on a cache miss" + + env = { + **defaultenv, + "PGRST_LOG_LEVEL": "info", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": SECRET, + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(env=env) as postgrest: + response = postgrest.session.get("/authors_only", headers=headers) + assert response.status_code == 200 + + output = sorted(postgrest.read_stdout(nlines=2)) + + response = postgrest.admin.get("/metrics") + assert response.status_code == 200 + + # read cache size from metrics + cache_size = float( + re.search(r"pgrst_jwt_cache_size_bytes (\d+)", response.text).group(1) + ) + + assert int(cache_size) > 0 + assert ( + "The JWT Cache size updated to " + str(int(cache_size)) + " bytes" + in output[1] + )