diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b4c5b833f..3e2a334a9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #1536, Add string comparison feature for jwt-role-claim-key - @taimoorzaeem - #3747, Allow `not_null` value for the `is` operator - @taimoorzaeem - #2255, Apply `to_tsvector()` explicitly to the full-text search filtered column (excluding `tsvector` types) - @laurenceisla + - #3802, Add metric `pgrst_jwt_cache_size_bytes` in admin server - @taimoorzaeem ### Fixed diff --git a/docs/references/observability.rst b/docs/references/observability.rst index 04a261ba48..f3a69487ed 100644 --- a/docs/references/observability.rst +++ b/docs/references/observability.rst @@ -169,6 +169,20 @@ pgrst_db_pool_max Max pool connections. +JWT Cache Metric +---------------- + +Related to the :ref:`jwt_caching`. + +pgrst_jwt_cache_size_bytes +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +======== ======= +**Type** Gauge +======== ======= + +The JWT cache size in bytes. + Traces ====== diff --git a/postgrest.cabal b/postgrest.cabal index 1377640ec2..cfdfbcdea6 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -89,7 +89,8 @@ library PostgREST.Response.GucHeader PostgREST.Response.Performance PostgREST.Version - other-modules: Paths_postgrest + other-modules: PostgREST.Internal + Paths_postgrest build-depends: base >= 4.9 && < 4.20 , HTTP >= 4000.3.7 && < 4000.5 , Ranged-sets >= 0.3 && < 0.5 @@ -108,6 +109,7 @@ library , either >= 4.4.1 && < 5.1 , extra >= 1.7.0 && < 2.0 , fuzzyset >= 0.2.4 && < 0.3 + , ghc-heap >= 9.4 && < 9.9 , hasql >= 1.6.1.1 && < 1.7 , hasql-dynamic-statements >= 0.3.1 && < 0.4 , hasql-notifications >= 0.2.2.0 && < 0.3 diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 99febebaca..81207eb386 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -15,7 +15,6 @@ module PostgREST.App , run ) where - import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) import Data.Maybe (fromJust) @@ -41,8 +40,7 @@ import qualified PostgREST.Response as Response import qualified PostgREST.Unix as Unix (installSignalHandlers) import PostgREST.ApiRequest (ApiRequest (..)) -import PostgREST.AppState (AppState) -import PostgREST.Auth (AuthResult (..)) +import PostgREST.AppState (AppState, AuthResult (..)) import PostgREST.Config (AppConfig (..), LogLevel (..)) import PostgREST.Config.PgVersion (PgVersion (..)) import PostgREST.Error (Error) diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index c8c3b720cb..b35f12f22b 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -39,17 +39,20 @@ import qualified Network.Wai.Middleware.HttpAuth as Wai import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) import Data.List (lookup) +import Data.Maybe (fromJust) import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) import System.Clock (TimeSpec (..)) import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) -import PostgREST.AppState (AppState, AuthResult (..), getConfig, - getJwtCache, getTime) -import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, - JSPathExp (..)) -import PostgREST.Error (Error (..)) +import PostgREST.AppState (AppState, AuthResult (..), getConfig, + getJwtCache, getObserver, getTime) +import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, + JSPathExp (..)) +import PostgREST.Error (Error (..)) +import PostgREST.Internal (recursiveSizeNF) +import PostgREST.Observation (Observation (..)) import Protolude @@ -153,7 +156,7 @@ middleware appState app req respond = do let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf --- If DbPlanEnabled -> calculate JWT validation time +-- If ServerTimingEnabled -> calculate JWT validation time -- If JwtCacheMaxLifetime -> cache JWT validation result req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of (True, 0) -> do @@ -177,24 +180,65 @@ middleware appState app req respond = do -- | Used to retrieve and insert JWT to JWT Cache getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) getJWTFromCache appState token maxLifetime parseJwt utc = do - checkCache <- C.lookup (getJwtCache appState) token + + checkCache <- C.lookup jwtCache token authResult <- maybe parseJwt (pure . Right) checkCache + -- if token not found, add to cache and increment cache size metric case (authResult,checkCache) of - (Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res maxLifetime utc) token res + (Right res, Nothing) -> do + let tSpec = getTimeSpec res maxLifetime utc + C.insert' jwtCache (Just tSpec) token res + cacheSize <- calcCacheSizeInBytes jwtCache + observer $ JWTCache cacheSize + _ -> pure () return authResult + where + observer = getObserver appState + jwtCache = getJwtCache appState -- Used to extract JWT exp claim and add to JWT Cache -getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec +getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec getTimeSpec res maxLifetime utc = do let expireJSON = KM.lookup "exp" (authClaims res) utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds sciToInt = fromMaybe 0 . Sci.toBoundedInteger case expireJSON of - Just (JSON.Number seconds) -> Just $ TimeSpec (sciToInt seconds - utcToSecs utc) 0 - _ -> Just $ TimeSpec (fromIntegral maxLifetime :: Int64) 0 + Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0 + _ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0 + +-- | Calculate a single entry of JWT Cache Size in Bytes +-- +-- The cache size is updated by calculating the size of every +-- new cache entry and adding it to the metric. +-- +-- The cache entry consists of +-- key :: ByteString +-- value :: AuthReults +-- expire value :: TimeSpec +-- +-- We calculate the size of each cache entry component +-- by using recursiveSizeNF function which first evaluates +-- the data structure to Normal Form and then calculate size. +-- The normal form evaluation is necessary for accurate size +-- calculation because haskell is lazy and we dont wanna count +-- the size of large thunks (unevaluated expressions) +calcCacheSizeInBytes :: C.Cache ByteString AuthResult -> IO Int +calcCacheSizeInBytes jwtCache = do + cacheList <- C.toList jwtCache + let szList = [ unsafePerformIO (getSize (bs, ar, fromJust ts)) | (bs, ar, ts) <- cacheList] + return $ fromIntegral (sum szList) + where + getSize :: (ByteString, AuthResult, TimeSpec) -> IO Word + getSize (bs, ar, TimeSpec{..}) = do + keySize <- recursiveSizeNF bs + arClaimsSize <- recursiveSizeNF $ authClaims ar + arRoleSize <- recursiveSizeNF $ authRole ar + timeSpecSize <- liftA2 (+) (recursiveSizeNF sec) (recursiveSizeNF nsec) + + return (keySize + arClaimsSize + arRoleSize + timeSpecSize) authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey diff --git a/src/PostgREST/Internal.hs b/src/PostgREST/Internal.hs new file mode 100644 index 0000000000..c73f25edaf --- /dev/null +++ b/src/PostgREST/Internal.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{- | +Module : PostgREST.Internal +Copyright : (c) Dennis Felsing +License : 3-Clause BSD-style +Maintainer : dennis@felsing.org + +https://hackage.haskell.org/package/ghc-datasize + +This vendored dependency can be removed once https://github.com/PostgREST/postgrest/issues/3881 is solved. +-} +module PostgREST.Internal + ( recursiveSizeNF ) + where + +import GHC.Exts +import GHC.Exts.Heap hiding (size) +import GHC.Exts.Heap.Constants (wORD_SIZE) + +import System.Mem + +import Protolude + +-- Inspired by Simon Marlow: +-- https://ghcmutterings.wordpress.com/2009/02/12/53/ + +-- | Calculate size of GHC objects in Bytes. Note that an object may not be +-- evaluated yet and only the size of the initial closure is returned. +closureSize :: a -> IO Word +closureSize x = do + rawWds <- getClosureRawWords x + return . fromIntegral $ length rawWds * wORD_SIZE + +-- | Calculate the recursive size of GHC objects in Bytes. Note that the actual +-- size in memory is calculated, so shared values are only counted once. +-- +-- Call with +-- @ +-- recursiveSize $! 2 +-- @ +-- to force evaluation to WHNF before calculating the size. +-- +-- Call with +-- @ +-- recursiveSize $!! \"foobar\" +-- @ +-- ($!! from Control.DeepSeq) to force full evaluation before calculating the +-- size. +-- +-- A garbage collection is performed before the size is calculated, because +-- the garbage collector would make heap walks difficult. +-- +-- This function works very quickly on small data structures, but can be slow +-- on large and complex ones. If speed is an issue it's probably possible to +-- get the exact size of a small portion of the data structure and then +-- estimate the total size from that. + +recursiveSize :: a -> IO Word +recursiveSize x = do + performGC + fmap snd $ go ([], 0) $ asBox x + where + go (!vs, !acc) b@(Box y) = do + isElem <- or <$> mapM (areBoxesEqual b) vs + if isElem + then return (vs, acc) + else do + size <- closureSize y + closure <- getClosureData y + foldM go (b : vs, acc + size) $ allClosures closure + +-- | Calculate the recursive size of GHC objects in Bytes after calling +-- Control.DeepSeq.force on the data structure to force it into Normal Form. +-- Using this function requires that the data structure has an `NFData` +-- typeclass instance. + +recursiveSizeNF :: NFData a => a -> IO Word +recursiveSizeNF x = recursiveSize $!! x + +-- | Adapted from 'GHC.Exts.Heap.getClosureRaw' which isn't exported. +-- +-- This returns the raw words of the closure on the heap. Once back in the +-- Haskell world, the raw words that hold pointers may be outdated after a +-- garbage collector run. +getClosureRawWords :: a -> IO [Word] +getClosureRawWords x = do + case unpackClosure# x of + (# _iptr, dat, _pointers #) -> do + let nelems = I# (sizeofByteArray# dat) `div` wORD_SIZE + end = nelems - 1 + pure [W# (indexWordArray# dat i) | I# i <- [0.. end] ] diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index c224f74c79..dee387201f 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -88,6 +88,9 @@ observationLogger loggerState logLevel obs = case obs of o@(HasqlPoolObs _) -> do when (logLevel >= LogDebug) $ do logWithZTime loggerState $ observationMessage o + o@(JWTCache _) -> do + when (logLevel >= LogDebug) $ do + logWithZTime loggerState $ observationMessage o PoolRequest -> pure () PoolRequestFullfilled -> diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 3999e43d83..0a94ad899b 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -1,5 +1,5 @@ {-| -Module : PostgREST.Logger +Module : PostgREST.Metrics Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics @@ -19,7 +19,7 @@ import PostgREST.Observation import Protolude data MetricsState = - MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge + MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge Gauge init :: Int -> IO MetricsState init configDbPoolSize = do @@ -29,12 +29,13 @@ init configDbPoolSize = do poolMaxSize <- register $ gauge (Info "pgrst_db_pool_max" "Max pool connections") schemaCacheLoads <- register $ vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded") schemaCacheQueryTime <- register $ gauge (Info "pgrst_schema_cache_query_time_seconds" "The query time in seconds of the last schema cache load") + jwtCacheSize <- register $ gauge (Info "pgrst_jwt_cache_size_bytes" "The JWT cache size in bytes") setGauge poolMaxSize (fromIntegral configDbPoolSize) - pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime + pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime jwtCacheSize -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler -observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime) obs = case obs of +observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime jwtCacheSize) obs = case obs of (PoolAcqTimeoutObs _) -> do incCounter poolTimeouts (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of @@ -54,6 +55,8 @@ observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schema setGauge schemaCacheQueryTime resTime SchemaCacheErrorObs _ -> do withLabel schemaCacheLoads "FAIL" incCounter + JWTCache cacheSize -> do + setGauge jwtCacheSize (fromIntegral cacheSize) _ -> pure () diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index 18fbf558d7..841ff16458 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -57,6 +57,7 @@ data Observation | HasqlPoolObs SQL.Observation | PoolRequest | PoolRequestFullfilled + | JWTCache Int data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01 @@ -138,6 +139,8 @@ observationMessage = \case SQL.ReleaseConnectionTerminationReason -> "release" SQL.NetworkErrorConnectionTerminationReason _ -> "network error" -- usage error is already logged, no need to repeat the same message. ) + JWTCache sz -> "The JWT Cache size increased to " <> show sz <> " bytes" + -- TOOD: refactor to remove mempty _ -> mempty where showMillis :: Double -> Text diff --git a/test/io/test_io.py b/test/io/test_io.py index e61e074cbc..9b99da31cb 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1632,6 +1632,8 @@ def test_admin_metrics(defaultenv): assert "pgrst_db_pool_available" in response.text assert "pgrst_db_pool_timeouts_total" in response.text + assert "pgrst_jwt_cache_size_bytes" in response.text + def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): "verify that the Schema Cache loads correctly at startup, using the in-db `pgrst.db_schemas` config" @@ -1648,3 +1650,35 @@ def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): response = metapostgrest.session.post("/rpc/reset_db_schemas_config") assert response.text == "" assert response.status_code == 204 + + +def test_jwt_cache_size_increase_log(defaultenv): + "JWT cache size should increase on every new cache entry" + + env = { + **defaultenv, + "PGRST_LOG_LEVEL": "debug", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": SECRET, + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(env=env) as postgrest: + response = postgrest.session.get("/authors_only", headers=headers) + assert response.status_code == 200 + + output = sorted(postgrest.read_stdout(nlines=3)) + + response = postgrest.admin.get("/metrics") + assert response.status_code == 200 + + # read cache size from metrics + cache_size = float( + re.search(r"pgrst_jwt_cache_size_bytes (\d+)", response.text).group(1) + ) + + assert ( + "The JWT Cache size increased to " + str(int(cache_size)) + " bytes" + in output[2] + )