|
| 1 | +import datetime |
| 2 | +import json |
| 3 | +import time |
| 4 | + |
| 5 | +from cds.modules.deposit.api import ( |
| 6 | + deposit_video_resolver, |
| 7 | + get_master_object, |
| 8 | + record_video_resolver, |
| 9 | +) |
| 10 | +from cds.modules.flows.deposit import index_deposit_project |
| 11 | +from cds.modules.flows.models import FlowMetadata, FlowTaskMetadata, FlowTaskStatus |
| 12 | +from cds.modules.flows.tasks import ( |
| 13 | + ExtractChapterFramesTask, |
| 14 | + ExtractFramesTask, |
| 15 | + ExtractMetadataTask, |
| 16 | + TranscodeVideoTask, |
| 17 | + sync_records_with_deposit_files, |
| 18 | +) |
| 19 | +from invenio_db import db |
| 20 | +from invenio_files_rest.models import ObjectVersion, ObjectVersionTag |
| 21 | + |
| 22 | + |
| 23 | +def copy_master_tags_between_buckets(src_bucket, dst_bucket): |
| 24 | + """Copy tags of the master ObjectVersion from src_bucket to dst_bucket.""" |
| 25 | + # Find master in deposit |
| 26 | + src_master = get_master_object(src_bucket) |
| 27 | + |
| 28 | + # Find master in record |
| 29 | + dst_master = ObjectVersion.get(dst_bucket, src_master.key) |
| 30 | + |
| 31 | + # Update tags because it'll not update during publish |
| 32 | + for tag in src_master.tags: |
| 33 | + ObjectVersionTag.create_or_update(dst_master, tag.key, tag.value) |
| 34 | + |
| 35 | + db.session.commit() |
| 36 | + |
| 37 | + |
| 38 | +def _find_celery_task_by_name(name): |
| 39 | + for celery_task in [ |
| 40 | + ExtractMetadataTask, |
| 41 | + ExtractFramesTask, |
| 42 | + ExtractChapterFramesTask, |
| 43 | + TranscodeVideoTask, |
| 44 | + ]: |
| 45 | + if celery_task.name == name: |
| 46 | + return celery_task |
| 47 | + |
| 48 | + |
| 49 | +def find_failed_tasks(deposit_id): |
| 50 | + flow = FlowMetadata.get_by_deposit(deposit_id) |
| 51 | + failed_tasks = [] |
| 52 | + |
| 53 | + for task in flow.tasks: |
| 54 | + task = db.session.query(FlowTaskMetadata).get(task.id) |
| 55 | + if task.status == FlowTaskStatus.FAILURE: |
| 56 | + failed_tasks.append((task.name, task.id)) |
| 57 | + |
| 58 | + return flow, failed_tasks |
| 59 | + |
| 60 | + |
| 61 | +def run_failed_tasks(failed_tasks, flow, deposit_id, record_id): |
| 62 | + failed_tasks = failed_tasks.copy() |
| 63 | + payload = flow.payload.copy() |
| 64 | + task_names = [task[0] for task in failed_tasks] |
| 65 | + flow_id = flow.id |
| 66 | + |
| 67 | + # --- Handle ExtractMetadataTask separately --- |
| 68 | + if ExtractMetadataTask.name in task_names: |
| 69 | + failed_task = next(t for t in failed_tasks if t[0] == ExtractMetadataTask.name) |
| 70 | + task_id = failed_task[1] |
| 71 | + task = db.session.query(FlowTaskMetadata).get(task_id) |
| 72 | + task.status = FlowTaskStatus.PENDING |
| 73 | + db.session.commit() |
| 74 | + |
| 75 | + print(f"Re-running ExtractMetadataTask for record {record_id}") |
| 76 | + payload["task_id"] = str(task.id) |
| 77 | + |
| 78 | + celery_task = ExtractMetadataTask() |
| 79 | + celery_task.clean(deposit_id=deposit_id, version_id=payload["version_id"]) |
| 80 | + celery_task.s(**payload).apply_async() |
| 81 | + db.session.commit() |
| 82 | + fetch_tasks_status(flow_id, timeout_seconds=60) |
| 83 | + |
| 84 | + # Remove from failed list so we don't run it twice |
| 85 | + failed_tasks = [t for t in failed_tasks if t[0] != ExtractMetadataTask.name] |
| 86 | + |
| 87 | + db.session.expire_all() |
| 88 | + flow = db.session.query(FlowMetadata).get(flow_id) |
| 89 | + deposit = deposit_video_resolver(deposit_id) |
| 90 | + extracted_metadata = deposit["_cds"]["extracted_metadata"] |
| 91 | + record = record_video_resolver(record_id) |
| 92 | + record["_cds"]["extracted_metadata"] = extracted_metadata |
| 93 | + record.commit() |
| 94 | + db.session.commit() |
| 95 | + copy_master_tags_between_buckets( |
| 96 | + src_bucket=deposit.bucket, |
| 97 | + dst_bucket=record["_buckets"]["record"], |
| 98 | + ) |
| 99 | + |
| 100 | + # --- Handle ExtractFramesTask separately --- |
| 101 | + if ExtractFramesTask.name in task_names: |
| 102 | + failed_task = next(t for t in failed_tasks if t[0] == ExtractFramesTask.name) |
| 103 | + task_id = failed_task[1] |
| 104 | + task = db.session.query(FlowTaskMetadata).get(task_id) |
| 105 | + task.status = FlowTaskStatus.PENDING |
| 106 | + db.session.commit() |
| 107 | + |
| 108 | + print(f"Re-running ExtractFramesTask for record {record_id}") |
| 109 | + payload["task_id"] = str(task.id) |
| 110 | + |
| 111 | + celery_task = ExtractFramesTask() |
| 112 | + celery_task.clean(deposit_id=deposit_id, version_id=payload["version_id"]) |
| 113 | + celery_task.s(**payload).apply_async() |
| 114 | + db.session.commit() |
| 115 | + fetch_tasks_status(flow_id, timeout_seconds=60) |
| 116 | + # Sync files between deposit and record |
| 117 | + sync_records_with_deposit_files(deposit_id) |
| 118 | + |
| 119 | + # Remove from failed list so we don't run it twice |
| 120 | + failed_tasks = [t for t in failed_tasks if t[0] != ExtractFramesTask.name] |
| 121 | + |
| 122 | + # --- Handle if other task failed --- |
| 123 | + for task_name, task_id in failed_tasks: |
| 124 | + print(f"Re-running failed task: {task_name} for record {record_id}") |
| 125 | + |
| 126 | + task_cls = _find_celery_task_by_name(task_name) |
| 127 | + if not task_cls: |
| 128 | + print(f"No Celery task class found for {task_name}. Skipping.") |
| 129 | + continue |
| 130 | + |
| 131 | + task = db.session.query(FlowTaskMetadata).get(task_id) |
| 132 | + task.status = FlowTaskStatus.PENDING |
| 133 | + db.session.commit() |
| 134 | + |
| 135 | + payload["task_id"] = str(task.id) |
| 136 | + |
| 137 | + celery_task = task_cls() |
| 138 | + celery_task.clean(deposit_id=deposit_id, version_id=payload["version_id"]) |
| 139 | + celery_task.s(**payload).apply_async() |
| 140 | + db.session.commit() |
| 141 | + |
| 142 | + fetch_tasks_status(flow_id, timeout_seconds=60) |
| 143 | + |
| 144 | + |
| 145 | +def fetch_tasks_status(flow_id, timeout_seconds=30): |
| 146 | + start_time = time.time() |
| 147 | + |
| 148 | + while True: |
| 149 | + elapsed_time = time.time() - start_time |
| 150 | + if elapsed_time >= timeout_seconds: |
| 151 | + print(f"Timeout reached after {timeout_seconds} seconds. Exiting.") |
| 152 | + break |
| 153 | + |
| 154 | + # Force SQLAlchemy to fetch fresh data from the DB |
| 155 | + db.session.expire_all() |
| 156 | + |
| 157 | + flow = db.session.query(FlowMetadata).get(flow_id) |
| 158 | + all_tasks_finished = True |
| 159 | + |
| 160 | + for task in flow.tasks: |
| 161 | + task = db.session.query(FlowTaskMetadata).get(task.id) |
| 162 | + if task.status == FlowTaskStatus.PENDING: |
| 163 | + print(f"Task {task.name} is still pending. Waiting...") |
| 164 | + all_tasks_finished = False |
| 165 | + elif task.status == FlowTaskStatus.STARTED: |
| 166 | + print(f"Task {task.name} is started. Waiting...") |
| 167 | + all_tasks_finished = False |
| 168 | + |
| 169 | + if all_tasks_finished: |
| 170 | + print("✅ All tasks are completed (SUCCESS or FAILURE).") |
| 171 | + break |
| 172 | + |
| 173 | + time.sleep(5) # Poll every 5 seconds |
| 174 | + |
| 175 | + |
| 176 | +def finalize_tasks(deposit_id): |
| 177 | + # Always work on a clean session to avoid cached data |
| 178 | + db.session.expire_all() |
| 179 | + |
| 180 | + flow = FlowMetadata.get_by_deposit(deposit_id) |
| 181 | + flow_id = flow.id |
| 182 | + payload = flow.payload.copy() |
| 183 | + |
| 184 | + # Determine if ExtractChapterFramesTask needs to run |
| 185 | + run_chapters_task = True |
| 186 | + for task in flow.tasks: |
| 187 | + if ( |
| 188 | + task.name == ExtractChapterFramesTask.name |
| 189 | + and task.status == FlowTaskStatus.SUCCESS |
| 190 | + ): |
| 191 | + run_chapters_task = False |
| 192 | + |
| 193 | + if run_chapters_task: |
| 194 | + print("Running ExtractChapterFramesTask...") |
| 195 | + |
| 196 | + # Create a FlowTaskMetadata |
| 197 | + new_task = FlowTaskMetadata( |
| 198 | + flow_id=flow_id, |
| 199 | + name=ExtractChapterFramesTask.name, |
| 200 | + status=FlowTaskStatus.PENDING, |
| 201 | + ) |
| 202 | + db.session.add(new_task) |
| 203 | + db.session.commit() |
| 204 | + |
| 205 | + payload["task_id"] = str(new_task.id) |
| 206 | + ExtractChapterFramesTask().s(**payload).apply_async() |
| 207 | + |
| 208 | + # Poll for task completion |
| 209 | + fetch_tasks_status(flow_id, timeout_seconds=120) |
| 210 | + |
| 211 | + |
| 212 | +def fetch_flow_and_log(record_id, deposit_id, flow_id, failed_tasks, log_file_path): |
| 213 | + """Fetch the latest flow and write detailed info to the log file.""" |
| 214 | + # Ensure we read the latest DB state |
| 215 | + db.session.expire_all() |
| 216 | + flow = db.session.query(FlowMetadata).get(flow_id) |
| 217 | + |
| 218 | + with open(log_file_path, "a") as log_file: |
| 219 | + log_file.write("\n" + "=" * 80 + "\n") |
| 220 | + log_file.write(f"Record ID: {record_id}\n") |
| 221 | + log_file.write(f"Deposit ID: {deposit_id}\n") |
| 222 | + log_file.write(f"Flow ID: {flow_id}\n") |
| 223 | + log_file.write("-" * 80 + "\n") |
| 224 | + |
| 225 | + # Log previously failed tasks |
| 226 | + if failed_tasks: |
| 227 | + log_file.write("Previously failed tasks:\n") |
| 228 | + for task_name, task_id in failed_tasks: |
| 229 | + log_file.write(f" - {task_name} (ID: {task_id})\n") |
| 230 | + else: |
| 231 | + log_file.write("No previously failed tasks.\n") |
| 232 | + |
| 233 | + log_file.write("-" * 80 + "\n") |
| 234 | + log_file.write("Latest task statuses:\n") |
| 235 | + |
| 236 | + # Iterate all tasks in the flow and log their current statuses |
| 237 | + for task in flow.tasks: |
| 238 | + task_obj = db.session.query(FlowTaskMetadata).get(task.id) |
| 239 | + log_file.write(f" • {task_obj.name:<30} | Status: {task_obj.status}\n") |
| 240 | + |
| 241 | + log_file.write("=" * 80 + "\n\n") |
| 242 | + |
| 243 | + |
| 244 | +def load_record_ids(redirections_file_path): |
| 245 | + with open(redirections_file_path, "r") as f: |
| 246 | + data = json.load(f) |
| 247 | + |
| 248 | + # Extract all cds_videos_id values |
| 249 | + record_ids = [item["cds_videos_id"] for item in data] |
| 250 | + return record_ids |
| 251 | + |
| 252 | + |
| 253 | +def main(): |
| 254 | + # Create a log file |
| 255 | + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| 256 | + log_file_path = f"/tmp/task_recovery_log_{timestamp}.txt" |
| 257 | + with open(log_file_path, "w") as log_file: |
| 258 | + pass |
| 259 | + |
| 260 | + redirections_file_path = "/tmp/record_redirections.json" |
| 261 | + all_record_ids = load_record_ids(redirections_file_path) |
| 262 | + record_ids = all_record_ids[:100] # any subset |
| 263 | + for record_id in record_ids: |
| 264 | + record = record_video_resolver(record_id) |
| 265 | + deposit_id = record["_deposit"]["id"] |
| 266 | + |
| 267 | + flow, failed_tasks = find_failed_tasks(deposit_id) |
| 268 | + flow_id = flow.id |
| 269 | + if not failed_tasks: |
| 270 | + print(f"No failed tasks found for record {record_id}.") |
| 271 | + else: |
| 272 | + run_failed_tasks(failed_tasks, flow, deposit_id, record_id) |
| 273 | + |
| 274 | + finalize_tasks(deposit_id) |
| 275 | + |
| 276 | + fetch_flow_and_log(record_id, deposit_id, flow_id, failed_tasks, log_file_path) |
0 commit comments