diff --git a/src/storage/db_connection.py b/src/storage/db_connection.py index c9b30b767..55a957b79 100644 --- a/src/storage/db_connection.py +++ b/src/storage/db_connection.py @@ -22,14 +22,14 @@ def __init__( address = '/var/run/postgresql' port = config.common.postgres.port - database = db_name if db_name else config.common.postgres.database + self.database = db_name if db_name else config.common.postgres.database engine_url = URL.create( 'postgresql', username=user, password=password, host=address, port=port, - database=database, + database=self.database, ) self.engine = create_engine(engine_url, pool_size=100, future=True, **kwargs) self.session_maker = sessionmaker(bind=self.engine, future=True) # future=True => sqlalchemy 2.0 support diff --git a/src/test/common_helper.py b/src/test/common_helper.py index 400f51adb..9612c6fbc 100644 --- a/src/test/common_helper.py +++ b/src/test/common_helper.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable +import config from helperFunctions.fileSystem import get_src_dir from helperFunctions.tag import TagColor from objects.file import FileObject @@ -299,6 +300,7 @@ def setup_test_tables(db_setup): def clear_test_tables(db_setup): + assert db_setup.connection.database == config.common.postgres.test_database db_setup.connection.base.metadata.drop_all(db_setup.connection.engine) diff --git a/src/test/conftest.py b/src/test/conftest.py index 596cf82b5..860d9eef8 100644 --- a/src/test/conftest.py +++ b/src/test/conftest.py @@ -138,6 +138,7 @@ def _database_interfaces(): # noqa: PT005 with pytest.MonkeyPatch.context() as mpk: config.load() # Make sure to match the config here with the one in src/conftest.py:common_config + assert config.common.postgres.database != config.common.postgres.test_database sections = { 'postgres': { 'server': config.common.postgres.server, @@ -178,6 +179,7 @@ def _database_interfaces(): # noqa: PT005 comparison = ComparisonDbInterface(connection=rw_connection) admin = AdminDbInterface(intercom=MockIntercom()) stats_update = StatsUpdateDbInterface(connection=rw_connection) + assert common.connection.database == config.common.postgres.test_database setup_test_tables(db_setup) @@ -195,6 +197,7 @@ def database_interfaces(_database_interfaces) -> DatabaseInterfaces: yield _database_interfaces finally: with _database_interfaces.admin.get_read_write_session() as session: + assert _database_interfaces.admin.connection.database == config.common.postgres.test_database # clear rows from test db between tests for table in reversed(_database_interfaces.admin.connection.base.metadata.sorted_tables): session.execute(table.delete())