Skip to content

Commit d340716

Browse files
committed
Allow providing alternate statement cache implementations to SqlBackends
1 parent b6d092d commit d340716

File tree

5 files changed

+69
-44
lines changed

5 files changed

+69
-44
lines changed

persistent-postgresql/Database/Persist/Postgresql.hs

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
{-# LANGUAGE OverloadedStrings #-}
44
{-# LANGUAGE RecordWildCards #-}
55
{-# LANGUAGE ScopedTypeVariables #-}
6-
{-# LANGUAGE TupleSections #-}
76
{-# LANGUAGE TypeFamilies #-}
87
{-# LANGUAGE ViewPatterns #-}
98
{-# OPTIONS_GHC -fno-warn-deprecations #-} -- Pattern match 'PersistDbSpecific'
@@ -126,7 +125,7 @@ withPostgresqlPool :: (MonadLoggerIO m, MonadUnliftIO m)
126125
-- ^ Action to be executed that uses the
127126
-- connection pool.
128127
-> m a
129-
withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci
128+
withPostgresqlPool = withPostgresqlPoolWithVersion getServerVersion
130129

131130
-- | Same as 'withPostgresPool', but takes a callback for obtaining
132131
-- the server version (to work around an Amazon Redshift bug).
@@ -146,7 +145,7 @@ withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLoggerIO m)
146145
-> m a
147146
withPostgresqlPoolWithVersion getVerDouble ci = do
148147
let getVer = oldGetVersionToNew getVerDouble
149-
withSqlPool $ open' (const $ return ()) getVer ci
148+
withSqlPool $ open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer }) ci
150149

151150
-- | Same as 'withPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'.
152151
--
@@ -159,9 +158,7 @@ withPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m)
159158
-- connection pool.
160159
-> m a
161160
withPostgresqlPoolWithConf conf hooks = do
162-
let getVer = pgConfHooksGetServerVersion hooks
163-
modConn = pgConfHooksAfterCreate hooks
164-
let logFuncToBackend = open' modConn getVer (pgConnStr conf)
161+
let logFuncToBackend = open' hooks (pgConnStr conf)
165162
withSqlPoolWithConfig logFuncToBackend (postgresConfToConnectionPoolConfig conf)
166163

167164
-- | Create a PostgreSQL connection pool. Note that it's your
@@ -207,7 +204,11 @@ createPostgresqlPoolModifiedWithVersion
207204
-> m (Pool SqlBackend)
208205
createPostgresqlPoolModifiedWithVersion getVerDouble modConn ci = do
209206
let getVer = oldGetVersionToNew getVerDouble
210-
createSqlPool $ open' modConn getVer ci
207+
hooks = defaultPostgresConfHooks
208+
{ pgConfHooksAfterCreate = modConn
209+
, pgConfHooksGetServerVersion = getVer
210+
}
211+
createSqlPool $ open' hooks ci
211212

212213
-- | Same as 'createPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'.
213214
--
@@ -218,9 +219,7 @@ createPostgresqlPoolWithConf
218219
-> PostgresConfHooks -- ^ Record of callback functions
219220
-> m (Pool SqlBackend)
220221
createPostgresqlPoolWithConf conf hooks = do
221-
let getVer = pgConfHooksGetServerVersion hooks
222-
modConn = pgConfHooksAfterCreate hooks
223-
createSqlPoolWithConfig (open' modConn getVer (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf)
222+
createSqlPoolWithConfig (open' hooks (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf)
224223

225224
postgresConfToConnectionPoolConfig :: PostgresConf -> ConnectionPoolConfig
226225
postgresConfToConnectionPoolConfig conf =
@@ -249,17 +248,18 @@ withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLoggerIO m)
249248
-> m a
250249
withPostgresqlConnWithVersion getVerDouble = do
251250
let getVer = oldGetVersionToNew getVerDouble
252-
withSqlConn . open' (const $ return ()) getVer
251+
withSqlConn . open' (defaultPostgresConfHooks { pgConfHooksGetServerVersion = getVer })
253252

254253
open'
255-
:: (PG.Connection -> IO ())
256-
-> (PG.Connection -> IO (NonEmpty Word))
257-
-> ConnectionString -> LogFunc -> IO SqlBackend
258-
open' modConn getVer cstr logFunc = do
254+
:: PostgresConfHooks
255+
-> ConnectionString
256+
-> LogFunc
257+
-> IO SqlBackend
258+
open' PostgresConfHooks{..} cstr logFunc = do
259259
conn <- PG.connectPostgreSQL cstr
260-
modConn conn
261-
ver <- getVer conn
262-
smap <- newIORef $ Map.empty
260+
pgConfHooksAfterCreate conn
261+
ver <- pgConfHooksGetServerVersion conn
262+
smap <- pgConfHooksCreateStatementCache
263263
return $ createBackend logFunc ver smap conn
264264

265265
-- | Gets the PostgreSQL server version
@@ -295,10 +295,9 @@ getServerVersionNonEmpty conn = do
295295
-- so depending upon that we have to choose how the sql query is generated.
296296
-- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text)
297297
upsertFunction :: a -> NonEmpty Word -> Maybe a
298-
upsertFunction f version = if (version >= postgres9dot5)
298+
upsertFunction f version = if version >= postgres9dot5
299299
then Just f
300300
else Nothing
301-
where
302301

303302
postgres9dot5 :: NonEmpty Word
304303
postgres9dot5 = 9 NEL.:| [5]
@@ -310,7 +309,7 @@ minimumPostgresVersion :: NonEmpty Word
310309
minimumPostgresVersion = 9 NEL.:| [4]
311310

312311
oldGetVersionToNew :: (PG.Connection -> IO (Maybe Double)) -> (PG.Connection -> IO (NonEmpty Word))
313-
oldGetVersionToNew oldFn = \conn -> do
312+
oldGetVersionToNew oldFn conn = do
314313
mDouble <- oldFn conn
315314
case mDouble of
316315
Nothing -> pure minimumPostgresVersion
@@ -328,14 +327,14 @@ openSimpleConn = openSimpleConnWithVersion getServerVersion
328327
-- @since 2.9.1
329328
openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend
330329
openSimpleConnWithVersion getVerDouble logFunc conn = do
331-
smap <- newIORef $ Map.empty
330+
smap <- makeSimpleStatementCache
332331
serverVersion <- oldGetVersionToNew getVerDouble conn
333332
return $ createBackend logFunc serverVersion smap conn
334333

335334
-- | Create the backend given a logging function, server version, mutable statement cell,
336335
-- and connection.
337336
createBackend :: LogFunc -> NonEmpty Word
338-
-> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend
337+
-> StatementCache -> PG.Connection -> SqlBackend
339338
createBackend logFunc serverVersion smap conn = do
340339
SqlBackend
341340
{ connPrepare = prepare' conn
@@ -422,7 +421,7 @@ upsertSql' ent uniqs updateVal =
422421
wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs
423422

424423
singleClause :: FieldNameDB -> Text
425-
singleClause field = escapeE (entityDB ent) <> "." <> (escapeF field) <> " =?"
424+
singleClause field = escapeE (entityDB ent) <> "." <> escapeF field <> " =?"
426425

427426
-- | SQL for inserting multiple rows at once and returning their primary keys.
428427
insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult
@@ -608,7 +607,7 @@ instance PGFF.FromField PgInterval where
608607
nominalDiffTime :: P.Parser NominalDiffTime
609608
nominalDiffTime = do
610609
(s, h, m, ss) <- interval
611-
let pico = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h))
610+
let pico = ss + 60 * fromIntegral m + 60 * 60 * fromIntegral (abs h)
612611
return . fromRational . toRational $ if s then (-pico) else pico
613612

614613
fromPersistValueError :: Text -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64"
@@ -799,7 +798,7 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do
799798
-- for https://github.com/yesodweb/persistent/issues/152
800799

801800
createText newcols fdefs_ udspair =
802-
(addTable newcols entity) : uniques ++ references ++ foreignsAlt
801+
addTable newcols entity : uniques ++ references ++ foreignsAlt
803802
where
804803
uniques = flip concatMap udspair $ \(uname, ucols) ->
805804
[AlterTable name $ AddUniqueConstraint uname ucols]
@@ -1076,7 +1075,7 @@ getColumn getter tableName' [ PersistText columnName
10761075

10771076
let cname = FieldNameDB columnName
10781077

1079-
ref <- lift $ fmap join $ traverse (getRef cname) refName_
1078+
ref <- lift $ join <$> traverse (getRef cname) refName_
10801079

10811080
return Column
10821081
{ cName = cname
@@ -1538,9 +1537,9 @@ instance FromJSON PostgresConf where
15381537
port <- o .:? "port" .!= 5432
15391538
user <- o .: "user"
15401539
password <- o .: "password"
1541-
poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig)
1542-
poolStripes <- o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig)
1543-
poolIdleTimeout <- o .:? "idleTimeout" .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig)
1540+
poolSize <- o .:? "poolsize" .!= connectionPoolConfigSize defaultPoolConfig
1541+
poolStripes <- o .:? "stripes" .!= connectionPoolConfigStripes defaultPoolConfig
1542+
poolIdleTimeout <- o .:? "idleTimeout" .!= floor (connectionPoolConfigIdleTimeout defaultPoolConfig)
15441543
let ci = PG.ConnectInfo
15451544
{ PG.connectHost = host
15461545
, PG.connectPort = port
@@ -1605,6 +1604,7 @@ data PostgresConfHooks = PostgresConfHooks
16051604
-- The default implementation does nothing.
16061605
--
16071606
-- @since 2.11.0
1607+
, pgConfHooksCreateStatementCache :: IO StatementCache
16081608
}
16091609

