diff --git a/CHANGELOG.md b/CHANGELOG.md index 062e6194a5..d552adb80b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ All notable changes to this project will be documented in this file. From versio - Remove automatic transaction retries on `40001 (serialization_failure)` errors to prevent replication lag by @laurenceisla in #3673 - Fix unexpected results when embedding and filtering the same table more than once by @laurenceisla in #4075 - If the schema cache fails to reload, PostgREST will no longer stop serving requests and will continue doing so in a "best effort" basis by @mkleczek in #4873 #4869 +- Limit concurrent schema cache loads by @mkleczek in #4643 ### Changed diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index 93dff41342..1c3018edb1 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -1,7 +1,9 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE TypeApplications #-} module PostgREST.AppState ( AppState @@ -31,7 +33,8 @@ import qualified Data.ByteString.Char8 as BS import Data.Either.Combinators (whenLeft) import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Config as SQL -import qualified Hasql.Session as SQL +import qualified Hasql.Session as SQL hiding (statement) +import qualified Hasql.Transaction as SQL hiding (sql) import qualified Hasql.Transaction.Sessions as SQL import qualified Network.HTTP.Types.Status as HTTP import qualified PostgREST.Auth.JwtCache as JwtCache @@ -61,11 +64,17 @@ import PostgREST.Config.Database (queryDbSettings, import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion) import PostgREST.Debounce (makeDebouncer) +import PostgREST.Metrics (MetricsState (connTrack)) import PostgREST.SchemaCache (SchemaCache (..), querySchemaCache, showSummary) import PostgREST.SchemaCache.Identifiers (quoteQi) +import qualified Hasql.Decoders as HD +import qualified Hasql.Encoders as HE +import qualified Hasql.Statement as SQL +import NeatInterpolation (trimming) + import Protolude data AppState = AppState @@ -292,7 +301,7 @@ getObserver = stateObserver -- + Because connections cache the pg catalog(see #2620) -- + For rapid recovery. Otherwise, the pool idle or lifetime timeout would have to be reached for new healthy connections to be acquired. retryingSchemaCacheLoad :: AppState -> IO () -retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId} = +retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId, stateMetrics} = void $ retrying retryPolicy shouldRetry (\RetryStatus{rsIterNumber, rsPreviousDelay} -> do when (rsIterNumber > 0) $ do let delay = fromMaybe 0 rsPreviousDelay `div` oneSecondInUs @@ -331,8 +340,22 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea qSchemaCache :: IO (Maybe SchemaCache) qSchemaCache = do conf@AppConfig{..} <- getConfig appState + -- Throttle concurrent schema cache loads, guarded by advisory locks. + -- This is to prevent thundering herd problem on startup or when many PostgREST + -- instances receive "reload schema" notifications at the same time + -- See get_lock_sql for details of the algorithm. + -- Here we calculate the number of open connections passed to the query. + Metrics.ConnStats connected inUse <- Metrics.connectionCounts $ connTrack stateMetrics + -- Determine whether schema cache loading will create a new session + let + -- if all connections in use but pool not full - schema cache loading will create session + scLoadingSessions = if connected <= inUse && inUse < configDbPoolSize then 1 else 0 + withTxLock = SQL.statement + (fromIntegral $ connected + scLoadingSessions) + (SQL.Statement get_lock_sql get_lock_params HD.noResult configDbPreparedStatements) + (resultTime, result) <- - timeItT $ usePool appState (SQL.transactionNoRetry SQL.ReadCommitted SQL.Read $ querySchemaCache conf) + timeItT $ usePool appState (SQL.transactionNoRetry SQL.ReadCommitted SQL.Read $ withTxLock *> querySchemaCache conf) case result of Left e -> do markSchemaCachePending appState @@ -353,6 +376,43 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea observer $ SchemaCacheLoadedObs loadTime summary markSchemaCacheLoaded appState return $ Just sCache + where + -- Recursive query that tries acquiring locks in order + -- and waits for randomly selected lock if no attempt succeeded. + -- It has a single parameter: this node open connection count. + -- It is used to estimate the number of nodes + -- by counting the number of active sessions for current session_user + -- and dividing it by this node open connections. + -- Assuming load is uniform among cluster nodes, all should have + -- statistically the same number of open connections. + -- Once the number of nodes is known we calculate the number + -- of locks as ceil(log(2, number_of_nodes)) + get_lock_sql = encodeUtf8 [trimming| + WITH RECURSIVE attempts AS ( + SELECT 1 AS lock_number, pg_try_advisory_xact_lock(lock_id, 1) AS success FROM parameters + UNION ALL + SELECT next_lock_number AS lock_number, pg_try_advisory_xact_lock(lock_id, next_lock_number) AS success + FROM + parameters CROSS JOIN LATERAL ( + SELECT lock_number + 1 AS next_lock_number FROM attempts + WHERE NOT success AND lock_number < locks_count + ORDER BY lock_number DESC + LIMIT 1 + ) AS previous_attempt + ), + counts AS ( + SELECT round(log(2, round(count(*)::double precision/$$1)::numeric))::int AS locks_count + FROM + pg_stat_activity WHERE usename = SESSION_USER + ), + parameters AS ( + SELECT locks_count, 50168275 AS lock_id FROM counts WHERE locks_count > 0 + ) + SELECT pg_advisory_xact_lock(lock_id, floor(random() * locks_count)::int + 1) + FROM + parameters WHERE NOT EXISTS (SELECT 1 FROM attempts WHERE success) |] + + get_lock_params = HE.param (HE.nonNullable HE.int4) shouldRetry :: RetryStatus -> (Maybe PgVersion, Maybe SchemaCache) -> IO Bool shouldRetry _ (pgVer, sCache) = do diff --git a/test/io/test_io.py b/test/io/test_io.py index 1fde6332a5..78c88d960b 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1,5 +1,6 @@ "Unit tests for Input/Ouput of PostgREST seen as a black box." +import contextlib import os import re import signal @@ -28,6 +29,7 @@ sleep_until_postgrest_full_reload, sleep_until_postgrest_scache_reload, wait_until_exit, + wait_until_status_code, ) @@ -1422,6 +1424,93 @@ def test_schema_cache_concurrent_notifications(slow_schema_cache_env): assert response.status_code == 200 +@pytest.mark.parametrize( + "instance_count, expected_concurrency", [(2, 2), (4, 3), (6, 4), (8, 4), (16, 5)] +) +def test_schema_cache_reload_throttled_with_advisory_locks( + instance_count, expected_concurrency, slow_schema_cache_env +): + "schema cache reloads should be throttled across instances" + + internal_sleep_ms = int( + slow_schema_cache_env["PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP"] + ) + lock_wait_threshold_ms = internal_sleep_ms * 2 + query_log_pattern = re.compile(r"Schema cache queried in ([\d.]+) milliseconds") + + def read_available_output_lines(postgrest): + try: + output = postgrest.process.stdout.read() + except BlockingIOError: + return [] + + if not output: + return [] + return output.decode().splitlines() + + with contextlib.ExitStack() as stack: + instances = [ + stack.enter_context( + run( + env=slow_schema_cache_env, + wait_for=None, + wait_max_seconds=10, + ) + ) + for _ in range(instance_count) + ] + + for postgrest in instances: + wait_until_status_code( + postgrest.admin.baseurl + "/ready", max_seconds=10, status_code=200 + ) + + # Drop startup logs so only reload logs are parsed. + for postgrest in instances: + read_available_output_lines(postgrest) + + response = instances[0].session.get("/rpc/notify_pgrst") + assert response.status_code == 204 + + # Wait long enough for the lock-throttled cache reloads to finish. + time.sleep((internal_sleep_ms / 1000) * 2) + + reload_durations_ms = [] + for postgrest in instances: + output_lines = [] + for _ in range(instance_count * 2): + output_lines.extend(read_available_output_lines(postgrest)) + if any(query_log_pattern.search(line) for line in output_lines): + break + time.sleep(0.2) + + durations = [] + for line in output_lines: + match = query_log_pattern.search(line) + if match: + durations.append(float(match.group(1))) + + assert durations + reload_durations_ms.append(max(durations)) + + assert len(reload_durations_ms) == instance_count + + # expected_concurrency instances should have + # reload_durations_ms <= lock_wait_threshold_ms + # the rest should wait + assert ( + instance_count + - len( + [ + duration + for duration in reload_durations_ms + if duration > lock_wait_threshold_ms + ] + ) + == expected_concurrency + ) + + def test_schema_cache_query_sleep_logs(defaultenv): """Schema cache sleep should be reflected in the logged query duration."""