-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathbackup.py
More file actions
358 lines (279 loc) · 12.1 KB
/
backup.py
File metadata and controls
358 lines (279 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
#!/usr/bin/env python3
import sqlite3
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional
from zoneinfo import ZoneInfo
# ─── Configuration ──────────────────────────────────────────────────────────────
# Folder where live DBs live, and where backups go
SRC_DIR = Path("databases")
BASE_BACKUP_DIR = Path("backup")
# List of DB filenames
SIMPLE_DBS = ["auth.db", "error.db"]
FILTERED_DBS = [("main.db", "trip", "uid"), ("path.db", "paths", "trip_id")]
# Max parameters per SQLite query (keep below 999)
CHUNK_SIZE = 900
# ─── Progress Bar Class ─────────────────────────────────────────────────────────
class ProgressBar:
def __init__(self, total: int, description: str = "", width: int = 40):
self.total = total
self.current = 0
self.description = description
self.width = width
self.start_time = time.time()
self.last_update = 0
def update(self, amount: int = 1):
"""Update progress by amount and display if enough time has passed."""
self.current = min(self.current + amount, self.total)
# Update display at most every 0.1 seconds to avoid flickering
current_time = time.time()
if current_time - self.last_update >= 0.1 or self.current == self.total:
self._display()
self.last_update = current_time
def _display(self):
"""Display the current progress bar."""
if self.total == 0:
percentage = 100
filled_length = self.width
else:
percentage = (self.current / self.total) * 100
filled_length = int(self.width * self.current // self.total)
bar = "█" * filled_length + "▒" * (self.width - filled_length)
# Calculate time estimates
elapsed = time.time() - self.start_time
if self.current > 0 and self.current < self.total:
rate = self.current / elapsed
remaining = (self.total - self.current) / rate
time_str = f" | ETA: {self._format_time(remaining)}"
elif self.current == self.total:
time_str = f" | Done in {self._format_time(elapsed)}"
else:
time_str = " | Calculating..."
# Build the progress line
progress_line = f"\r{self.description} |{bar}| {percentage:5.1f}% ({self.current}/{self.total}){time_str}"
# Print and ensure we don't leave trailing characters
print(progress_line.ljust(80), end="", flush=True)
if self.current == self.total:
print() # New line when complete
def _format_time(self, seconds: float) -> str:
"""Format seconds into human-readable time."""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
minutes = int(seconds // 60)
secs = int(seconds % 60)
return f"{minutes}m {secs}s"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m"
# ─── Helper Functions ────────────────────────────────────────────────────────────
def now_iso_date():
"""Return today's date in YYYY-MM-DD for Europe/Oslo."""
tz = ZoneInfo("Europe/Oslo")
return datetime.now(tz).strftime("%Y-%m-%d")
def connect_readonly(path: Path):
"""Open a read-only, immutable SQLite URI connection."""
uri = f"file:{path}?mode=ro&immutable=1"
return sqlite3.connect(uri, uri=True)
def connect_writable(path: Path):
"""Open or create a writable SQLite file."""
return sqlite3.connect(path)
def get_table_row_count(conn: sqlite3.Connection, table: str) -> int:
"""Get the number of rows in a table."""
cur = conn.cursor()
cur.execute(f"SELECT COUNT(*) FROM {table}")
return cur.fetchone()[0]
def get_all_tables_row_count(conn: sqlite3.Connection) -> int:
"""Get total row count across all user tables."""
cur = conn.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
total = 0
for (table,) in cur.fetchall():
total += get_table_row_count(conn, table)
return total
def copy_schema_and_data(
src_conn, dst_conn, table_filter=None, progress_callback: Optional[Callable] = None
):
"""
Copy all tables/indexes/triggers/views from src to dst.
Skips sqlite_sequence and any tables rejected by table_filter(name)->bool.
"""
src = src_conn.cursor()
dst = dst_conn.cursor()
# Grab all user-and-meta objects (tables, indexes, triggers, views)
src.execute("""
SELECT type, name, sql
FROM sqlite_master
WHERE name NOT LIKE 'sqlite_%'
AND sql IS NOT NULL
ORDER BY type='table' DESC, type='index', type;
""")
schema_objects = src.fetchall()
# Create schema objects
for obj_type, name, sql in schema_objects:
if table_filter and not table_filter(name, obj_type):
continue
dst.execute(sql)
dst_conn.commit()
# Copy table data with progress tracking
src.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
)
tables = [row[0] for row in src.fetchall()]
for tbl in tables:
if table_filter and not table_filter(tbl, "table"):
continue
# Get all rows for this table
rows = src.execute(f"SELECT * FROM {tbl}").fetchall()
if not rows:
continue
# Insert rows and update progress
placeholders = ", ".join(["?"] * len(rows[0]))
dst.executemany(f"INSERT INTO {tbl} VALUES ({placeholders})", rows)
if progress_callback:
progress_callback(len(rows))
dst_conn.commit()
def get_ids(src_path: Path, table: str, column: str) -> set:
"""Fetch the set of values of `column` from `table` in src_path."""
conn = connect_readonly(src_path)
cur = conn.cursor()
cur.execute(f"SELECT DISTINCT {column} FROM {table}")
ids = {row[0] for row in cur.fetchall()}
conn.close()
return ids
def chunked(iterable, size):
"""Yield successive chunks from iterable of length ≤ size."""
it = list(iterable)
for i in range(0, len(it), size):
yield it[i : i + size]
# ─── Backup Routines ────────────────────────────────────────────────────────────
def backup_simple(db_name: str, dst_folder: Path):
"""Copy schema + all data from a simple DB (no cross-filtering)."""
src = SRC_DIR / db_name
dst = dst_folder / db_name
print(f"Analyzing {db_name}...")
# Get total row count for progress tracking
with connect_readonly(src) as src_conn:
total_rows = get_all_tables_row_count(src_conn)
if total_rows == 0:
print(f"✅ {db_name} is empty, skipping.")
return
# Create progress bar
progress = ProgressBar(total_rows, f"Backing up {db_name}")
def update_progress(rows_processed):
progress.update(rows_processed)
# Perform backup with progress tracking
with connect_readonly(src) as src_conn, connect_writable(dst) as dst_conn:
copy_schema_and_data(src_conn, dst_conn, progress_callback=update_progress)
def backup_filtered(
main_db: str, table: str, column: str, valid_ids: set, dst_folder: Path
):
"""
Copy schema + all data *except* `table`, then create `table` and copy only rows
whose `column` is in valid_ids (in chunks).
"""
src = SRC_DIR / main_db
dst = dst_folder / main_db
print(f"Analyzing {main_db}...")
with connect_readonly(src) as src_conn:
# Get total rows excluding the filtered table
total_other_rows = 0
cur = src_conn.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
for (tbl,) in cur.fetchall():
if tbl != table:
total_other_rows += get_table_row_count(src_conn, tbl)
# Get count of rows we'll copy from the filtered table
if valid_ids:
# Estimate based on a sample to avoid creating huge IN clauses
sample_size = min(100, len(valid_ids))
sample_ids = list(valid_ids)[:sample_size]
qmarks = ",".join("?" for _ in sample_ids)
sample_count = cur.execute(
f"SELECT COUNT(*) FROM {table} WHERE {column} IN ({qmarks})",
tuple(sample_ids),
).fetchone()[0]
# Extrapolate to full set
filtered_rows = int(sample_count * len(valid_ids) / sample_size)
else:
filtered_rows = 0
total_rows = total_other_rows + filtered_rows
if total_rows == 0:
print(f"✅ {main_db} would be empty after filtering, skipping.")
return
# Create progress bar
progress = ProgressBar(total_rows, f"Backing up {main_db} (filtered)")
with connect_readonly(src) as src_conn, connect_writable(dst) as dst_conn:
# 1) Copy everything *except* our filtered table
def filter_out(name, obj_type):
return not (obj_type == "table" and name == table)
def update_progress_other(rows_processed):
progress.update(rows_processed)
copy_schema_and_data(
src_conn,
dst_conn,
table_filter=filter_out,
progress_callback=update_progress_other,
)
# 2) Now create the filtered table schema itself
schema_sql = src_conn.execute(
"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,)
).fetchone()[0]
dst_conn.execute(schema_sql)
# 3) Copy filtered data in chunks
insert_cur = dst_conn.cursor()
for chunk in chunked(valid_ids, CHUNK_SIZE):
qmarks = ",".join("?" for _ in chunk)
rows = src_conn.execute(
f"SELECT * FROM {table} WHERE {column} IN ({qmarks})", tuple(chunk)
).fetchall()
if rows:
ph = ",".join("?" for _ in rows[0])
insert_cur.executemany(f"INSERT INTO {table} VALUES ({ph})", rows)
progress.update(len(rows))
dst_conn.commit()
# ─── Main Script ────────────────────────────────────────────────────────────────
def main():
print("🔄 Starting database backup...\n")
# 1) Prepare destination folder
date_str = now_iso_date()
dst = BASE_BACKUP_DIR / date_str
dst.mkdir(parents=True, exist_ok=True)
print(f"📁 Backing up to folder: {dst}\n")
# 2) Simple DBs
for db in SIMPLE_DBS:
backup_simple(db, dst)
print() # Add spacing between databases
# 3) Compute valid trip IDs = intersection of main.trip.uid and path.paths.trip_id
print("🔍 Computing valid trip IDs...")
main_ids = get_ids(SRC_DIR / "main.db", "trip", "uid")
path_ids = get_ids(SRC_DIR / "path.db", "paths", "trip_id")
valid = main_ids & path_ids
print(f" Found {len(main_ids)} trip IDs in main.db")
print(f" Found {len(path_ids)} trip IDs in path.db")
print(f" Valid intersection: {len(valid)} trip IDs")
if not valid:
print("⚠️ Warning: no matching trip IDs between main.db and path.db")
print()
# 4) Filtered DBs
backup_filtered("main.db", "trip", "uid", valid, dst)
print()
backup_filtered("path.db", "paths", "trip_id", valid, dst)
print()
print("✅ Backup complete!")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ Backup interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\n❌ Backup failed: {e}", file=sys.stderr)
sys.exit(1)