16101610
-- | Default settings for 'PostgresConfHooks'. See the individual fields of 'PostgresConfHooks' for the default values.
@@ -1614,6 +1614,7 @@ defaultPostgresConfHooks :: PostgresConfHooks
16141614
defaultPostgresConfHooks = PostgresConfHooks
16151615
{ pgConfHooksGetServerVersion = getServerVersionNonEmpty
16161616
, pgConfHooksAfterCreate = const $ pure ()
1617+
, pgConfHooksCreateStatementCache = makeSimpleStatementCache
16171618
}
16181619

16191620

@@ -1695,7 +1696,7 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do
16951696
-- with the difference that an actual database is not needed.
16961697
mockMigration :: Migration -> IO ()
16971698
mockMigration mig = do
1698-
smap <- newIORef $ Map.empty
1699+
smap <- makeSimpleStatementCache
16991700
let sqlbackend = SqlBackend { connPrepare = \_ -> do
17001701
return Statement
17011702
{ stmtFinalize = return ()

persistent/Database/Persist/Sql/Raw.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ import Control.Monad.Trans.Resource (MonadResource,release)
99
import Data.Acquire (allocateAcquire, Acquire, mkAcquire, with)
1010
import Data.Conduit
1111
import Data.IORef (writeIORef, readIORef, newIORef)
12-
import qualified Data.Map as Map
1312
import Data.Int (Int64)
1413
import Data.Text (Text, pack)
1514
import qualified Data.Text as T
1615

1716
import Database.Persist
1817
import Database.Persist.Sql.Types
1918
import Database.Persist.Sql.Class
19+
import Database.Persist.Sql.Types.Internal (statementCacheLookup, StatementCache (statementCacheInsert))
2020

2121
rawQuery :: (MonadResource m, MonadReader env m, BackendCompatible SqlBackend env)
2222
=> Text
@@ -74,8 +74,8 @@ getStmt sql = do
7474

7575
getStmtConn :: SqlBackend -> Text -> IO Statement
7676
getStmtConn conn sql = do
77-
smap <- liftIO $ readIORef $ connStmtMap conn
78-
case Map.lookup sql smap of
77+
smap <- liftIO $ statementCacheLookup (connStmtMap conn) sql
78+
case smap of
7979
Just stmt -> connStatementMiddleware conn sql stmt
8080
Nothing -> do
8181
stmt' <- liftIO $ connPrepare conn sql
@@ -99,7 +99,7 @@ getStmtConn conn sql = do
9999
then stmtQuery stmt' x
100100
else liftIO $ throwIO $ StatementAlreadyFinalized sql
101101
}
102-
liftIO $ writeIORef (connStmtMap conn) $ Map.insert sql stmt smap
102+
liftIO $ statementCacheInsert (connStmtMap conn) sql stmt
103103
connStatementMiddleware conn sql stmt
104104

105105
-- | Execute a raw SQL statement and return its results as a

persistent/Database/Persist/Sql/Run.hs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@ import qualified Control.Monad.Reader as MonadReader
99
import Control.Monad.Trans.Reader hiding (local)
1010
import Control.Monad.Trans.Resource
1111
import Data.Acquire (Acquire, ReleaseType(..), mkAcquireType, with)
12-
import Data.IORef (readIORef)
1312
import Data.Pool (Pool)
1413
import Data.Pool as P
15-
import qualified Data.Map as Map
1614
import qualified Data.Text as T
1715

1816
import Database.Persist.Class.PersistStore
1917
import Database.Persist.Sql.Types
20-
import Database.Persist.Sql.Types.Internal (IsolationLevel)
18+
import Database.Persist.Sql.Types.Internal (IsolationLevel, StatementCache (..))
2119
import Database.Persist.Sql.Raw
2220

2321
-- | Get a connection from the pool, run the given action, and then return the
@@ -184,7 +182,7 @@ withSqlPool
184182
-> Int -- ^ connection count
185183
-> (Pool backend -> m a)
186184
-> m a
187-
withSqlPool mkConn connCount f = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } ) f
185+
withSqlPool mkConn connCount = withSqlPoolWithConfig mkConn (defaultConnectionPoolConfig { connectionPoolConfigSize = connCount } )
188186

