|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import os
|
6 |
| -import asyncio |
7 |
| -import subprocess |
8 |
| -import time |
9 |
| -from typing import AsyncGenerator, Optional |
| 6 | +from typing import AsyncGenerator |
10 | 7 | from urllib.parse import quote_plus as urlquote
|
11 |
| -from pathlib import Path |
12 | 8 |
|
13 | 9 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
14 | 10 | from sqlalchemy.orm import sessionmaker
|
15 | 11 | from sqlalchemy.schema import CreateSchema
|
16 | 12 | from fastapi import Depends
|
17 |
| -from alembic.config import Config |
18 | 13 |
|
19 | 14 | from .models import Base, SCHEMA_NAME
|
20 | 15 |
|
|
40 | 35 | engine, class_=AsyncSession, expire_on_commit=False
|
41 | 36 | )
|
42 | 37 |
|
43 |
| -async def run_migrations_with_lock(redis_client=None, lock_timeout: int = 120, max_wait_time: int = 300) -> bool: |
44 |
| - """ |
45 |
| - Run database migrations using Alembic with a Redis distributed lock. |
46 |
| - All workers will wait for the migration to complete before proceeding. |
47 |
| - |
48 |
| - Args: |
49 |
| - redis_client: Redis client instance |
50 |
| - lock_timeout: How long the lock should be held (seconds) |
51 |
| - max_wait_time: Maximum time to wait for migrations to complete (seconds) |
52 |
| - |
53 |
| - Returns: |
54 |
| - bool: True if migrations were run successfully or completed by another instance, False on timeout or error |
55 |
| - """ |
56 |
| - if redis_client is None: |
57 |
| - # Import here to avoid circular imports |
58 |
| - from config import get_redis_client |
59 |
| - redis_client = get_redis_client() |
60 |
| - |
61 |
| - # Keys for Redis coordination |
62 |
| - lock_name = "alembic_migration_lock" |
63 |
| - status_key = "alembic_migration_status" |
64 |
| - lock_value = f"instance_{time.time()}" |
65 |
| - |
66 |
| - # Check if migrations are already completed |
67 |
| - migration_status = redis_client.get(status_key) |
68 |
| - if migration_status == "completed": |
69 |
| - print("Migrations already completed - continuing startup") |
70 |
| - return True |
71 |
| - |
72 |
| - # Try to acquire the lock - non-blocking |
73 |
| - lock_acquired = redis_client.set( |
74 |
| - lock_name, |
75 |
| - lock_value, |
76 |
| - nx=True, # Only set if key doesn't exist |
77 |
| - ex=lock_timeout # Expiry in seconds |
78 |
| - ) |
79 |
| - |
80 |
| - if lock_acquired: |
81 |
| - print("This instance will run migrations") |
82 |
| - try: |
83 |
| - # Set status to in-progress |
84 |
| - redis_client.set(status_key, "in_progress", ex=lock_timeout) |
85 |
| - |
86 |
| - # Run migrations |
87 |
| - success = await run_migrations_subprocess() |
88 |
| - |
89 |
| - if success: |
90 |
| - # Set status to completed with a longer expiry (1 hour) |
91 |
| - redis_client.set(status_key, "completed", ex=3600) |
92 |
| - print("Migration completed successfully - signaling other instances") |
93 |
| - return True |
94 |
| - else: |
95 |
| - # Set status to failed |
96 |
| - redis_client.set(status_key, "failed", ex=3600) |
97 |
| - print("Migration failed - signaling other instances") |
98 |
| - return False |
99 |
| - finally: |
100 |
| - # Release the lock only if we're the owner |
101 |
| - current_value = redis_client.get(lock_name) |
102 |
| - if current_value == lock_value: |
103 |
| - redis_client.delete(lock_name) |
104 |
| - else: |
105 |
| - print("Another instance is running migrations - waiting for completion") |
106 |
| - |
107 |
| - # Wait for the migration to complete |
108 |
| - start_time = time.time() |
109 |
| - while time.time() - start_time < max_wait_time: |
110 |
| - # Check migration status |
111 |
| - status = redis_client.get(status_key) |
112 |
| - |
113 |
| - if status == "completed": |
114 |
| - print("Migrations completed by another instance - continuing startup") |
115 |
| - return True |
116 |
| - elif status == "failed": |
117 |
| - print("Migrations failed in another instance - continuing startup with caution") |
118 |
| - return False |
119 |
| - elif status is None: |
120 |
| - # No status yet, might be a stale lock or not started |
121 |
| - # Check if lock exists |
122 |
| - if not redis_client.exists(lock_name): |
123 |
| - # Lock released but no status - try to acquire the lock ourselves |
124 |
| - print("No active migration lock - attempting to acquire") |
125 |
| - return await run_migrations_with_lock(redis_client, lock_timeout, max_wait_time) |
126 |
| - |
127 |
| - # Wait before checking again |
128 |
| - await asyncio.sleep(1) |
129 |
| - |
130 |
| - # Timeout waiting for migration |
131 |
| - print(f"Timeout waiting for migrations after {max_wait_time} seconds") |
132 |
| - return False |
133 |
| - |
134 |
| -async def run_migrations_subprocess() -> bool: |
135 |
| - """ |
136 |
| - Run Alembic migrations using a subprocess |
137 |
| - |
138 |
| - Returns: |
139 |
| - bool: True if migrations were successful, False otherwise |
140 |
| - """ |
141 |
| - try: |
142 |
| - # Get the path to the database directory |
143 |
| - db_dir = Path(__file__).parent |
144 |
| - |
145 |
| - # Create a subprocess to run alembic |
146 |
| - process = await asyncio.create_subprocess_exec( |
147 |
| - 'alembic', 'upgrade', 'head', |
148 |
| - stdout=asyncio.subprocess.PIPE, |
149 |
| - stderr=asyncio.subprocess.PIPE, |
150 |
| - cwd=str(db_dir) # Run in the database directory |
151 |
| - ) |
152 |
| - |
153 |
| - # Wait for the process to complete with a timeout |
154 |
| - try: |
155 |
| - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=60) |
156 |
| - |
157 |
| - if process.returncode == 0: |
158 |
| - print("Database migrations completed successfully") |
159 |
| - if stdout: |
160 |
| - print(stdout.decode()) |
161 |
| - return True |
162 |
| - else: |
163 |
| - print(f"Migration failed with error code {process.returncode}") |
164 |
| - if stderr: |
165 |
| - print(stderr.decode()) |
166 |
| - return False |
167 |
| - |
168 |
| - except asyncio.TimeoutError: |
169 |
| - print("Migration timed out after 60 seconds") |
170 |
| - # Try to terminate the process |
171 |
| - process.terminate() |
172 |
| - return False |
173 |
| - |
174 |
| - except Exception as e: |
175 |
| - print(f"Migration failed: {str(e)}") |
176 |
| - return False |
177 |
| - |
178 |
| -async def run_migrations() -> None: |
179 |
| - """ |
180 |
| - Legacy function to run migrations directly (without lock) |
181 |
| - This is kept for backward compatibility |
182 |
| - """ |
183 |
| - await run_migrations_subprocess() |
184 | 38 |
|
185 | 39 | async def init_db() -> None:
|
186 | 40 | """Initialize the database with required tables"""
|
187 |
| - # Only create tables, let Alembic handle schema creation |
188 |
| - async with engine.begin() as conn: |
189 |
| - await conn.run_sync(Base.metadata.create_all) |
| 41 | + try: |
| 42 | + # Create schema if it doesn't exist |
| 43 | + async with engine.begin() as conn: |
| 44 | + await conn.execute(CreateSchema(SCHEMA_NAME, if_not_exists=True)) |
| 45 | + |
| 46 | + # Create tables |
| 47 | + async with engine.begin() as conn: |
| 48 | + await conn.run_sync(Base.metadata.create_all) |
| 49 | + |
| 50 | + except Exception as e: |
| 51 | + print(f"Error initializing database: {str(e)}") |
| 52 | + raise |
| 53 | + |
190 | 54 |
|
191 | 55 | async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
192 | 56 | """Get a database session"""
|
|
0 commit comments