diff --git a/src/mysql_to_sqlite3/cli.py b/src/mysql_to_sqlite3/cli.py index cd9bc26..769e22f 100644 --- a/src/mysql_to_sqlite3/cli.py +++ b/src/mysql_to_sqlite3/cli.py @@ -93,6 +93,9 @@ help="Prefix indices with their corresponding tables. " "This ensures that their names remain unique across the SQLite database.", ) +@click.option( + "-D", "--defer-foreign-keys", is_flag=True, help="Defer foreign key constraints until the end of the transfer." +) @click.option("-X", "--without-foreign-keys", is_flag=True, help="Do not transfer foreign keys.") @click.option( "-Z", @@ -164,6 +167,7 @@ def cli( limit_rows: int, collation: t.Optional[str], prefix_indices: bool, + defer_foreign_keys: bool, without_foreign_keys: bool, without_tables: bool, without_data: bool, @@ -212,6 +216,11 @@ def cli( limit_rows=limit_rows, collation=collation, prefix_indices=prefix_indices, + defer_foreign_keys=( + defer_foreign_keys + if not without_foreign_keys and not (mysql_tables is not None and len(mysql_tables) > 0) + else False + ), without_foreign_keys=without_foreign_keys or (mysql_tables is not None and len(mysql_tables) > 0), without_tables=without_tables, without_data=without_data, diff --git a/src/mysql_to_sqlite3/transporter.py b/src/mysql_to_sqlite3/transporter.py index f5fdf01..07f1a61 100644 --- a/src/mysql_to_sqlite3/transporter.py +++ b/src/mysql_to_sqlite3/transporter.py @@ -95,6 +95,16 @@ def __init__(self, **kwargs: tx.Unpack[MySQLtoSQLiteParams]) -> None: else: self._without_foreign_keys = bool(kwargs.get("without_foreign_keys", False)) + if not self._without_foreign_keys and not bool(self._mysql_tables) and not bool(self._exclude_mysql_tables): + self._defer_foreign_keys = bool(kwargs.get("defer_foreign_keys", False)) + if self._defer_foreign_keys and sqlite3.sqlite_version < "3.6.19": + self._logger.warning( + "SQLite %s lacks DEFERRABLE support. Ignoring -D/--defer-foreign-keys.", sqlite3.sqlite_version + ) + self._defer_foreign_keys = False + else: + self._defer_foreign_keys = False + self._without_data = bool(kwargs.get("without_data", False)) self._without_tables = bool(kwargs.get("without_tables", False)) @@ -557,10 +567,12 @@ def _build_create_table_sql(self, table_name: str) -> str: ) for foreign_key in self._mysql_cur_dict.fetchall(): if foreign_key is not None: + deferrable_clause = " DEFERRABLE INITIALLY DEFERRED" if self._defer_foreign_keys else "" sql += ( ',\n\tFOREIGN KEY("{column}") REFERENCES "{ref_table}" ("{ref_column}") ' "ON UPDATE {on_update} " - "ON DELETE {on_delete}".format(**foreign_key) # type: ignore[str-bytes-safe] + "ON DELETE {on_delete}" + "{deferrable}".format(**foreign_key, deferrable=deferrable_clause) # type: ignore[str-bytes-safe] ) sql += "\n);" @@ -755,6 +767,32 @@ def transfer(self) -> None: # re-enable foreign key checking once done transferring self._sqlite_cur.execute("PRAGMA foreign_keys=ON") + # Check for any foreign key constraint violations + self._logger.info("Validating foreign key constraints in SQLite database.") + try: + self._sqlite_cur.execute("PRAGMA foreign_key_check") + fk_violations: t.List[sqlite3.Row] = self._sqlite_cur.fetchall() + + if fk_violations: + self._logger.warning( + "Foreign key constraint violations found (%d violation%s):", + len(fk_violations), + "s" if len(fk_violations) != 1 else "", + ) + for violation in fk_violations: + self._logger.warning( + " → Table '%s' (row %s) references missing key in '%s' (constraint #%s)", + violation[0], + violation[1], + violation[2], + violation[3], + ) + else: + self._logger.info("All foreign key constraints are valid.") + + except sqlite3.Error as err: + self._logger.warning("Failed to validate foreign key constraints: %s", err) + if self._vacuum: self._logger.info("Vacuuming created SQLite database file.\nThis might take a while.") self._sqlite_cur.execute("VACUUM") diff --git a/src/mysql_to_sqlite3/types.py b/src/mysql_to_sqlite3/types.py index 2a28f2a..dcf20e0 100644 --- a/src/mysql_to_sqlite3/types.py +++ b/src/mysql_to_sqlite3/types.py @@ -35,6 +35,7 @@ class MySQLtoSQLiteParams(tx.TypedDict): vacuum: t.Optional[bool] without_tables: t.Optional[bool] without_data: t.Optional[bool] + defer_foreign_keys: t.Optional[bool] without_foreign_keys: t.Optional[bool] @@ -71,4 +72,5 @@ class MySQLtoSQLiteAttributes: _sqlite_json1_extension_enabled: bool _vacuum: bool _without_data: bool + _defer_foreign_keys: bool _without_foreign_keys: bool diff --git a/tests/unit/test_transporter.py b/tests/unit/test_transporter.py index 3ff5b91..5434ea1 100644 --- a/tests/unit/test_transporter.py +++ b/tests/unit/test_transporter.py @@ -150,7 +150,7 @@ def test_transfer_exception_handling(self, mock_sqlite_connect: MagicMock, mock_ assert "Test exception" in str(excinfo.value) # Verify that foreign keys are re-enabled in the finally block - mock_sqlite_cursor.execute.assert_called_with("PRAGMA foreign_keys=ON") + mock_sqlite_cursor.execute.assert_called_with("PRAGMA foreign_key_check") def test_constructor_missing_mysql_database(self) -> None: """Test constructor raises ValueError if mysql_database is missing.""" @@ -225,3 +225,27 @@ def test_translate_default_from_mysql_to_sqlite_bytes(self) -> None: """Test _translate_default_from_mysql_to_sqlite with bytes default.""" result = MySQLtoSQLite._translate_default_from_mysql_to_sqlite(b"abc", column_type="BLOB") assert result.startswith("DEFAULT x'") + + def test_translate_default_from_mysql_to_sqlite_curtime(self) -> None: + """Test _translate_default_from_mysql_to_sqlite with curtime().""" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("curtime()") == "DEFAULT CURRENT_TIME" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("CURTIME()") == "DEFAULT CURRENT_TIME" + + def test_translate_default_from_mysql_to_sqlite_curdate(self) -> None: + """Test _translate_default_from_mysql_to_sqlite with curdate().""" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("curdate()") == "DEFAULT CURRENT_DATE" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("CURDATE()") == "DEFAULT CURRENT_DATE" + + def test_translate_default_from_mysql_to_sqlite_current_timestamp_with_parentheses(self) -> None: + """Test _translate_default_from_mysql_to_sqlite with current_timestamp().""" + assert ( + MySQLtoSQLite._translate_default_from_mysql_to_sqlite("current_timestamp()") == "DEFAULT CURRENT_TIMESTAMP" + ) + assert ( + MySQLtoSQLite._translate_default_from_mysql_to_sqlite("CURRENT_TIMESTAMP()") == "DEFAULT CURRENT_TIMESTAMP" + ) + + def test_translate_default_from_mysql_to_sqlite_now(self) -> None: + """Test _translate_default_from_mysql_to_sqlite with now().""" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("now()") == "DEFAULT CURRENT_TIMESTAMP" + assert MySQLtoSQLite._translate_default_from_mysql_to_sqlite("NOW()") == "DEFAULT CURRENT_TIMESTAMP"