189187
-- | Creates a pool of connections to a SQL database which can be used by the @Pool backend -> m a@ function.
190188
-- After the function completes, the connections are destroyed.
@@ -297,5 +295,6 @@ withSqlConn open f = do
297295

298296
close' :: (BackendCompatible SqlBackend backend) => backend -> IO ()
299297
close' conn = do
300-
readIORef (connStmtMap $ projectBackend conn) >>= mapM_ stmtFinalize . Map.elems
301-
connClose $ projectBackend conn
298+
let backend = projectBackend conn
299+
statementCacheClear $ connStmtMap backend
300+
connClose backend

persistent/Database/Persist/Sql/Types.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ module Database.Persist.Sql.Types
66
, SqlBackendCanRead, SqlBackendCanWrite, SqlReadT, SqlWriteT, IsSqlBackend
77
, OverflowNatural(..)
88
, ConnectionPoolConfig(..)
9+
, StatementCache(..)
10+
, makeSimpleStatementCache
911
) where
1012

1113
import Database.Persist.Types.Base (FieldCascade)

persistent/Database/Persist/Sql/Types/Internal.hs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ module Database.Persist.Sql.Types.Internal
1919
, SqlReadT
2020
, SqlWriteT
2121
, IsSqlBackend
22+
, StatementCache(..)
23+
, makeSimpleStatementCache
2224
) where
2325

