From 8295af296f7459e417ed89976eaab0806380feae Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Wed, 3 Dec 2025 21:19:32 +0000 Subject: [PATCH 01/14] wip --- .../Database/Persist/Postgresql/Internal.hs | 3 + .../Persist/Postgresql/Internal/Migration.hs | 405 ++++++++++++++++++ .../persistent-postgresql.cabal | 1 + 3 files changed, 409 insertions(+) create mode 100644 persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 1581d1f0a..1ceaf6971 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -910,6 +910,9 @@ getColumns -> EntityDef -> [Column] -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] + -- ^ Left Text: error + -- Right (Left Column): a column + -- Right (Right ...): a constraint getColumns getter def cols = do let sqlv = diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs new file mode 100644 index 000000000..734092478 --- /dev/null +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -0,0 +1,405 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} + +-- | Generate postgresql migrations for a set of EntityDefs, either from scratch +-- or based on the current state of a database. +module Database.Persist.Postgresql.Internal.Migration where + +import Control.Arrow +import Control.Monad.Except +import Control.Monad.IO.Class +import Data.Acquire (with) +import Data.Conduit +import qualified Data.Conduit.List as CL +import Data.Either (partitionEithers) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import Data.Traversable +import Database.Persist.Sql + +-- | In order to ensure that generating migrations is fast and avoids N+1 +-- queries, we split it into two phases. The first phase involves querying the +-- database to gather all of the information we need about the existing schema. +-- The second phase then generates migrations based on the information from the +-- first phase. This data type represents all of the data that's gathered during +-- the first phase: information about the current state of the entities we're +-- migrating in the database. +newtype SchemaState = SchemaState (Map EntityNameDB EntitySchemaState) + deriving (Eq, Show) + +-- | The state of a particular entity in the database; we generate migrations +-- based on the diff of this versus an EntityDef. +data EntitySchemaState = EntitySchemaState + { essColumns :: [Column] + -- ^ The columns in this entity + , essConstraints :: Map ConstraintNameDB [FieldNameDB] + -- ^ A map of constraint names to the columns that are affected by those + -- constraints. Primary key and foreign key constraints are not included + -- here, since they are part of the 'Column'. + } + deriving (Eq, Show) + +-- | Query a database in order to assemble a SchemaState containing information +-- about each of the entities in the given list. +collectSchemaState + :: (Text -> IO Statement) -> [EntityNameDB] -> IO (Either Text SchemaState) +collectSchemaState getStmt entityNames = runExceptT $ do + columns <- getColumnsWithoutReferences getStmt entityNames + constraints <- getConstraints getStmt entityNames + foreignKeyReferences <- getForeignKeyReferences getStmt entityNames + + fmap (SchemaState . Map.fromList) $ + for entityNames $ \entityNameDB -> do + let + addColumnReference column = + column + { cReference = Map.lookup (cName column) =<< Map.lookup entityNameDB foreignKeyReferences + } + + essColumns <- case Map.lookup entityNameDB columns of + Just cols -> + pure (map addColumnReference cols) + Nothing -> + throwError + ("Missing entity name from columns map: " <> unEntityNameDB entityNameDB) + + let + essConstraints = fromMaybe Map.empty (Map.lookup entityNameDB constraints) + pure + ( entityNameDB + , EntitySchemaState{essColumns, essConstraints} + ) + +runStmt + :: (Show a) + => (Text -> IO Statement) + -> Text + -> [PersistValue] + -> ([PersistValue] -> a) + -> IO [a] +runStmt getStmt sql values process = do + stmt <- getStmt sql + results <- + with + (stmtQuery stmt values) + (\src -> runConduit $ src .| CL.map process .| CL.consume) + pure results + +-- | Get all columns for the listed tables from the database, ignoring foreign +-- key references (those are filled in later). +getColumnsWithoutReferences + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB [Column]) +getColumnsWithoutReferences getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getColumnsSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processColumn + case partitionEithers results of + ([], xs) -> pure $ Map.fromListWith (++) $ map (second (: [])) xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getColumnsSql = + T.concat + [ "SELECT " + , "table_name " + , ",column_name " + , ",is_nullable " + , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below + , ",column_default " + , ",generation_expression " + , ",numeric_precision " + , ",numeric_scale " + , ",character_maximum_length " + , "FROM information_schema.columns " + , "WHERE table_catalog=current_database() " + , "AND table_schema=current_schema() " + , "AND table_name=ANY (?) " + ] + + -- DOMAINS Postgres supports the concept of domains, which are data types + -- with optional constraints. An app might make an "email" domain over the + -- varchar type, with a CHECK that the emails are valid In this case the + -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN + -- foo TYPE email This code exists to use the domain name (email), instead + -- of the underlying type (varchar). This is tested in + -- EquivalentTypeTest.hs + processColumn :: [PersistValue] -> Either Text (EntityNameDB, Column) + processColumn resultRow = do + case resultRow of + [ PersistText tableName + , PersistText columnName + , PersistText isNullable + , PersistText typeName + , defaultValue + , generationExpression + , numericPrecision + , numericScale + , maxlen + ] -> mapLeft (addErrorContext tableName columnName) $ do + defaultValue' <- + case defaultValue of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid default column: " ++ show defaultValue + generationExpression' <- + case generationExpression of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression + let + typeStr = + case maxlen of + PersistInt64 n -> + T.concat [typeName, "(", T.pack (show n), ")"] + _ -> + typeName + + t <- getType numericPrecision numericScale typeStr + + pure + ( EntityNameDB tableName + , Column + { cName = FieldNameDB columnName + , cNull = isNullable == "YES" + , cSqlType = t + , cDefault = fmap stripSuffixes defaultValue' + , cGenerated = fmap stripSuffixes generationExpression' + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ) + other -> + Left $ + T.pack $ + "Invalid result from information_schema: " ++ show other + + stripSuffixes t = + loop' + [ "::character varying" + , "::text" + ] + where + loop' [] = t + loop' (p : ps) = + case T.stripSuffix p t of + Nothing -> loop' ps + Just t' -> t' + + getType _ _ "int4" = pure SqlInt32 + getType _ _ "int8" = pure SqlInt64 + getType _ _ "varchar" = pure SqlString + getType _ _ "text" = pure SqlString + getType _ _ "date" = pure SqlDay + getType _ _ "bool" = pure SqlBool + getType _ _ "timestamptz" = pure SqlDayTime + getType _ _ "float4" = pure SqlReal + getType _ _ "float8" = pure SqlReal + getType _ _ "bytea" = pure SqlBlob + getType _ _ "time" = pure SqlTime + getType precision scale "numeric" = getNumeric precision scale + getType _ _ a = pure $ SqlOther a + + getNumeric (PersistInt64 a) (PersistInt64 b) = + pure $ SqlNumeric (fromIntegral a) (fromIntegral b) + getNumeric PersistNull PersistNull = + throwError $ + T.concat + [ "No precision and scale were specified. " + , "Postgres defaults to a maximum scale of 147,455 and precision of 16383," + , " which is probably not what you intended." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + getNumeric a b = + throwError $ + T.concat + [ "Can not get numeric field precision. " + , "Expected an integer for both precision and scale, " + , "got: " + , T.pack $ show a + , " and " + , T.pack $ show b + , ", respectively." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + +-- cyclist putting a stick into his own wheel meme +addErrorContext :: Text -> Text -> Text -> Text +addErrorContext tableName columnName originalMsg = + T.concat + [ "Error in column " + , tableName + , "." + , columnName + , ": " + , originalMsg + ] + +-- | Get all constraints for the listed tables from the database, except for foreign +-- keys and primary keys (those go in the Column data type) +getConstraints + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB])) +getConstraints getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getConstraintsSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processConstraint + case partitionEithers results of + ([], xs) -> pure $ Map.unionsWith (Map.unionWith (<>)) xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getConstraintsSql = + T.concat + [ "SELECT " + , "c.table_name, " + , "c.constraint_name, " + , "c.column_name " + , "FROM information_schema.key_column_usage AS c, " + , "information_schema.table_constraints AS k " + , "WHERE c.table_catalog=current_database() " + , "AND c.table_catalog=k.table_catalog " + , "AND c.table_schema=current_schema() " + , "AND c.table_schema=k.table_schema " + , "AND c.table_name=ANY (?) " + , "AND c.table_name=k.table_name " + , "AND c.constraint_name=k.constraint_name " + , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " + , "ORDER BY c.constraint_name, c.column_name" + ] + + processConstraint + :: [PersistValue] + -> Either Text (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB])) + processConstraint resultRow = do + (tableName, constraintName, columnName) <- case resultRow of + [PersistText tab, PersistText con, PersistText col] -> + pure (tab, con, col) + [PersistByteString tab, PersistByteString con, PersistByteString col] -> + pure (T.decodeUtf8 tab, T.decodeUtf8 con, T.decodeUtf8 col) + o -> + throwError $ T.pack $ "unexpected datatype returned for postgres o=" ++ show o + + pure $ + Map.singleton + (EntityNameDB tableName) + (Map.singleton (ConstraintNameDB constraintName) [FieldNameDB columnName]) + +-- | Get foreign key reference information for all columns in the supplied +-- tables from the database. +getForeignKeyReferences + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB (Map FieldNameDB ColumnReference)) +getForeignKeyReferences getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getForeignKeyReferencesSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processForeignKeyReference + case partitionEithers results of + ([], xs) -> pure $ Map.unionsWith Map.union xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + -- TODO: should this filter by schema? + getForeignKeyReferencesSql = + T.concat + [ "SELECT DISTINCT " + , "kcu.table_name, " + , "kcu.column_name, " + , "ccu.table_name, " + , "tc.constraint_name, " + , "rc.update_rule, " + , "rc.delete_rule " + , "FROM information_schema.constraint_column_usage ccu " + , "INNER JOIN information_schema.key_column_usage kcu " + , " ON ccu.constraint_name = kcu.constraint_name " + , "INNER JOIN information_schema.table_constraints tc " + , " ON tc.constraint_name = kcu.constraint_name " + , "LEFT JOIN information_schema.referential_constraints AS rc" + , " ON rc.constraint_name = ccu.constraint_name " + , "WHERE tc.constraint_type='FOREIGN KEY' " + , "AND kcu.ordinal_position=1 " + , "AND kcu.table_name=ANY (?) " + ] + + processForeignKeyReference + :: [PersistValue] + -> Either Text (Map EntityNameDB (Map FieldNameDB ColumnReference)) + processForeignKeyReference resultRow = do + (sourceTableName, sourceColumnName, refTableName, constraintName, updRule, delRule) <- + case resultRow of + [ PersistText srcTable + , PersistText srcColumn + , PersistText refTable + , PersistText constraint + , PersistText updRule + , PersistText delRule + ] -> + pure + ( EntityNameDB srcTable + , FieldNameDB srcColumn + , EntityNameDB refTable + , ConstraintNameDB constraint + , updRule + , delRule + ) + other -> + throwError $ T.pack $ "unexpected row returned for postgres: " ++ show other + + fcOnUpdate <- parseCascade updRule + fcOnDelete <- parseCascade delRule + + let columnRef = ColumnReference + { crTableName = refTableName + , crConstraintName = constraintName + , crFieldCascade = FieldCascade + { fcOnUpdate = Just fcOnUpdate + , fcOnDelete = Just fcOnDelete + } + } + + pure $ Map.singleton sourceTableName (Map.singleton sourceColumnName columnRef) + +parseCascade :: Text -> Either Text CascadeAction +parseCascade txt = + case txt of + "NO ACTION" -> + Right NoAction + "CASCADE" -> + Right Cascade + "SET NULL" -> + Right SetNull + "SET DEFAULT" -> + Right SetDefault + "RESTRICT" -> + Right Restrict + _ -> + Left $ "Unexpected value in parseCascade: " <> txt + +mapLeft :: (a1 -> a2) -> Either a1 b -> Either a2 b +mapLeft _ (Right x) = Right x +mapLeft f (Left x) = Left (f x) diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index f2f71969f..d9ae9d93e 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -41,6 +41,7 @@ library exposed-modules: Database.Persist.Postgresql Database.Persist.Postgresql.Internal + Database.Persist.Postgresql.Internal.Migration Database.Persist.Postgresql.JSON ghc-options: -Wall From df645294ad8ca3d13d795efa7f46f30a1bf6b6ce Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 11:08:17 +0000 Subject: [PATCH 02/14] Handle existence + test --- .../Persist/Postgresql/Internal/Migration.hs | 128 +++++++-- .../persistent-postgresql.cabal | 1 + persistent-postgresql/test/MigrationSpec.hs | 259 ++++++++++++++++++ persistent-postgresql/test/main.hs | 2 + 4 files changed, 362 insertions(+), 28 deletions(-) create mode 100644 persistent-postgresql/test/MigrationSpec.hs diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 734092478..d20df0827 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -17,6 +17,7 @@ import Data.Either (partitionEithers) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe +import qualified Data.Set as Set import Data.Text (Text) import qualified Data.Text as T import qualified Data.Text.Encoding as T @@ -33,9 +34,17 @@ import Database.Persist.Sql newtype SchemaState = SchemaState (Map EntityNameDB EntitySchemaState) deriving (Eq, Show) --- | The state of a particular entity in the database; we generate migrations --- based on the diff of this versus an EntityDef. -data EntitySchemaState = EntitySchemaState +-- | The state of a particular entity (i.e. table) in the database; we generate +-- migrations based on the diff of this versus an EntityDef. +data EntitySchemaState + = -- | The table does not exist in the database + EntityDoesNotExist + | -- | The table does exist in the database + EntityExists ExistingEntitySchemaState + deriving (Eq, Show) + +-- | Information about an existing table in the database +data ExistingEntitySchemaState = ExistingEntitySchemaState { essColumns :: [Column] -- ^ The columns in this entity , essConstraints :: Map ConstraintNameDB [FieldNameDB] @@ -46,35 +55,51 @@ data EntitySchemaState = EntitySchemaState deriving (Eq, Show) -- | Query a database in order to assemble a SchemaState containing information --- about each of the entities in the given list. +-- about each of the entities in the given list. Every entity name in the input +-- should be present in the returned Map. collectSchemaState :: (Text -> IO Statement) -> [EntityNameDB] -> IO (Either Text SchemaState) collectSchemaState getStmt entityNames = runExceptT $ do + existence <- getTableExistence getStmt entityNames columns <- getColumnsWithoutReferences getStmt entityNames constraints <- getConstraints getStmt entityNames foreignKeyReferences <- getForeignKeyReferences getStmt entityNames fmap (SchemaState . Map.fromList) $ for entityNames $ \entityNameDB -> do - let - addColumnReference column = - column - { cReference = Map.lookup (cName column) =<< Map.lookup entityNameDB foreignKeyReferences - } - - essColumns <- case Map.lookup entityNameDB columns of - Just cols -> - pure (map addColumnReference cols) + tableExists <- case Map.lookup entityNameDB existence of + Just e -> pure e Nothing -> throwError - ("Missing entity name from columns map: " <> unEntityNameDB entityNameDB) + ("Missing entity name from existence map: " <> unEntityNameDB entityNameDB) - let - essConstraints = fromMaybe Map.empty (Map.lookup entityNameDB constraints) - pure - ( entityNameDB - , EntitySchemaState{essColumns, essConstraints} - ) + if tableExists + then do + let + addColumnReference column = + column + { cReference = + Map.lookup (cName column) =<< Map.lookup entityNameDB foreignKeyReferences + } + + essColumns <- case Map.lookup entityNameDB columns of + Just cols -> + pure (map addColumnReference cols) + Nothing -> + throwError + ("Missing entity name from columns map: " <> unEntityNameDB entityNameDB) + + let + essConstraints = fromMaybe Map.empty (Map.lookup entityNameDB constraints) + pure + ( entityNameDB + , EntityExists $ ExistingEntitySchemaState{essColumns, essConstraints} + ) + else + pure + ( entityNameDB + , EntityDoesNotExist + ) runStmt :: (Show a) @@ -91,6 +116,44 @@ runStmt getStmt sql values process = do (\src -> runConduit $ src .| CL.map process .| CL.consume) pure results +-- | Check for the existence of each of the input tables. The keys in the +-- returned Map are exactly the entity names in the argument; True means the +-- table exists. +getTableExistence + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB Bool) +getTableExistence getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getTableExistenceSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processTable + case partitionEithers results of + ([], xs) -> + let + existing = Set.fromList xs + in + pure $ Map.fromList $ map (\n -> (n, Set.member n existing)) entityNames + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getTableExistenceSql = + "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" + <> " AND schemaname != 'information_schema' AND tablename=ANY (?)" + + processTable :: [PersistValue] -> Either Text EntityNameDB + processTable resultRow = do + fmap EntityNameDB $ + case resultRow of + [PersistText tableName] -> + pure tableName + [PersistByteString tableName] -> + pure (T.decodeUtf8 tableName) + other -> + throwError $ T.pack $ "Invalid result from information_schema: " ++ show other + -- | Get all columns for the listed tables from the database, ignoring foreign -- key references (those are filled in later). getColumnsWithoutReferences @@ -350,7 +413,13 @@ getForeignKeyReferences getStmt entityNames = do :: [PersistValue] -> Either Text (Map EntityNameDB (Map FieldNameDB ColumnReference)) processForeignKeyReference resultRow = do - (sourceTableName, sourceColumnName, refTableName, constraintName, updRule, delRule) <- + ( sourceTableName + , sourceColumnName + , refTableName + , constraintName + , updRule + , delRule + ) <- case resultRow of [ PersistText srcTable , PersistText srcColumn @@ -373,14 +442,17 @@ getForeignKeyReferences getStmt entityNames = do fcOnUpdate <- parseCascade updRule fcOnDelete <- parseCascade delRule - let columnRef = ColumnReference - { crTableName = refTableName - , crConstraintName = constraintName - , crFieldCascade = FieldCascade - { fcOnUpdate = Just fcOnUpdate - , fcOnDelete = Just fcOnDelete + let + columnRef = + ColumnReference + { crTableName = refTableName + , crConstraintName = constraintName + , crFieldCascade = + FieldCascade + { fcOnUpdate = Just fcOnUpdate + , fcOnDelete = Just fcOnDelete + } } - } pure $ Map.singleton sourceTableName (Map.singleton sourceColumnName columnRef) diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index d9ae9d93e..c0aa85701 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -62,6 +62,7 @@ test-suite test ImplicitUuidSpec JSONTest MigrationReferenceSpec + MigrationSpec PgInit PgIntervalTest UpsertWhere diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs new file mode 100644 index 000000000..2b88f5674 --- /dev/null +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -0,0 +1,259 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} + +module MigrationSpec where + +import PgInit + +import qualified Data.Map as Map +import Database.Persist.Postgresql.Internal.Migration +import qualified Database.Persist.SqlBackend.Internal as SqlBackend + +runConnPrepare + :: (MonadIO m) => ((Text -> IO Statement) -> a -> IO b) -> a -> SqlPersistT m b +runConnPrepare inner arg = do + backend <- ask + liftIO $ inner (SqlBackend.connPrepare backend) arg + +spec :: Spec +spec = describe "MigrationSpec" $ do + it "works" $ runConnAssert $ do + let + rawEx sql = rawExecute sql [] + rawEx + "CREATE TABLE users(id serial primary key, name text not null, title text);" + rawEx + "CREATE TABLE user_friendships(id serial primary key, user_1_id int references users(id), user_2_id int references users(id));" + rawEx + "CREATE TABLE passwords(id serial primary key, password_hash text, user_id int unique references users(id));" + rawEx + "CREATE TABLE passwords_2(id serial primary key, password_hash text, user_id int unique references users(id));" + rawEx "CREATE TABLE ignored(id serial primary key);" + + actual <- + runConnPrepare collectSchemaState $ + map + EntityNameDB + [ "users" + , "user_friendships" + , "passwords" + , "passwords_2" + , "nonexistent" + ] + + let + expected = + SchemaState $ + Map.fromList + [ (EntityNameDB{unEntityNameDB = "nonexistent"}, EntityDoesNotExist) + , + ( EntityNameDB{unEntityNameDB = "passwords"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + [ Column + { cName = FieldNameDB{unFieldNameDB = "user_id"} + , cNull = True + , cSqlType = SqlInt32 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = + Just + ( ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "passwords_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + } + ) + } + , Column + { cName = FieldNameDB{unFieldNameDB = "password_hash"} + , cNull = True + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt32 + , cDefault = Just "nextval('passwords_id_seq'::regclass)" + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ] + , essConstraints = + Map.fromList + [ + ( ConstraintNameDB{unConstraintNameDB = "passwords_user_id_key"} + , [FieldNameDB{unFieldNameDB = "user_id"}] + ) + ] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "passwords_2"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + [ Column + { cName = FieldNameDB{unFieldNameDB = "user_id"} + , cNull = True + , cSqlType = SqlInt32 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = + Just + ( ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + } + ) + } + , Column + { cName = FieldNameDB{unFieldNameDB = "password_hash"} + , cNull = True + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt32 + , cDefault = Just "nextval('passwords_2_id_seq'::regclass)" + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ] + , essConstraints = + Map.fromList + [ + ( ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_key"} + , [FieldNameDB{unFieldNameDB = "user_id"}] + ) + ] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "user_friendships"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + [ Column + { cName = FieldNameDB{unFieldNameDB = "user_2_id"} + , cNull = True + , cSqlType = SqlInt32 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = + Just + ( ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "user_friendships_user_2_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + } + ) + } + , Column + { cName = FieldNameDB{unFieldNameDB = "user_1_id"} + , cNull = True + , cSqlType = SqlInt32 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = + Just + ( ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "user_friendships_user_1_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + } + ) + } + , Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt32 + , cDefault = Just "nextval('user_friendships_id_seq'::regclass)" + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ] + , essConstraints = Map.fromList [] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "users"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + [ Column + { cName = FieldNameDB{unFieldNameDB = "title"} + , cNull = True + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Column + { cName = FieldNameDB{unFieldNameDB = "name"} + , cNull = False + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt32 + , cDefault = Just "nextval('users_id_seq'::regclass)" + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ] + , essConstraints = Map.fromList [] + } + ) + ) + ] + + actual `shouldBe` Right expected diff --git a/persistent-postgresql/test/main.hs b/persistent-postgresql/test/main.hs index 25d83a428..05c34dc2c 100644 --- a/persistent-postgresql/test/main.hs +++ b/persistent-postgresql/test/main.hs @@ -45,6 +45,7 @@ import qualified MaybeFieldDefsTest import qualified MigrationColumnLengthTest import qualified MigrationOnlyTest import qualified MigrationReferenceSpec +import qualified MigrationSpec import qualified MigrationTest import qualified MpsCustomPrefixTest import qualified MpsNoPrefixTest @@ -151,6 +152,7 @@ main = do hspec $ do ImplicitUuidSpec.spec MigrationReferenceSpec.spec + MigrationSpec.spec RenameTest.specsWith runConnAssert DataTypeTest.specsWith runConnAssert From 4f9c32d7698a56888ee0b916833e473f46289a5b Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 11:53:28 +0000 Subject: [PATCH 03/14] Move diffing logic over --- .../Persist/Postgresql/Internal/Migration.hs | 613 ++++++++++++++++++ 1 file changed, 613 insertions(+) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index d20df0827..0d53fb86e 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} @@ -8,6 +9,7 @@ module Database.Persist.Postgresql.Internal.Migration where import Control.Arrow +import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Data.Acquire (with) @@ -23,6 +25,9 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.Traversable import Database.Persist.Sql +import qualified Database.Persist.Sql.Util as Util +import qualified Data.List.NonEmpty as NEL +import Data.List as List -- | In order to ensure that generating migrations is fast and avoids N+1 -- queries, we split it into two phases. The first phase involves querying the @@ -475,3 +480,611 @@ parseCascade txt = mapLeft :: (a1 -> a2) -> Either a1 b -> Either a2 b mapLeft _ (Right x) = Right x mapLeft f (Left x) = Left (f x) + +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in +-- Haskell. +migrateEntitiesStructured + :: (Text -> IO Statement) + -> [EntityDef] + -> [EntityDef] + -> IO (Either [Text] [AlterDB]) +migrateEntitiesStructured getStmt allDefs defsToMigrate = do + r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) + pure $ case r of + Right schemaState -> + migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate + Left err -> + Left [err] + +migrateEntitiesFromSchemaState + :: SchemaState + -> [EntityDef] + -> [EntityDef] + -> Either [Text] [AlterDB] +migrateEntitiesFromSchemaState (SchemaState schemaStateMap) allDefs defsToMigrate = + let go :: EntityDef -> Either Text [AlterDB] + go entity = do + let name = getEntityDBName entity + case Map.lookup name schemaStateMap of + Just entityState -> + Right $ migrateEntityFromSchemaState entityState allDefs entity + Nothing -> + Left $ T.pack $ "No entry for entity in schemaState: " <> show name + in + case partitionEithers (map go defsToMigrate) of + ([], xs) -> Right (concat xs) + (errs, _) -> Left errs + +migrateEntityFromSchemaState + :: EntitySchemaState + -> [EntityDef] + -> EntityDef + -> [AlterDB] +migrateEntityFromSchemaState schemaState allDefs entity = + case schemaState of + EntityDoesNotExist -> + (addTable newcols entity) : uniques ++ references ++ foreignsAlt + EntityExists ExistingEntitySchemaState { essColumns, essConstraints } -> + let + (acs, ats) = + getAlters + allDefs + entity + (newcols, udspair) + (essColumns, Map.toList essConstraints) + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + in + acs' ++ ats' + + where + name = getEntityDBName entity + (newcols', udefs, fdefs) = postgresMkColumns allDefs entity + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + + uniques = flip concatMap udspair $ \(uname, ucols) -> + [AlterTable name $ AddUniqueConstraint uname ucols] + references = + mapMaybe + ( \Column{cName, cReference} -> + getAddReference allDefs entity cName =<< cReference + ) + newcols + foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs + + +-- | Indicates whether a Postgres Column is safe to drop. +-- +-- @since 2.17.1.0 +newtype SafeToRemove = SafeToRemove Bool + deriving (Show, Eq) + +-- | Represents a change to a Postgres column in a DB statement. +-- +-- @since 2.17.1.0 +data AlterColumn + = ChangeType Column SqlType Text + | IsNull Column + | NotNull Column + | AddColumn Column + | Drop Column SafeToRemove + | Default Column Text + | NoDefault Column + | UpdateNullToValue Column Text + | AddReference + EntityNameDB + ConstraintNameDB + (NEL.NonEmpty FieldNameDB) + [Text] + FieldCascade + | DropReference ConstraintNameDB + deriving (Show, Eq) + +-- | Represents a change to a Postgres table in a DB statement. +-- +-- @since 2.17.1.0 +data AlterTable + = AddUniqueConstraint ConstraintNameDB [FieldNameDB] + | DropConstraint ConstraintNameDB + deriving (Show, Eq) + +-- | Represents a change to a Postgres DB in a statement. +-- +-- @since 2.17.1.0 +data AlterDB + = AddTable EntityNameDB EntityIdDef [Column] + | AlterColumn EntityNameDB AlterColumn + | AlterTable EntityNameDB AlterTable + deriving (Show, Eq) + +-- | Create a table if it doesn't exist. +-- +-- @since 2.17.1.0 +addTable :: [Column] -> EntityDef -> AlterDB +addTable cols entity = + AddTable name entityId nonIdCols + where + nonIdCols = + case entityPrimary entity of + Just _ -> + cols + _ -> + filter keepField cols + where + keepField c = + Just (cName c) /= fmap fieldDB (getEntityIdField entity) + && not (safeToRemove entity (cName c)) + entityId = getEntityId entity + name = getEntityDBName entity + +maySerial :: SqlType -> Maybe Text -> Text +maySerial SqlInt64 Nothing = " SERIAL8 " +maySerial sType _ = " " <> showSqlType sType + +mayDefault :: Maybe Text -> Text +mayDefault def = case def of + Nothing -> "" + Just d -> " DEFAULT " <> d + +getAlters + :: [EntityDef] + -> EntityDef + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([AlterColumn], [AlterTable]) +getAlters defs def (c1, u1) (c2, u2) = + (getAltersC c1 c2, getAltersU u1 u2) + where + getAltersC [] old = + map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old + getAltersC (new : news) old = + let + (alters, old') = findAlters defs def new old + in + alters ++ getAltersC news old' + + getAltersU + :: [(ConstraintNameDB, [FieldNameDB])] + -> [(ConstraintNameDB, [FieldNameDB])] + -> [AlterTable] + getAltersU [] old = + map DropConstraint $ filter (not . isManual) $ map fst old + getAltersU ((name, cols) : news) old = + case lookup name old of + Nothing -> + AddUniqueConstraint name cols : getAltersU news old + Just ocols -> + let + old' = filter (\(x, _) -> x /= name) old + in + if sort cols == sort ocols + then getAltersU news old' + else + DropConstraint name + : AddUniqueConstraint name cols + : getAltersU news old' + + -- Don't drop constraints which were manually added. + isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x + +-- | Postgres' default maximum identifier length in bytes +-- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). +-- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +maximumIdentifierLength :: Int +maximumIdentifierLength = 63 + +-- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer +sqlTypeEq :: SqlType -> SqlType -> Bool +sqlTypeEq x y = + let + -- Non exhaustive helper to map postgres aliases to the same name. Based on + -- https://www.postgresql.org/docs/9.5/datatype.html. + -- This prevents needless `ALTER TYPE`s when the type is the same. + normalize "int8" = "bigint" + normalize "serial8" = "bigserial" + normalize v = v + in + normalize (T.toCaseFold (showSqlType x)) + == normalize (T.toCaseFold (showSqlType y)) + +-- We check if we should alter a foreign key. This is almost an equality check, +-- except we consider 'Nothing' and 'Just Restrict' equivalent. +equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool +equivalentRef Nothing Nothing = True +equivalentRef (Just cr1) (Just cr2) = + crTableName cr1 == crTableName cr2 + && crConstraintName cr1 == crConstraintName cr2 + && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) + && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) + where + eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool + eqCascade Nothing Nothing = True + eqCascade Nothing (Just Restrict) = True + eqCascade (Just Restrict) Nothing = True + eqCascade (Just cs1) (Just cs2) = cs1 == cs2 + eqCascade _ _ = False +equivalentRef _ _ = False + +refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB +refName (EntityNameDB table) (FieldNameDB column) = + let + overhead = T.length $ T.concat ["_", "_fkey"] + (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) + in + ConstraintNameDB $ + T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] + where + -- Postgres automatically truncates too long foreign keys to a combination of + -- truncatedTableName + "_" + truncatedColumnName + "_fkey" + -- This works fine for normal use cases, but it creates an issue for Persistent + -- Because after running the migrations, Persistent sees the truncated foreign key constraint + -- doesn't have the expected name, and suggests that you migrate again + -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. + -- + -- I believe this will also be an issue for extremely long table names, + -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length + + -- Approximation of the algorithm Postgres uses to truncate identifiers + -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 + shortenNames :: Int -> (Int, Int) -> (Int, Int) + shortenNames overhead (x, y) + | x + y + overhead <= maximumIdentifierLength = (x, y) + | x > y = shortenNames overhead (x - 1, y) + | otherwise = shortenNames overhead (x, y - 1) + +postgresMkColumns + :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) +postgresMkColumns allDefs t = + mkColumns allDefs t $ + setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides + +-- | Check if a column name is listed as the "safe to remove" in the entity +-- list. +safeToRemove :: EntityDef -> FieldNameDB -> Bool +safeToRemove def (FieldNameDB colName) = + any (elem FieldAttrSafeToRemove . fieldAttrs) $ + filter ((== FieldNameDB colName) . fieldDB) $ + allEntityFields + where + allEntityFields = + getEntityFieldsDatabase def <> case getEntityId def of + EntityIdField fdef -> + [fdef] + _ -> + [] + +udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) +udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) + +-- | Get the references to be added to a table for the given column. +getAddReference + :: [EntityDef] + -> EntityDef + -> FieldNameDB + -> ColumnReference + -> Maybe AlterDB +getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConstraintName = constraintName} = do + guard $ Just cname /= fmap fieldDB (getEntityIdField entity) + pure $ + AlterColumn + table + (AddReference s constraintName (cname NEL.:| []) id_ (crFieldCascade cr)) + where + table = getEntityDBName entity + id_ = + fromMaybe + (error $ "Could not find ID of entity " ++ show s) + $ do + entDef <- find ((== s) . getEntityDBName) allDefs + return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef + +mkForeignAlt + :: EntityDef + -> ForeignDef + -> Maybe AlterDB +mkForeignAlt entity fdef = case NEL.nonEmpty childfields of + Nothing -> Nothing + Just childfields' -> Just $ AlterColumn tableName_ addReference + where + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + childfields' + escapedParentFields + (foreignFieldCascade fdef) + where + tableName_ = getEntityDBName entity + constraintName = + foreignConstraintNameDBName fdef + (childfields, parentfields) = + unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) + escapedParentFields = + map escapeF parentfields + +escapeC :: ConstraintNameDB -> Text +escapeC = escapeWith escape + +escapeE :: EntityNameDB -> Text +escapeE = escapeWith escape + +escapeF :: FieldNameDB -> Text +escapeF = escapeWith escape + +escape :: Text -> Text +escape s = + T.pack $ '"' : go (T.unpack s) ++ "\"" + where + go "" = "" + go ('"' : xs) = "\"\"" ++ go xs + go (x : xs) = x : go xs + +showAlterDb :: AlterDB -> (Bool, Text) +showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) + where + idtxt = + case entityId of + EntityIdNaturalKey pdef -> + T.concat + [ " PRIMARY KEY (" + , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef + , ")" + ] + EntityIdField field -> + let + defText = defaultAttribute $ fieldAttrs field + sType = fieldSqlType field + in + T.concat + [ escapeF $ fieldDB field + , maySerial sType defText + , " PRIMARY KEY UNIQUE" + , mayDefault defText + ] + rawText = + T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] +showAlterDb (AlterColumn t ac) = + (isUnsafe ac, showAlter t ac) + where + isUnsafe (Drop _ (SafeToRemove safeRemove)) = not safeRemove + isUnsafe _ = False +showAlterDb (AlterTable t at) = (False, showAlterTable t at) + +showAlterTable :: EntityNameDB -> AlterTable -> Text +showAlterTable table (AddUniqueConstraint cname cols) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC cname + , " UNIQUE(" + , T.intercalate "," $ map escapeF cols + , ")" + ] +showAlterTable table (DropConstraint cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + +showAlter :: EntityNameDB -> AlterColumn -> Text +showAlter table (ChangeType c t extra) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " TYPE " + , showSqlType t + , extra + ] +showAlter table (IsNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP NOT NULL" + ] +showAlter table (NotNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET NOT NULL" + ] +showAlter table (AddColumn col) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD COLUMN " + , showColumn col + ] +showAlter table (Drop c _) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP COLUMN " + , escapeF (cName c) + ] +showAlter table (Default c s) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET DEFAULT " + , s + ] +showAlter table (NoDefault c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP DEFAULT" + ] +showAlter table (UpdateNullToValue c s) = + T.concat + [ "UPDATE " + , escapeE table + , " SET " + , escapeF (cName c) + , "=" + , s + , " WHERE " + , escapeF (cName c) + , " IS NULL" + ] +showAlter table (AddReference reftable fkeyname t2 id2 cascade) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC fkeyname + , " FOREIGN KEY(" + , T.intercalate "," $ map escapeF $ NEL.toList t2 + , ") REFERENCES " + , escapeE reftable + , "(" + , T.intercalate "," id2 + , ")" + ] + <> renderFieldCascade cascade +showAlter table (DropReference cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + +showColumn :: Column -> Text +showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = + T.concat + [ escapeF n + , " " + , showSqlType sqlType' + , " " + , if nu then "NULL" else "NOT NULL" + , case def of + Nothing -> "" + Just s -> " DEFAULT " <> s + , case gen of + Nothing -> "" + Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" + ] + +showSqlType :: SqlType -> Text +showSqlType SqlString = "VARCHAR" +showSqlType SqlInt32 = "INT4" +showSqlType SqlInt64 = "INT8" +showSqlType SqlReal = "DOUBLE PRECISION" +showSqlType (SqlNumeric s prec) = T.concat ["NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")"] +showSqlType SqlDay = "DATE" +showSqlType SqlTime = "TIME" +showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" +showSqlType SqlBlob = "BYTEA" +showSqlType SqlBool = "BOOLEAN" +-- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 +showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" +showSqlType (SqlOther t) = t + +findAlters + :: [EntityDef] + -- ^ The list of all entity definitions that persistent is aware of. + -> EntityDef + -- ^ The entity definition for the entity that we're working on. + -> Column + -- ^ The column that we're searching for potential alterations for. + -> [Column] + -> ([AlterColumn], [Column]) +findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName _maxLen ref) cols = + case List.find (\c -> cName c == name) cols of + Nothing -> + ([AddColumn col], cols) + Just + (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> + let + refDrop Nothing = [] + refDrop (Just ColumnReference{crConstraintName = cname}) = + [DropReference cname] + + refAdd Nothing = [] + refAdd (Just colRef) = + case find ((== crTableName colRef) . getEntityDBName) defs of + Just refdef + | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> + [ AddReference + (crTableName colRef) + (crConstraintName colRef) + (name NEL.:| []) + (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) + (crFieldCascade colRef) + ] + Just _ -> [] + Nothing -> + error $ + "could not find the entityDef for reftable[" + ++ show (crTableName colRef) + ++ "]" + modRef = + if equivalentRef ref ref' + then [] + else refDrop ref' ++ refAdd ref + modNull = case (isNull, isNull') of + (True, False) -> do + guard $ Just name /= fmap fieldDB (getEntityIdField edef) + pure (IsNull col) + (False, True) -> + let + up = case def of + Nothing -> id + Just s -> (:) (UpdateNullToValue col s) + in + up [NotNull col] + _ -> [] + modType + | sqlTypeEq sqltype sqltype' = [] + -- When converting from Persistent pre-2.0 databases, we + -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is + -- treated as UTC. + | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = + [ ChangeType col sqltype $ + T.concat + [ " USING " + , escapeF name + , " AT TIME ZONE 'UTC'" + ] + ] + | otherwise = [ChangeType col sqltype ""] + modDef = + if def == def' + || isJust (T.stripPrefix "nextval" =<< def') + then [] + else case def of + Nothing -> [NoDefault col] + Just s -> [Default col s] + dropSafe = + if safeToRemove edef name + then error "wtf" [Drop col (SafeToRemove True)] + else [] + in + ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe + , filter (\c -> cName c /= name) cols + ) From b441e183950fae9da13d9eb304314dbef5e7726a Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 11:58:28 +0000 Subject: [PATCH 04/14] Replace the impl --- .../Database/Persist/Postgresql/Internal.hs | 947 +----------------- 1 file changed, 7 insertions(+), 940 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 1ceaf6971..7e99a29e1 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -42,38 +42,20 @@ import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Types as PG import qualified Blaze.ByteString.Builder.Char8 as BBB -import Control.Arrow import Control.Monad -import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadIO (..)) -import Control.Monad.Trans.Class (lift) -import Data.Acquire (with) -import Data.Bits (toIntegralSized) import Data.ByteString (ByteString) import qualified Data.ByteString.Builder as BB -import Data.Conduit -import qualified Data.Conduit.List as CL -import Data.Data (Typeable) -import Data.Either (partitionEithers) -import Data.Fixed (Fixed (..), Micro, Pico) -import Data.Function (on) +import Data.Fixed (Pico) import qualified Data.IntMap as I -import Data.List as List (find, foldl', groupBy, sort) -import qualified Data.List.NonEmpty as NEL -import qualified Data.Map as Map import Data.Maybe import Data.Text (Text) -import qualified Data.Text as T -import qualified Data.Text.Encoding as T import Data.Time ( NominalDiffTime , localTimeToUTC - , nominalDiffTimeToSeconds - , secondsToNominalDiffTime , utc ) import Database.Persist.Sql -import qualified Database.Persist.Sql.Util as Util +import Database.Persist.Postgresql.Internal.Migration -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- @@ -127,7 +109,7 @@ instance PGTF.ToField Unknown where toField (Unknown a) = PGTF.Escape a newtype UnknownLiteral = UnknownLiteral {unUnknownLiteral :: ByteString} - deriving (Eq, Show, Read, Ord, Typeable) + deriving (Eq, Show, Read, Ord) instance PGFF.FromField UnknownLiteral where fromField f mdata = @@ -280,50 +262,6 @@ intervalToPgInterval interval = then Just $ PgInterval nominalDiffTime else Nothing --- | Indicates whether a Postgres Column is safe to drop. --- --- @since 2.17.1.0 -newtype SafeToRemove = SafeToRemove Bool - deriving (Show, Eq) - --- | Represents a change to a Postgres column in a DB statement. --- --- @since 2.17.1.0 -data AlterColumn - = ChangeType Column SqlType Text - | IsNull Column - | NotNull Column - | AddColumn Column - | Drop Column SafeToRemove - | Default Column Text - | NoDefault Column - | UpdateNullToValue Column Text - | AddReference - EntityNameDB - ConstraintNameDB - (NEL.NonEmpty FieldNameDB) - [Text] - FieldCascade - | DropReference ConstraintNameDB - deriving (Show, Eq) - --- | Represents a change to a Postgres table in a DB statement. --- --- @since 2.17.1.0 -data AlterTable - = AddUniqueConstraint ConstraintNameDB [FieldNameDB] - | DropConstraint ConstraintNameDB - deriving (Show, Eq) - --- | Represents a change to a Postgres DB in a statement. --- --- @since 2.17.1.0 -data AlterDB - = AddTable EntityNameDB EntityIdDef [Column] - | AlterColumn EntityNameDB AlterColumn - | AlterTable EntityNameDB AlterTable - deriving (Show, Eq) - -- | Returns a structured representation of all of the -- DB changes required to migrate the Entity from its -- current state in the database to the state described in @@ -335,49 +273,8 @@ migrateStructured -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [AlterDB]) -migrateStructured allDefs getter entity = do - old <- getColumns getter entity newcols' - case partitionEithers old of - ([], old'') -> do - exists' <- - if null old - then doesTableExist getter name - else return True - return $ Right $ migrationText exists' old'' - (errs, _) -> return $ Left errs - where - name = getEntityDBName entity - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity - migrationText exists' old'' - | not exists' = - createText newcols fdefs udspair - | otherwise = - let - (acs, ats) = - getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats - in - acs' ++ ats' - where - old' = partitionEithers old'' - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 - - createText newcols fdefs_ udspair = - (addTable newcols entity) : uniques ++ references ++ foreignsAlt - where - uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] - references = - mapMaybe - ( \Column{cName, cReference} -> - getAddReference allDefs entity cName =<< cReference - ) - newcols - foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs_ +migrateStructured allDefs getter entity = + migrateEntitiesStructured getter allDefs [entity] -- | Returns a structured representation of all of the -- DB changes required to migrate the Entity to the state @@ -389,835 +286,5 @@ mockMigrateStructured :: [EntityDef] -> EntityDef -> [AlterDB] -mockMigrateStructured allDefs entity = migrationText - where - name = getEntityDBName entity - migrationText = createText newcols fdefs udspair - where - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 - - createText newcols fdefs udspair = - (addTable newcols entity) : uniques ++ references ++ foreignsAlt - where - uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] - references = - mapMaybe - ( \Column{cName, cReference} -> - getAddReference allDefs entity cName =<< cReference - ) - newcols - foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs - --- | Returns a structured representation of all of the --- DB changes required to migrate the Entity from its current state --- in the database to the state described in Haskell. --- --- @since 2.17.1.0 -addTable :: [Column] -> EntityDef -> AlterDB -addTable cols entity = - AddTable name entityId nonIdCols - where - nonIdCols = - case entityPrimary entity of - Just _ -> - cols - _ -> - filter keepField cols - where - keepField c = - Just (cName c) /= fmap fieldDB (getEntityIdField entity) - && not (safeToRemove entity (cName c)) - entityId = getEntityId entity - name = getEntityDBName entity - -maySerial :: SqlType -> Maybe Text -> Text -maySerial SqlInt64 Nothing = " SERIAL8 " -maySerial sType _ = " " <> showSqlType sType - -mayDefault :: Maybe Text -> Text -mayDefault def = case def of - Nothing -> "" - Just d -> " DEFAULT " <> d - -getAlters - :: [EntityDef] - -> EntityDef - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([AlterColumn], [AlterTable]) -getAlters defs def (c1, u1) (c2, u2) = - (getAltersC c1 c2, getAltersU u1 u2) - where - getAltersC [] old = - map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old - getAltersC (new : news) old = - let - (alters, old') = findAlters defs def new old - in - alters ++ getAltersC news old' - - getAltersU - :: [(ConstraintNameDB, [FieldNameDB])] - -> [(ConstraintNameDB, [FieldNameDB])] - -> [AlterTable] - getAltersU [] old = - map DropConstraint $ filter (not . isManual) $ map fst old - getAltersU ((name, cols) : news) old = - case lookup name old of - Nothing -> - AddUniqueConstraint name cols : getAltersU news old - Just ocols -> - let - old' = filter (\(x, _) -> x /= name) old - in - if sort cols == sort ocols - then getAltersU news old' - else - DropConstraint name - : AddUniqueConstraint name cols - : getAltersU news old' - - -- Don't drop constraints which were manually added. - isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x - --- | Postgres' default maximum identifier length in bytes --- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). --- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -maximumIdentifierLength :: Int -maximumIdentifierLength = 63 - --- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer -sqlTypeEq :: SqlType -> SqlType -> Bool -sqlTypeEq x y = - let - -- Non exhaustive helper to map postgres aliases to the same name. Based on - -- https://www.postgresql.org/docs/9.5/datatype.html. - -- This prevents needless `ALTER TYPE`s when the type is the same. - normalize "int8" = "bigint" - normalize "serial8" = "bigserial" - normalize v = v - in - normalize (T.toCaseFold (showSqlType x)) - == normalize (T.toCaseFold (showSqlType y)) - --- We check if we should alter a foreign key. This is almost an equality check, --- except we consider 'Nothing' and 'Just Restrict' equivalent. -equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool -equivalentRef Nothing Nothing = True -equivalentRef (Just cr1) (Just cr2) = - crTableName cr1 == crTableName cr2 - && crConstraintName cr1 == crConstraintName cr2 - && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) - && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) - where - eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool - eqCascade Nothing Nothing = True - eqCascade Nothing (Just Restrict) = True - eqCascade (Just Restrict) Nothing = True - eqCascade (Just cs1) (Just cs2) = cs1 == cs2 - eqCascade _ _ = False -equivalentRef _ _ = False - -refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB -refName (EntityNameDB table) (FieldNameDB column) = - let - overhead = T.length $ T.concat ["_", "_fkey"] - (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) - in - ConstraintNameDB $ - T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] - where - -- Postgres automatically truncates too long foreign keys to a combination of - -- truncatedTableName + "_" + truncatedColumnName + "_fkey" - -- This works fine for normal use cases, but it creates an issue for Persistent - -- Because after running the migrations, Persistent sees the truncated foreign key constraint - -- doesn't have the expected name, and suggests that you migrate again - -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. - -- - -- I believe this will also be an issue for extremely long table names, - -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length - - -- Approximation of the algorithm Postgres uses to truncate identifiers - -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 - shortenNames :: Int -> (Int, Int) -> (Int, Int) - shortenNames overhead (x, y) - | x + y + overhead <= maximumIdentifierLength = (x, y) - | x > y = shortenNames overhead (x - 1, y) - | otherwise = shortenNames overhead (x, y - 1) - -postgresMkColumns - :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -postgresMkColumns allDefs t = - mkColumns allDefs t $ - setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides - --- | Check if a column name is listed as the "safe to remove" in the entity --- list. -safeToRemove :: EntityDef -> FieldNameDB -> Bool -safeToRemove def (FieldNameDB colName) = - any (elem FieldAttrSafeToRemove . fieldAttrs) $ - filter ((== FieldNameDB colName) . fieldDB) $ - allEntityFields - where - allEntityFields = - getEntityFieldsDatabase def <> case getEntityId def of - EntityIdField fdef -> - [fdef] - _ -> - [] - -udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) -udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) - --- | Get the references to be added to a table for the given column. -getAddReference - :: [EntityDef] - -> EntityDef - -> FieldNameDB - -> ColumnReference - -> Maybe AlterDB -getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConstraintName = constraintName} = do - guard $ Just cname /= fmap fieldDB (getEntityIdField entity) - pure $ - AlterColumn - table - (AddReference s constraintName (cname NEL.:| []) id_ (crFieldCascade cr)) - where - table = getEntityDBName entity - id_ = - fromMaybe - (error $ "Could not find ID of entity " ++ show s) - $ do - entDef <- find ((== s) . getEntityDBName) allDefs - return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef - -mkForeignAlt - :: EntityDef - -> ForeignDef - -> Maybe AlterDB -mkForeignAlt entity fdef = case NEL.nonEmpty childfields of - Nothing -> Nothing - Just childfields' -> Just $ AlterColumn tableName_ addReference - where - addReference = - AddReference - (foreignRefTableDBName fdef) - constraintName - childfields' - escapedParentFields - (foreignFieldCascade fdef) - where - tableName_ = getEntityDBName entity - constraintName = - foreignConstraintNameDBName fdef - (childfields, parentfields) = - unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) - escapedParentFields = - map escapeF parentfields - -escapeC :: ConstraintNameDB -> Text -escapeC = escapeWith escape - -escapeE :: EntityNameDB -> Text -escapeE = escapeWith escape - -escapeF :: FieldNameDB -> Text -escapeF = escapeWith escape - -escape :: Text -> Text -escape s = - T.pack $ '"' : go (T.unpack s) ++ "\"" - where - go "" = "" - go ('"' : xs) = "\"\"" ++ go xs - go (x : xs) = x : go xs - -showAlterDb :: AlterDB -> (Bool, Text) -showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) - where - idtxt = - case entityId of - EntityIdNaturalKey pdef -> - T.concat - [ " PRIMARY KEY (" - , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef - , ")" - ] - EntityIdField field -> - let - defText = defaultAttribute $ fieldAttrs field - sType = fieldSqlType field - in - T.concat - [ escapeF $ fieldDB field - , maySerial sType defText - , " PRIMARY KEY UNIQUE" - , mayDefault defText - ] - rawText = - T.concat - -- Lower case e: see Database.Persist.Sql.Migration - [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name - , "(" - , idtxt - , if null nonIdCols then "" else "," - , T.intercalate "," $ map showColumn nonIdCols - , ")" - ] -showAlterDb (AlterColumn t ac) = - (isUnsafe ac, showAlter t ac) - where - isUnsafe (Drop _ (SafeToRemove safeRemove)) = not safeRemove - isUnsafe _ = False -showAlterDb (AlterTable t at) = (False, showAlterTable t at) - -showAlterTable :: EntityNameDB -> AlterTable -> Text -showAlterTable table (AddUniqueConstraint cname cols) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC cname - , " UNIQUE(" - , T.intercalate "," $ map escapeF cols - , ")" - ] -showAlterTable table (DropConstraint cname) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] - -showAlter :: EntityNameDB -> AlterColumn -> Text -showAlter table (ChangeType c t extra) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " TYPE " - , showSqlType t - , extra - ] -showAlter table (IsNull c) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " DROP NOT NULL" - ] -showAlter table (NotNull c) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " SET NOT NULL" - ] -showAlter table (AddColumn col) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD COLUMN " - , showColumn col - ] -showAlter table (Drop c _) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP COLUMN " - , escapeF (cName c) - ] -showAlter table (Default c s) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " SET DEFAULT " - , s - ] -showAlter table (NoDefault c) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " DROP DEFAULT" - ] -showAlter table (UpdateNullToValue c s) = - T.concat - [ "UPDATE " - , escapeE table - , " SET " - , escapeF (cName c) - , "=" - , s - , " WHERE " - , escapeF (cName c) - , " IS NULL" - ] -showAlter table (AddReference reftable fkeyname t2 id2 cascade) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC fkeyname - , " FOREIGN KEY(" - , T.intercalate "," $ map escapeF $ NEL.toList t2 - , ") REFERENCES " - , escapeE reftable - , "(" - , T.intercalate "," id2 - , ")" - ] - <> renderFieldCascade cascade -showAlter table (DropReference cname) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] - -showColumn :: Column -> Text -showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = - T.concat - [ escapeF n - , " " - , showSqlType sqlType' - , " " - , if nu then "NULL" else "NOT NULL" - , case def of - Nothing -> "" - Just s -> " DEFAULT " <> s - , case gen of - Nothing -> "" - Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" - ] - -showSqlType :: SqlType -> Text -showSqlType SqlString = "VARCHAR" -showSqlType SqlInt32 = "INT4" -showSqlType SqlInt64 = "INT8" -showSqlType SqlReal = "DOUBLE PRECISION" -showSqlType (SqlNumeric s prec) = T.concat ["NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")"] -showSqlType SqlDay = "DATE" -showSqlType SqlTime = "TIME" -showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" -showSqlType SqlBlob = "BYTEA" -showSqlType SqlBool = "BOOLEAN" --- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 -showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" -showSqlType (SqlOther t) = t - -findAlters - :: [EntityDef] - -- ^ The list of all entity definitions that persistent is aware of. - -> EntityDef - -- ^ The entity definition for the entity that we're working on. - -> Column - -- ^ The column that we're searching for potential alterations for. - -> [Column] - -> ([AlterColumn], [Column]) -findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName _maxLen ref) cols = - case List.find (\c -> cName c == name) cols of - Nothing -> - ([AddColumn col], cols) - Just - (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> - let - refDrop Nothing = [] - refDrop (Just ColumnReference{crConstraintName = cname}) = - [DropReference cname] - - refAdd Nothing = [] - refAdd (Just colRef) = - case find ((== crTableName colRef) . getEntityDBName) defs of - Just refdef - | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> - [ AddReference - (crTableName colRef) - (crConstraintName colRef) - (name NEL.:| []) - (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) - (crFieldCascade colRef) - ] - Just _ -> [] - Nothing -> - error $ - "could not find the entityDef for reftable[" - ++ show (crTableName colRef) - ++ "]" - modRef = - if equivalentRef ref ref' - then [] - else refDrop ref' ++ refAdd ref - modNull = case (isNull, isNull') of - (True, False) -> do - guard $ Just name /= fmap fieldDB (getEntityIdField edef) - pure (IsNull col) - (False, True) -> - let - up = case def of - Nothing -> id - Just s -> (:) (UpdateNullToValue col s) - in - up [NotNull col] - _ -> [] - modType - | sqlTypeEq sqltype sqltype' = [] - -- When converting from Persistent pre-2.0 databases, we - -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is - -- treated as UTC. - | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = - [ ChangeType col sqltype $ - T.concat - [ " USING " - , escapeF name - , " AT TIME ZONE 'UTC'" - ] - ] - | otherwise = [ChangeType col sqltype ""] - modDef = - if def == def' - || isJust (T.stripPrefix "nextval" =<< def') - then [] - else case def of - Nothing -> [NoDefault col] - Just s -> [Default col s] - dropSafe = - if safeToRemove edef name - then error "wtf" [Drop col (SafeToRemove True)] - else [] - in - ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe - , filter (\c -> cName c /= name) cols - ) - --- | Returns all of the columns in the given table currently in the database. -getColumns - :: (Text -> IO Statement) - -> EntityDef - -> [Column] - -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] - -- ^ Left Text: error - -- Right (Left Column): a column - -- Right (Right ...): a constraint -getColumns getter def cols = do - let - sqlv = - T.concat - [ "SELECT " - , "column_name " - , ",is_nullable " - , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below - , ",column_default " - , ",generation_expression " - , ",numeric_precision " - , ",numeric_scale " - , ",character_maximum_length " - , "FROM information_schema.columns " - , "WHERE table_catalog=current_database() " - , "AND table_schema=current_schema() " - , "AND table_name=? " - ] - - -- DOMAINS Postgres supports the concept of domains, which are data types - -- with optional constraints. An app might make an "email" domain over the - -- varchar type, with a CHECK that the emails are valid In this case the - -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN - -- foo TYPE email This code exists to use the domain name (email), instead - -- of the underlying type (varchar). This is tested in - -- EquivalentTypeTest.hs - - stmt <- getter sqlv - let - vals = - [ PersistText $ unEntityNameDB $ getEntityDBName def - ] - columns <- - with - (stmtQuery stmt vals) - (\src -> runConduit $ src .| processColumns .| CL.consume) - let - sqlc = - T.concat - [ "SELECT " - , "c.constraint_name, " - , "c.column_name " - , "FROM information_schema.key_column_usage AS c, " - , "information_schema.table_constraints AS k " - , "WHERE c.table_catalog=current_database() " - , "AND c.table_catalog=k.table_catalog " - , "AND c.table_schema=current_schema() " - , "AND c.table_schema=k.table_schema " - , "AND c.table_name=? " - , "AND c.table_name=k.table_name " - , "AND c.constraint_name=k.constraint_name " - , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " - , "ORDER BY c.constraint_name, c.column_name" - ] - - stmt' <- getter sqlc - - us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU) - return $ columns ++ us - where - refMap = - fmap (\cr -> (crTableName cr, crConstraintName cr)) $ - Map.fromList $ - List.foldl' ref [] cols - where - ref rs c = - maybe rs (\r -> (unFieldNameDB $ cName c, r) : rs) (cReference c) - getAll = - CL.mapM $ \x -> - pure $ case x of - [PersistText con, PersistText col] -> - (con, col) - [PersistByteString con, PersistByteString col] -> - (T.decodeUtf8 con, T.decodeUtf8 col) - o -> - error $ "unexpected datatype returned for postgres o=" ++ show o - helperU = do - rows <- getAll .| CL.consume - return - $ map - (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) - $ groupBy ((==) `on` fst) rows - processColumns = - CL.mapM $ \x'@((PersistText cname) : _) -> do - col <- - liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) - pure $ case col of - Left e -> Left e - Right c -> Right $ Left c - -getColumn - :: (Text -> IO Statement) - -> EntityNameDB - -> [PersistValue] - -> Maybe (EntityNameDB, ConstraintNameDB) - -> IO (Either Text Column) -getColumn - getter - tableName' - [ PersistText columnName - , PersistText isNullable - , PersistText typeName - , defaultValue - , generationExpression - , numericPrecision - , numericScale - , maxlen - ] - refName_ = runExceptT $ do - defaultValue' <- - case defaultValue of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid default column: " ++ show defaultValue - - generationExpression' <- - case generationExpression of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression - - let - typeStr = - case maxlen of - PersistInt64 n -> - T.concat [typeName, "(", T.pack (show n), ")"] - _ -> - typeName - - t <- getType typeStr - - let - cname = FieldNameDB columnName - - ref <- lift $ fmap join $ traverse (getRef cname) refName_ - - return - Column - { cName = cname - , cNull = isNullable == "YES" - , cSqlType = t - , cDefault = fmap stripSuffixes defaultValue' - , cGenerated = fmap stripSuffixes generationExpression' - , cDefaultConstraintName = Nothing - , cMaxLen = Nothing - , cReference = fmap (\(a, b, c, d) -> ColumnReference a b (mkCascade c d)) ref - } - where - mkCascade updText delText = - FieldCascade - { fcOnUpdate = parseCascade updText - , fcOnDelete = parseCascade delText - } - - parseCascade txt = - case txt of - "NO ACTION" -> - Just NoAction - "CASCADE" -> - Just Cascade - "SET NULL" -> - Just SetNull - "SET DEFAULT" -> - Just SetDefault - "RESTRICT" -> - Just Restrict - _ -> - error $ "Unexpected value in parseCascade: " <> show txt - - stripSuffixes t = - loop' - [ "::character varying" - , "::text" - ] - where - loop' [] = t - loop' (p : ps) = - case T.stripSuffix p t of - Nothing -> loop' ps - Just t' -> t' - - getRef cname (_, refName') = do - let - sql = - T.concat - [ "SELECT DISTINCT " - , "ccu.table_name, " - , "tc.constraint_name, " - , "rc.update_rule, " - , "rc.delete_rule " - , "FROM information_schema.constraint_column_usage ccu " - , "INNER JOIN information_schema.key_column_usage kcu " - , " ON ccu.constraint_name = kcu.constraint_name " - , "INNER JOIN information_schema.table_constraints tc " - , " ON tc.constraint_name = kcu.constraint_name " - , "LEFT JOIN information_schema.referential_constraints AS rc" - , " ON rc.constraint_name = ccu.constraint_name " - , "WHERE tc.constraint_type='FOREIGN KEY' " - , "AND kcu.ordinal_position=1 " - , "AND kcu.table_name=? " - , "AND kcu.column_name=? " - , "AND tc.constraint_name=?" - ] - stmt <- getter sql - cntrs <- - with - ( stmtQuery - stmt - [ PersistText $ unEntityNameDB tableName' - , PersistText $ unFieldNameDB cname - , PersistText $ unConstraintNameDB refName' - ] - ) - (\src -> runConduit $ src .| CL.consume) - case cntrs of - [] -> - return Nothing - [ [ PersistText table - , PersistText constraint - , PersistText updRule - , PersistText delRule - ] - ] -> - return $ - Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) - xs -> - error $ - mconcat - [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " - , T.unpack (unEntityNameDB tableName') - , " and column: " - , T.unpack (unFieldNameDB cname) - , " but got: " - , show xs - ] - - getType "int4" = pure SqlInt32 - getType "int8" = pure SqlInt64 - getType "varchar" = pure SqlString - getType "text" = pure SqlString - getType "date" = pure SqlDay - getType "bool" = pure SqlBool - getType "timestamptz" = pure SqlDayTime - getType "float4" = pure SqlReal - getType "float8" = pure SqlReal - getType "bytea" = pure SqlBlob - getType "time" = pure SqlTime - getType "numeric" = getNumeric numericPrecision numericScale - getType a = pure $ SqlOther a - - getNumeric (PersistInt64 a) (PersistInt64 b) = - pure $ SqlNumeric (fromIntegral a) (fromIntegral b) - getNumeric PersistNull PersistNull = - throwError $ - T.concat - [ "No precision and scale were specified for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," - , " which is probably not what you intended." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] - getNumeric a b = - throwError $ - T.concat - [ "Can not get numeric field precision for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Expected an integer for both precision and scale, " - , "got: " - , T.pack $ show a - , " and " - , T.pack $ show b - , ", respectively." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] -getColumn _ _ columnName _ = - return $ - Left $ - T.pack $ - "Invalid result from information_schema: " ++ show columnName - -doesTableExist - :: (Text -> IO Statement) - -> EntityNameDB - -> IO Bool -doesTableExist getter (EntityNameDB name) = do - stmt <- getter sql - with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) - where - sql = - "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" - <> " AND schemaname != 'information_schema' AND tablename=?" - vals = [PersistText name] - - start = await >>= maybe (error "No results when checking doesTableExist") start' - start' [PersistInt64 0] = finish False - start' [PersistInt64 1] = finish True - start' res = error $ "doesTableExist returned unexpected result: " ++ show res - finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist") +mockMigrateStructured allDefs entity = + migrateEntityFromSchemaState EntityDoesNotExist allDefs entity \ No newline at end of file From d2a57a590c35bc6ca047b733d2839ad02e662b93 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 13:11:35 +0000 Subject: [PATCH 05/14] expand test --- .../Database/Persist/Postgresql/Internal.hs | 4 +- .../Persist/Postgresql/Internal/Migration.hs | 20 +- persistent-postgresql/test/MigrationSpec.hs | 227 ++++++++++++++---- 3 files changed, 188 insertions(+), 63 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 7e99a29e1..80fc7c25e 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -54,8 +54,8 @@ import Data.Time , localTimeToUTC , utc ) -import Database.Persist.Sql import Database.Persist.Postgresql.Internal.Migration +import Database.Persist.Sql -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- @@ -287,4 +287,4 @@ mockMigrateStructured -> EntityDef -> [AlterDB] mockMigrateStructured allDefs entity = - migrateEntityFromSchemaState EntityDoesNotExist allDefs entity \ No newline at end of file + migrateEntityFromSchemaState EntityDoesNotExist allDefs entity diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 0d53fb86e..7f2ea4d3f 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -1,8 +1,8 @@ {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} -- | Generate postgresql migrations for a set of EntityDefs, either from scratch -- or based on the current state of a database. @@ -16,6 +16,8 @@ import Data.Acquire (with) import Data.Conduit import qualified Data.Conduit.List as CL import Data.Either (partitionEithers) +import Data.List as List +import qualified Data.List.NonEmpty as NEL import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe @@ -26,8 +28,6 @@ import qualified Data.Text.Encoding as T import Data.Traversable import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util -import qualified Data.List.NonEmpty as NEL -import Data.List as List -- | In order to ensure that generating migrations is fast and avoids N+1 -- queries, we split it into two phases. The first phase involves querying the @@ -504,15 +504,17 @@ migrateEntitiesFromSchemaState -> [EntityDef] -> Either [Text] [AlterDB] migrateEntitiesFromSchemaState (SchemaState schemaStateMap) allDefs defsToMigrate = - let go :: EntityDef -> Either Text [AlterDB] + let + go :: EntityDef -> Either Text [AlterDB] go entity = do - let name = getEntityDBName entity + let + name = getEntityDBName entity case Map.lookup name schemaStateMap of Just entityState -> Right $ migrateEntityFromSchemaState entityState allDefs entity Nothing -> Left $ T.pack $ "No entry for entity in schemaState: " <> show name - in + in case partitionEithers (map go defsToMigrate) of ([], xs) -> Right (concat xs) (errs, _) -> Left errs @@ -526,7 +528,7 @@ migrateEntityFromSchemaState schemaState allDefs entity = case schemaState of EntityDoesNotExist -> (addTable newcols entity) : uniques ++ references ++ foreignsAlt - EntityExists ExistingEntitySchemaState { essColumns, essConstraints } -> + EntityExists ExistingEntitySchemaState{essColumns, essConstraints} -> let (acs, ats) = getAlters @@ -538,8 +540,7 @@ migrateEntityFromSchemaState schemaState allDefs entity = ats' = map (AlterTable name) ats in acs' ++ ats' - - where + where name = getEntityDBName entity (newcols', udefs, fdefs) = postgresMkColumns allDefs entity newcols = filter (not . safeToRemove entity . cName) newcols' @@ -555,7 +556,6 @@ migrateEntityFromSchemaState schemaState allDefs entity = newcols foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs - -- | Indicates whether a Postgres Column is safe to drop. -- -- @since 2.17.1.0 diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs index 2b88f5674..73f8031f3 100644 --- a/persistent-postgresql/test/MigrationSpec.hs +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -1,50 +1,157 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} module MigrationSpec where import PgInit import qualified Data.Map as Map +import Data.Proxy +import qualified Data.Text as T import Database.Persist.Postgresql.Internal.Migration import qualified Database.Persist.SqlBackend.Internal as SqlBackend -runConnPrepare - :: (MonadIO m) => ((Text -> IO Statement) -> a -> IO b) -> a -> SqlPersistT m b -runConnPrepare inner arg = do +getConnPrepare + :: (Monad m) => SqlPersistT m (Text -> IO Statement) +getConnPrepare = do backend <- ask - liftIO $ inner (SqlBackend.connPrepare backend) arg + pure (SqlBackend.connPrepare backend) + +-- NB: we do not perform these migrations in main.hs +share + [mkPersist persistSettings{mpsGeneric = False}, mkMigrate "migrate"] + [persistLowerCase| +User sql=users + name Text + title Text Maybe + deriving Show Eq + +UserFriendship sql=user_friendships + user1Id UserId Maybe + user2Id UserId Maybe + deriving Show Eq + +Password sql=passwords + passwordHash Text + userId UserId Maybe + UniqueUserId userId !force + +Password2 sql=passwords_2 + passwordHash Text + userId UserId Maybe OnDeleteCascade OnUpdateSetNull + UniqueUserId2 userId !force +|] + +userEntityDef :: EntityDef +userEntityDef = entityDef (Proxy :: Proxy User) + +userFriendshipEntityDef :: EntityDef +userFriendshipEntityDef = entityDef (Proxy :: Proxy UserFriendship) + +passwordEntityDef :: EntityDef +passwordEntityDef = entityDef (Proxy :: Proxy Password) + +password2EntityDef :: EntityDef +password2EntityDef = entityDef (Proxy :: Proxy Password2) + +allEntityDefs :: [EntityDef] +allEntityDefs = + [ userEntityDef + , userFriendshipEntityDef + , passwordEntityDef + , password2EntityDef + ] + +migrateManually :: (HasCallStack, MonadIO m) => SqlPersistT m () +migrateManually = do + cleanDB + let + rawEx sql = rawExecute sql [] + rawEx + "CREATE TABLE users(id int8 primary key, name text not null, title text);" + rawEx $ + T.concat + [ "CREATE TABLE user_friendships(" + , " id int8 primary key," + , " user1_id int8 references users(id) on delete restrict on update restrict," + , " user2_id int8 references users(id) on delete restrict on update restrict" + , ");" + ] + rawEx $ + T.concat + [ "CREATE TABLE passwords(" + , " id int8 primary key," + , " password_hash text not null," + , " user_id int8 references users(id) on delete restrict on update restrict" + , ");" + ] + rawEx $ + T.concat + [ "ALTER TABLE passwords" + , " ADD CONSTRAINT unique_user_id" + , " UNIQUE(user_id);" + ] + rawEx $ + T.concat + [ "CREATE TABLE passwords_2(" + , " id int8 primary key," + , " password_hash text not null," + , " user_id int8 references users(id) on delete cascade on update set null" + , ");" + ] + rawEx $ + T.concat + [ "ALTER TABLE passwords_2" + , " ADD CONSTRAINT unique_user_id2" + , " UNIQUE(user_id);" + ] + rawEx "CREATE TABLE ignored(id int8 primary key);" + +cleanDB :: (HasCallStack, MonadIO m) => SqlPersistT m () +cleanDB = do + let + rawEx sql = rawExecute sql [] + rawEx "DROP TABLE IF EXISTS user_friendships;" + rawEx "DROP TABLE IF EXISTS passwords;" + rawEx "DROP TABLE IF EXISTS passwords_2;" + rawEx "DROP TABLE IF EXISTS ignored;" + rawEx "DROP TABLE IF EXISTS users;" spec :: Spec spec = describe "MigrationSpec" $ do - it "works" $ runConnAssert $ do - let - rawEx sql = rawExecute sql [] - rawEx - "CREATE TABLE users(id serial primary key, name text not null, title text);" - rawEx - "CREATE TABLE user_friendships(id serial primary key, user_1_id int references users(id), user_2_id int references users(id));" - rawEx - "CREATE TABLE passwords(id serial primary key, password_hash text, user_id int unique references users(id));" - rawEx - "CREATE TABLE passwords_2(id serial primary key, password_hash text, user_id int unique references users(id));" - rawEx "CREATE TABLE ignored(id serial primary key);" + it "gathers schema state" $ runConnAssert $ do + migrateManually + connPrepare <- getConnPrepare actual <- - runConnPrepare collectSchemaState $ - map - EntityNameDB - [ "users" - , "user_friendships" - , "passwords" - , "passwords_2" - , "nonexistent" - ] + liftIO $ + collectSchemaState connPrepare $ + map + EntityNameDB + [ "users" + , "user_friendships" + , "passwords" + , "passwords_2" + , "nonexistent" + ] + + cleanDB let expected = - SchemaState $ - Map.fromList + SchemaState + ( Map.fromList [ (EntityNameDB{unEntityNameDB = "nonexistent"}, EntityDoesNotExist) , ( EntityNameDB{unEntityNameDB = "passwords"} @@ -54,7 +161,7 @@ spec = describe "MigrationSpec" $ do [ Column { cName = FieldNameDB{unFieldNameDB = "user_id"} , cNull = True - , cSqlType = SqlInt32 + , cSqlType = SqlInt64 , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing @@ -66,13 +173,13 @@ spec = describe "MigrationSpec" $ do , crConstraintName = ConstraintNameDB{unConstraintNameDB = "passwords_user_id_fkey"} , crFieldCascade = - FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} } ) } , Column { cName = FieldNameDB{unFieldNameDB = "password_hash"} - , cNull = True + , cNull = False , cSqlType = SqlString , cDefault = Nothing , cGenerated = Nothing @@ -83,8 +190,8 @@ spec = describe "MigrationSpec" $ do , Column { cName = FieldNameDB{unFieldNameDB = "id"} , cNull = False - , cSqlType = SqlInt32 - , cDefault = Just "nextval('passwords_id_seq'::regclass)" + , cSqlType = SqlInt64 + , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing , cMaxLen = Nothing @@ -94,7 +201,7 @@ spec = describe "MigrationSpec" $ do , essConstraints = Map.fromList [ - ( ConstraintNameDB{unConstraintNameDB = "passwords_user_id_key"} + ( ConstraintNameDB{unConstraintNameDB = "unique_user_id"} , [FieldNameDB{unFieldNameDB = "user_id"}] ) ] @@ -109,7 +216,7 @@ spec = describe "MigrationSpec" $ do [ Column { cName = FieldNameDB{unFieldNameDB = "user_id"} , cNull = True - , cSqlType = SqlInt32 + , cSqlType = SqlInt64 , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing @@ -121,13 +228,13 @@ spec = describe "MigrationSpec" $ do , crConstraintName = ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_fkey"} , crFieldCascade = - FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + FieldCascade{fcOnUpdate = Just SetNull, fcOnDelete = Just Cascade} } ) } , Column { cName = FieldNameDB{unFieldNameDB = "password_hash"} - , cNull = True + , cNull = False , cSqlType = SqlString , cDefault = Nothing , cGenerated = Nothing @@ -138,8 +245,8 @@ spec = describe "MigrationSpec" $ do , Column { cName = FieldNameDB{unFieldNameDB = "id"} , cNull = False - , cSqlType = SqlInt32 - , cDefault = Just "nextval('passwords_2_id_seq'::regclass)" + , cSqlType = SqlInt64 + , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing , cMaxLen = Nothing @@ -149,7 +256,7 @@ spec = describe "MigrationSpec" $ do , essConstraints = Map.fromList [ - ( ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_key"} + ( ConstraintNameDB{unConstraintNameDB = "unique_user_id2"} , [FieldNameDB{unFieldNameDB = "user_id"}] ) ] @@ -162,9 +269,9 @@ spec = describe "MigrationSpec" $ do ( ExistingEntitySchemaState { essColumns = [ Column - { cName = FieldNameDB{unFieldNameDB = "user_2_id"} + { cName = FieldNameDB{unFieldNameDB = "user2_id"} , cNull = True - , cSqlType = SqlInt32 + , cSqlType = SqlInt64 , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing @@ -174,16 +281,16 @@ spec = describe "MigrationSpec" $ do ( ColumnReference { crTableName = EntityNameDB{unEntityNameDB = "users"} , crConstraintName = - ConstraintNameDB{unConstraintNameDB = "user_friendships_user_2_id_fkey"} + ConstraintNameDB{unConstraintNameDB = "user_friendships_user2_id_fkey"} , crFieldCascade = - FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} } ) } , Column - { cName = FieldNameDB{unFieldNameDB = "user_1_id"} + { cName = FieldNameDB{unFieldNameDB = "user1_id"} , cNull = True - , cSqlType = SqlInt32 + , cSqlType = SqlInt64 , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing @@ -193,17 +300,17 @@ spec = describe "MigrationSpec" $ do ( ColumnReference { crTableName = EntityNameDB{unEntityNameDB = "users"} , crConstraintName = - ConstraintNameDB{unConstraintNameDB = "user_friendships_user_1_id_fkey"} + ConstraintNameDB{unConstraintNameDB = "user_friendships_user1_id_fkey"} , crFieldCascade = - FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} } ) } , Column { cName = FieldNameDB{unFieldNameDB = "id"} , cNull = False - , cSqlType = SqlInt32 - , cDefault = Just "nextval('user_friendships_id_seq'::regclass)" + , cSqlType = SqlInt64 + , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing , cMaxLen = Nothing @@ -242,8 +349,8 @@ spec = describe "MigrationSpec" $ do , Column { cName = FieldNameDB{unFieldNameDB = "id"} , cNull = False - , cSqlType = SqlInt32 - , cDefault = Just "nextval('users_id_seq'::regclass)" + , cSqlType = SqlInt64 + , cDefault = Nothing , cGenerated = Nothing , cDefaultConstraintName = Nothing , cMaxLen = Nothing @@ -255,5 +362,23 @@ spec = describe "MigrationSpec" $ do ) ) ] + ) actual `shouldBe` Right expected + + it "no-ops on a migrated DB" $ runConnAssert $ do + migrateManually + + connPrepare <- getConnPrepare + result <- + liftIO $ migrateEntitiesStructured connPrepare allEntityDefs allEntityDefs + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> + map (snd . showAlterDb) alters `shouldBe` [] From 4a131486764a0cf877ba55615c7398463d909b56 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 13:36:46 +0000 Subject: [PATCH 06/14] what da hell --- .../Persist/Postgresql/Internal/Migration.hs | 1 + persistent-postgresql/test/MigrationSpec.hs | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 7f2ea4d3f..eb74fe420 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -492,6 +492,7 @@ migrateEntitiesStructured -> IO (Either [Text] [AlterDB]) migrateEntitiesStructured getStmt allDefs defsToMigrate = do r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) + putStrLn $ "collectSchemaState: " <> show r pure $ case r of Right schemaState -> migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs index 73f8031f3..59a30f4d6 100644 --- a/persistent-postgresql/test/MigrationSpec.hs +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -382,3 +382,23 @@ spec = describe "MigrationSpec" $ do expectationFailure $ show err Right alters -> map (snd . showAlterDb) alters `shouldBe` [] + + it "migrates a clean DB" $ runConnAssert $ do + cleanDB + + connPrepare <- getConnPrepare + result <- + liftIO $ migrateEntitiesStructured connPrepare allEntityDefs allEntityDefs + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> do + traverse (flip rawExecute [] . snd . showAlterDb) alters + result2 <- + liftIO $ migrateEntitiesStructured connPrepare allEntityDefs allEntityDefs + result2 `shouldBe` Right [] From 59e7d00306fc9e7fd0595a994c8a8fbe0e17ee9c Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 16:37:55 +0000 Subject: [PATCH 07/14] tracing --- .../Database/Persist/Postgresql/Internal/Migration.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index eb74fe420..7778e49eb 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -13,6 +13,7 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Data.Acquire (with) +import Debug.Trace import Data.Conduit import qualified Data.Conduit.List as CL import Data.Either (partitionEithers) @@ -492,7 +493,7 @@ migrateEntitiesStructured -> IO (Either [Text] [AlterDB]) migrateEntitiesStructured getStmt allDefs defsToMigrate = do r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) - putStrLn $ "collectSchemaState: " <> show r + -- putStrLn $ "collectSchemaState: " <> show r pure $ case r of Right schemaState -> migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate @@ -637,7 +638,10 @@ getAlters -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) -> ([AlterColumn], [AlterTable]) getAlters defs def (c1, u1) (c2, u2) = - (getAltersC c1 c2, getAltersU u1 u2) + if getEntityDBName def == EntityNameDB "child" + then traceShow ((c1, u1), (c2, u2)) $ + traceShowId (getAltersC c1 c2, getAltersU u1 u2) + else (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old From a13c912bbdc00fc23fd407f1d304b56b7ae4ed9e Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:10:22 +0000 Subject: [PATCH 08/14] hell yeah --- .../Persist/Postgresql/Internal/Migration.hs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 7778e49eb..a89894b60 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -537,7 +537,7 @@ migrateEntityFromSchemaState schemaState allDefs entity = allDefs entity (newcols, udspair) - (essColumns, Map.toList essConstraints) + (map dubiouslyRemoveReferences essColumns, Map.toList essConstraints) acs' = map (AlterColumn name) acs ats' = map (AlterTable name) ats in @@ -558,6 +558,22 @@ migrateEntityFromSchemaState schemaState allDefs entity = newcols foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs + -- HACK! This shouldn't really be here; it was added to preserve existing + -- behaviour. The migrator currently expects to only see cReference set in + -- the old columns if it is also set in the new ones. This means that the + -- migrator sometimes behaves incorrectly for standalone Foreign + -- declarations, like Child in the ForeignKey test in persistent-test. + -- + -- See https://github.com/yesodweb/persistent/issues/1611#issuecomment-3613251095 for + -- more info + dubiouslyRemoveReferences oldCol = + case List.find (\c -> cName c == cName oldCol) newcols of + Just new | isNothing (cReference new) -> + oldCol { cReference = Nothing } + _ -> + -- otherwise no-op, `getAlters` will handle dropping this for us. + oldCol + -- | Indicates whether a Postgres Column is safe to drop. -- -- @since 2.17.1.0 From f0eaa71c6c45661c9dcbd8860521204bf322ec64 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:13:59 +0000 Subject: [PATCH 09/14] remove tracing --- .../Database/Persist/Postgresql/Internal/Migration.hs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index a89894b60..6f414f5a7 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -13,7 +13,6 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class import Data.Acquire (with) -import Debug.Trace import Data.Conduit import qualified Data.Conduit.List as CL import Data.Either (partitionEithers) @@ -493,7 +492,6 @@ migrateEntitiesStructured -> IO (Either [Text] [AlterDB]) migrateEntitiesStructured getStmt allDefs defsToMigrate = do r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) - -- putStrLn $ "collectSchemaState: " <> show r pure $ case r of Right schemaState -> migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate @@ -654,10 +652,7 @@ getAlters -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) -> ([AlterColumn], [AlterTable]) getAlters defs def (c1, u1) (c2, u2) = - if getEntityDBName def == EntityNameDB "child" - then traceShow ((c1, u1), (c2, u2)) $ - traceShowId (getAltersC c1 c2, getAltersU u1 u2) - else (getAltersC c1 c2, getAltersU u1 u2) + (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old From 374673e34d7cf02b3941a0368fe20906f431ee94 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:14:09 +0000 Subject: [PATCH 10/14] format --- .../Database/Persist/Postgresql/Internal/Migration.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 6f414f5a7..6c8cdbe47 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -566,8 +566,9 @@ migrateEntityFromSchemaState schemaState allDefs entity = -- more info dubiouslyRemoveReferences oldCol = case List.find (\c -> cName c == cName oldCol) newcols of - Just new | isNothing (cReference new) -> - oldCol { cReference = Nothing } + Just new + | isNothing (cReference new) -> + oldCol{cReference = Nothing} _ -> -- otherwise no-op, `getAlters` will handle dropping this for us. oldCol @@ -652,7 +653,7 @@ getAlters -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) -> ([AlterColumn], [AlterTable]) getAlters defs def (c1, u1) (c2, u2) = - (getAltersC c1 c2, getAltersU u1 u2) + (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old From 4416758985081135050a1dd72dc3354492e5468f Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:29:53 +0000 Subject: [PATCH 11/14] Move last two migration fns into Internal.Migration --- .../Database/Persist/Postgresql/Internal.hs | 28 ------------------- .../Persist/Postgresql/Internal/Migration.hs | 27 ++++++++++++++++++ 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 80fc7c25e..53de1888b 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -48,7 +48,6 @@ import qualified Data.ByteString.Builder as BB import Data.Fixed (Pico) import qualified Data.IntMap as I import Data.Maybe -import Data.Text (Text) import Data.Time ( NominalDiffTime , localTimeToUTC @@ -261,30 +260,3 @@ intervalToPgInterval interval = if calendarDiffDays == mempty then Just $ PgInterval nominalDiffTime else Nothing - --- | Returns a structured representation of all of the --- DB changes required to migrate the Entity from its --- current state in the database to the state described in --- Haskell. --- --- @since 2.17.1.0 -migrateStructured - :: [EntityDef] - -> (Text -> IO Statement) - -> EntityDef - -> IO (Either [Text] [AlterDB]) -migrateStructured allDefs getter entity = - migrateEntitiesStructured getter allDefs [entity] - --- | Returns a structured representation of all of the --- DB changes required to migrate the Entity to the state --- described in Haskell, assuming it currently does not --- exist in the database. --- --- @since 2.17.1.0 -mockMigrateStructured - :: [EntityDef] - -> EntityDef - -> [AlterDB] -mockMigrateStructured allDefs entity = - migrateEntityFromSchemaState EntityDoesNotExist allDefs entity diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 6c8cdbe47..949c21ad0 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -29,6 +29,33 @@ import Data.Traversable import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in +-- Haskell. +-- +-- @since 2.17.1.0 +migrateStructured + :: [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] [AlterDB]) +migrateStructured allDefs getter entity = + migrateEntitiesStructured getter allDefs [entity] + +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity to the state +-- described in Haskell, assuming it currently does not +-- exist in the database. +-- +-- @since 2.17.1.0 +mockMigrateStructured + :: [EntityDef] + -> EntityDef + -> [AlterDB] +mockMigrateStructured allDefs entity = + migrateEntityFromSchemaState EntityDoesNotExist allDefs entity + -- | In order to ensure that generating migrations is fast and avoids N+1 -- queries, we split it into two phases. The first phase involves querying the -- database to gather all of the information we need about the existing schema. From 224b8e362ed1033c697027c387f6c1175543800e Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:33:00 +0000 Subject: [PATCH 12/14] docs + re-export new migrateEntitiesStructured --- .../Database/Persist/Postgresql/Internal.hs | 1 + .../Persist/Postgresql/Internal/Migration.hs | 38 ++++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 53de1888b..0220229a3 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -13,6 +13,7 @@ module Database.Persist.Postgresql.Internal , AlterColumn (..) , SafeToRemove , migrateStructured + , migrateEntitiesStructured , mockMigrateStructured , addTable , findAlters diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 949c21ad0..9c980343a 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -43,6 +43,27 @@ migrateStructured migrateStructured allDefs getter entity = migrateEntitiesStructured getter allDefs [entity] +-- | Returns a structured representation of all of the DB changes required to +-- migrate the listed entities from their current state in the database to the +-- state described in Haskell. This function avoids N+1 queries, so if you +-- have a lot of entities to migrate, it's much faster to use this rather than +-- using 'migrateStructured' in a loop. +-- +-- @since 2.14.1.0 +migrateEntitiesStructured + :: (Text -> IO Statement) + -> [EntityDef] + -> [EntityDef] + -> IO (Either [Text] [AlterDB]) +migrateEntitiesStructured getStmt allDefs defsToMigrate = do + r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) + pure $ case r of + Right schemaState -> + migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate + Left err -> + Left [err] + + -- | Returns a structured representation of all of the -- DB changes required to migrate the Entity to the state -- described in Haskell, assuming it currently does not @@ -508,23 +529,6 @@ mapLeft :: (a1 -> a2) -> Either a1 b -> Either a2 b mapLeft _ (Right x) = Right x mapLeft f (Left x) = Left (f x) --- | Returns a structured representation of all of the --- DB changes required to migrate the Entity from its --- current state in the database to the state described in --- Haskell. -migrateEntitiesStructured - :: (Text -> IO Statement) - -> [EntityDef] - -> [EntityDef] - -> IO (Either [Text] [AlterDB]) -migrateEntitiesStructured getStmt allDefs defsToMigrate = do - r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) - pure $ case r of - Right schemaState -> - migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate - Left err -> - Left [err] - migrateEntitiesFromSchemaState :: SchemaState -> [EntityDef] From 71fe08c22f43da51d9ebf460036ad82f6522c785 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:37:43 +0000 Subject: [PATCH 13/14] bump version + add changelog --- persistent-postgresql/ChangeLog.md | 6 ++++++ persistent-postgresql/persistent-postgresql.cabal | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index a062f06f1..f2a9fa483 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -1,5 +1,11 @@ # Changelog for persistent-postgresql +# 2.14.1.0 + +* [#1612](https://github.com/yesodweb/persistent/pull/1612) + * Speed up migrations by avoiding N+1 queries. + You can now migrate a large set of entities much faster, by using the new `migrateEntitiesStructured` function. + # 2.14.0.1 * [#1610](https://github.com/yesodweb/persistent/pull/1610) diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index c0aa85701..960207f03 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -1,5 +1,5 @@ name: persistent-postgresql -version: 2.14.0.1 +version: 2.14.1.0 license: MIT license-file: LICENSE author: Felipe Lessa, Michael Snoyman From 0eba413a4f0a4b3c650db4297eb0175d274b10e1 Mon Sep 17 00:00:00 2001 From: Harry Garrood Date: Thu, 4 Dec 2025 17:38:53 +0000 Subject: [PATCH 14/14] format --- .../Database/Persist/Postgresql/Internal/Migration.hs | 1 - 1 file changed, 1 deletion(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 9c980343a..00f4f912e 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -63,7 +63,6 @@ migrateEntitiesStructured getStmt allDefs defsToMigrate = do Left err -> Left [err] - -- | Returns a structured representation of all of the -- DB changes required to migrate the Entity to the state -- described in Haskell, assuming it currently does not