2426
import Data.List.NonEmpty (NonEmpty(..))
@@ -29,8 +31,8 @@ import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask)
2931
import Data.Acquire (Acquire)
3032
import Data.Conduit (ConduitM)
3133
import Data.Int (Int64)
32-
import Data.IORef (IORef)
33-
import Data.Map (Map)
34+
import Data.IORef
35+
import qualified Data.Map as Map
3436
import Data.Monoid ((<>))
3537
import Data.String (IsString)
3638
import Data.Text (Text)
@@ -45,6 +47,7 @@ import Database.Persist.Class
4547
)
4648
import Database.Persist.Class.PersistStore (IsPersistBackend (..))
4749
import Database.Persist.Types
50+
import Data.Foldable (traverse_)
4851

4952
type LogFunc = Loc -> LogSource -> LogLevel -> LogStr -> IO ()
5053

@@ -76,6 +79,26 @@ makeIsolationLevelStatement l = "SET TRANSACTION ISOLATION LEVEL " <> case l of
7679
RepeatableRead -> "REPEATABLE READ"
7780
Serializable -> "SERIALIZABLE"
7881

82+
data StatementCache = StatementCache
83+
{ statementCacheLookup :: Text -> IO (Maybe Statement)
84+
, statementCacheInsert :: Text -> Statement -> IO ()
85+
, statementCacheClear :: IO ()
86+
, statementCacheSize :: IO Int
87+
}
88+
89+
makeSimpleStatementCache :: IO StatementCache
90+
makeSimpleStatementCache = do
91+
stmtMap <- newIORef Map.empty
92+
pure $ StatementCache
93+
{ statementCacheLookup = \sql -> Map.lookup sql <$> readIORef stmtMap
94+
, statementCacheInsert = \sql stmt ->
95+
modifyIORef' stmtMap (Map.insert sql stmt)
96+
, statementCacheClear = do
97+
oldStatements <- atomicModifyIORef' stmtMap (\oldStatements -> (Map.empty, oldStatements))
98+
traverse_ stmtFinalize oldStatements
99+
, statementCacheSize = Map.size <$> readIORef stmtMap
100+
}
101+
79102
-- | A 'SqlBackend' represents a handle or connection to a database. It
80103
-- contains functions and values that allow databases to have more
81104
-- optimized implementations, as well as references that benefit
@@ -127,7 +150,7 @@ data SqlBackend = SqlBackend
127150
-- When left as 'Nothing', we default to using 'defaultPutMany'.
128151
--
129152
-- @since 2.8.1
130-
, connStmtMap :: IORef (Map Text Statement)
153+
, connStmtMap :: StatementCache
131154
-- ^ A reference to the cache of statements. 'Statement's are keyed by
132155
-- the 'Text' queries that generated them.
133156
, connClose :: IO ()

0 commit comments

Comments
 (0)