From 6ed6cc6014ee135b6048f1666166fb881c8c6d2a Mon Sep 17 00:00:00 2001 From: David Koslicki Date: Mon, 26 Jan 2026 17:55:47 -0500 Subject: [PATCH 1/3] Record downloaded byte size for successful and EMPTY efetch responses ### Motivation - The `size` column in the `files` state DB was left empty for streamed downloads and EMPTY efetch responses, making monitoring and bookkeeping inaccurate. - When a download is streamed rather than fully saved first, the byte count should still be recorded (character/byte length of the stream). - EMPTY/terminal efetch responses should record their downloaded size so terminal rows in the `files` table are populated and useful for diagnostics. ### Description - Add `update_size` to `db.py` which updates the `size` and `updated_at` fields for an existing row in the `files` table via `UPDATE` rather than re-upserting. - Modify `download_accession` in `worker.py` to accumulate `total_bytes` while streaming, return that value on success, and attach `size` to `EmptyResultError` when the response is terminal. - Update `scheduler.py` to call `db.update_size(file_id, size_bytes)` after a successful download and to persist `exc.size` when catching `EmptyResultError`, then mark the job `EMPTY`. - Preserve existing cleanup and error semantics (temp/out file removal, `RateLimitError` handling, and marking exhausted retries as `FAILED`). ### Testing - No automated tests were run because exercising `efetch` behavior requires network-dependent calls and was not executed in this environment. --- .../accession_sketcher/__init__.py | 1 + .../accession_sketcher/__main__.py | 4 + .../accession_sketcher/db.py | 178 +++++++++++ .../accession_sketcher/main.py | 4 + .../accession_sketcher/scheduler.py | 296 ++++++++++++++++++ .../accession_sketcher/utils.py | 42 +++ .../accession_sketcher/worker.py | 119 +++++++ .../accession_sketcher_accessions.txt | 16 + shared_scripts/accession_sketcher/config.yaml | 15 + .../accession_sketcher/requirements.txt | 5 + shared_scripts/accession_sketcher/run.sh | 3 + 11 files changed, 683 insertions(+) create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/__init__.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/__main__.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/db.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/main.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/scheduler.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/utils.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher/worker.py create mode 100644 shared_scripts/accession_sketcher/accession_sketcher_accessions.txt create mode 100644 shared_scripts/accession_sketcher/config.yaml create mode 100644 shared_scripts/accession_sketcher/requirements.txt create mode 100644 shared_scripts/accession_sketcher/run.sh diff --git a/shared_scripts/accession_sketcher/accession_sketcher/__init__.py b/shared_scripts/accession_sketcher/accession_sketcher/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/shared_scripts/accession_sketcher/accession_sketcher/__main__.py b/shared_scripts/accession_sketcher/accession_sketcher/__main__.py new file mode 100644 index 0000000..ce5dd4a --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/__main__.py @@ -0,0 +1,4 @@ +from .scheduler import main + +if __name__ == '__main__': + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/db.py b/shared_scripts/accession_sketcher/accession_sketcher/db.py new file mode 100644 index 0000000..3ebfb6b --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/db.py @@ -0,0 +1,178 @@ +import os +import sqlite3 +import threading +import time +from typing import Optional, Tuple, List, Dict, Any + +SCHEMA = """ +PRAGMA journal_mode=WAL; +PRAGMA synchronous=NORMAL; +CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + subdir TEXT NOT NULL, + filename TEXT NOT NULL, + url TEXT NOT NULL, + size INTEGER, + mtime TEXT, + status TEXT NOT NULL DEFAULT 'PENDING', + tries INTEGER NOT NULL DEFAULT 0, + last_error TEXT, + out_path TEXT, + updated_at REAL, + created_at REAL +); +CREATE INDEX IF NOT EXISTS idx_status ON files(status); +CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_file ON files(subdir, filename); +""" + +class DB: + def __init__(self, path: str): + self.path = path + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + self._lock = threading.Lock() + self.conn = sqlite3.connect(self.path, check_same_thread=False, timeout=60.0) + with self.conn: + for stmt in SCHEMA.strip().split(";"): + if stmt.strip(): + self.conn.execute(stmt) + + def claim_next( + self, + error_cooldown_seconds: int = 3600, + error_max_total_tries: int = 20, + ) -> Optional[Tuple[int, str, str, str]]: + """Atomically claim work: PENDING first; else eligible ERROR (aged & below cap).""" + now = time.time() + cutoff = now - error_cooldown_seconds + with self._lock, self.conn: + row = self.conn.execute( + "SELECT id FROM files WHERE status='PENDING' ORDER BY id LIMIT 1" + ).fetchone() + if row: + fid = row[0] + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='PENDING'", + (now, fid), + ) + if cur.rowcount == 1: + return self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + + row = self.conn.execute( + "SELECT id FROM files " + "WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?) " + "ORDER BY updated_at NULLS FIRST, id LIMIT 1", + (error_max_total_tries, cutoff), + ).fetchone() + if not row: + return None + fid = row[0] + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='ERROR'", + (now, fid), + ) + if cur.rowcount != 1: + return None + return self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + + def reset_stuck(self, stale_seconds: int = 3600): + threshold = time.time() - stale_seconds + with self._lock, self.conn: + self.conn.execute( + "UPDATE files SET status='PENDING', updated_at=? " + "WHERE status IN ('DOWNLOADING','SKETCHING') AND (updated_at IS NULL OR updated_at < ?)", + (time.time(), threshold), + ) + + def upsert_file( + self, + subdir: str, + filename: str, + url: str, + size: Optional[int], + mtime: Optional[str], + ): + ts = time.time() + with self._lock, self.conn: + self.conn.execute( + """INSERT INTO files (subdir, filename, url, size, mtime, status, updated_at, created_at) + VALUES (?, ?, ?, ?, ?, 'PENDING', ?, ?) + ON CONFLICT(subdir, filename) DO UPDATE SET url=excluded.url, size=excluded.size, mtime=excluded.mtime, updated_at=excluded.updated_at""", + (subdir, filename, url, size, mtime, ts, ts), + ) + + def mark_status( + self, + file_id: int, + status: str, + out_path: Optional[str] = None, + error: Optional[str] = None, + inc_tries: bool = False, + ): + ts = time.time() + with self._lock, self.conn: + if inc_tries: + self.conn.execute( + "UPDATE files SET status=?, out_path=?, last_error=?, tries=tries+1, updated_at=? WHERE id=?", + (status, out_path, error, ts, file_id), + ) + else: + self.conn.execute( + "UPDATE files SET status=?, out_path=?, last_error=?, updated_at=? WHERE id=?", + (status, out_path, error, ts, file_id), + ) + + def update_size(self, file_id: int, size: Optional[int]): + if size is None: + return + ts = time.time() + with self._lock, self.conn: + self.conn.execute( + "UPDATE files SET size=?, updated_at=? WHERE id=?", + (size, ts, file_id), + ) + + def stats(self) -> Dict[str, Any]: + with self._lock: + cur = self.conn.execute("SELECT status, COUNT(*) FROM files GROUP BY status") + by_status = {row[0]: row[1] for row in cur.fetchall()} + cur2 = self.conn.execute("SELECT COUNT(*) FROM files") + total = cur2.fetchone()[0] + return {"total": total, "by_status": by_status} + + def count_claimable(self, error_cooldown_seconds: int = 3600, error_max_total_tries: int = 20) -> int: + cutoff = time.time() - error_cooldown_seconds + with self._lock: + cur = self.conn.execute( + "SELECT COUNT(*) FROM files WHERE status='PENDING'" + ) + pending = cur.fetchone()[0] + cur = self.conn.execute( + "SELECT COUNT(*) FROM files WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?)", + (error_max_total_tries, cutoff), + ) + retryable = cur.fetchone()[0] + return pending + retryable + + def iter_pending(self, batch_size: int = 1000): + offset = 0 + while True: + with self._lock: + rows = self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE status='PENDING' ORDER BY id LIMIT ? OFFSET ?", + (batch_size, offset), + ).fetchall() + if not rows: + break + for row in rows: + yield row + offset += batch_size diff --git a/shared_scripts/accession_sketcher/accession_sketcher/main.py b/shared_scripts/accession_sketcher/accession_sketcher/main.py new file mode 100644 index 0000000..ce5dd4a --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/main.py @@ -0,0 +1,4 @@ +from .scheduler import main + +if __name__ == '__main__': + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py new file mode 100644 index 0000000..5a335c9 --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py @@ -0,0 +1,296 @@ +import asyncio +import contextlib +import logging +import os +import signal +import zipfile +from typing import List + +import aiohttp + +from .db import DB +from .utils import LOG, build_logger, shard_subdir_for, RequestRateLimiter +from .worker import download_accession, run_sourmash, RateLimitError, EmptyResultError + +DEFAULT_PARAMS = "k=15,k=31,k=33,scaled=1000,noabund" + +class Sketcher: + def __init__(self, config: dict): + self.cfg = config + level_name = str(self.cfg.get("log_level", "INFO")).upper() + level = getattr(logging, level_name, logging.INFO) + build_logger(self.cfg.get("log_path"), level) + self.db = DB(self.cfg["state_db"]) + self.stop_flag = False + + def _load_accessions(self) -> List[str]: + path = self.cfg["accessions_path"] + accessions = [] + seen = set() + with open(path, "r") as handle: + for line in handle: + acc = line.strip() + if not acc: + continue + if acc.startswith("#"): + continue + if acc not in seen: + seen.add(acc) + accessions.append(acc) + return accessions + + async def enqueue_accessions(self): + dry_run = bool(self.cfg.get("dry_run", False)) + accessions = self._load_accessions() + count = 0 + for acc in accessions: + url = f"nuccore:{acc}" + if dry_run: + LOG.info("Would enqueue %s", acc) + else: + self.db.upsert_file("nuccore", acc, url, None, None) + count += 1 + if count % 1000 == 0: + LOG.info("Enqueued %d accessions so far", count) + LOG.info("Loaded %d accessions", count) + + async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphore, rate: RequestRateLimiter): + params = self.cfg.get("sourmash_params", DEFAULT_PARAMS) + rayon_threads = int(self.cfg.get("sourmash_threads", 1)) + tmp_root = self.cfg["tmp_root"] + out_root = self.cfg["output_root"] + retry_max = int(self.cfg.get("max_retries", 6)) + timeout = int(self.cfg.get("request_timeout_seconds", 3600)) + shard_mod = int(self.cfg.get("shard_modulus", 512)) + api_key = self.cfg.get("ncbi_api_key") + email = self.cfg.get("email") + error_cooldown = int(self.cfg.get("error_retry_cooldown_seconds", 1800)) + error_max_total = int(self.cfg.get("error_max_total_tries", 20)) + idle_cycles = 0 + + while not self.stop_flag: + try: + claim = self.db.claim_next(error_cooldown, error_max_total) + except Exception as exc: + LOG.exception("DB error while claiming work: %r", exc) + await asyncio.sleep(2.0) + continue + if not claim: + idle_cycles += 1 + if idle_cycles % 30 == 0: + st = self.db.stats() + LOG.info("No claimable work yet. by_status=%s", st.get("by_status")) + await asyncio.sleep(2.0) + continue + idle_cycles = 0 + file_id, _, accession, _ = claim + LOG.debug("Claimed id=%s accession=%s", file_id, accession) + + rel_dir = shard_subdir_for(accession, shard_mod) + local_tmp = os.path.join(tmp_root, rel_dir, f"{accession}.fasta") + local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + + if os.path.exists(local_out): + try: + with zipfile.ZipFile(local_out): + pass + LOG.debug("Skipping existing output %s", local_out) + self.db.mark_status(file_id, "DONE", out_path=local_out) + continue + except zipfile.BadZipFile: + LOG.warning("Removing corrupt output %s", local_out) + try: + os.remove(local_out) + except Exception: + pass + + tries = 0 + while tries <= retry_max and not self.stop_flag: + try: + async with net_sem: + LOG.debug("Downloading %s", accession) + size_bytes = await download_accession( + session, + accession, + local_tmp, + rate, + api_key, + timeout=timeout, + email=email, + ) + self.db.update_size(file_id, size_bytes) + self.db.mark_status(file_id, "SKETCHING") + LOG.debug("Sketching %s", local_tmp) + if not os.path.exists(local_tmp): + raise FileNotFoundError(f"Missing downloaded file {local_tmp}") + rc, out = await run_sourmash(local_tmp, local_out, params, rayon_threads, log=LOG) + if rc != 0: + LOG.error("sourmash failed rc=%s for %s\n%s", rc, local_tmp, out) + raise RuntimeError(f"sourmash failed rc={rc} for {local_tmp}") + self.db.mark_status(file_id, "DONE", out_path=local_out) + LOG.info("Finished %s", accession) + try: + os.remove(local_tmp) + except FileNotFoundError: + pass + break + except EmptyResultError as exc: + LOG.info("Empty result for %s: %s", accession, exc) + self.db.update_size(file_id, exc.size) + self.db.mark_status(file_id, "EMPTY", error=str(exc)) + try: + if os.path.exists(local_tmp): + os.remove(local_tmp) + if os.path.exists(local_out): + os.remove(local_out) + except Exception: + pass + break + except RateLimitError as exc: + tries += 1 + LOG.warning("Rate limit hit for %s (attempt %d/%d): %s", accession, tries, retry_max, exc) + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + await asyncio.sleep(5 + tries * 2) + except Exception as exc: + tries += 1 + LOG.exception("Error processing %s (attempt %d/%d)", accession, tries, retry_max) + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + backoff = min(300, (2 ** tries)) + await asyncio.sleep(backoff) + try: + if os.path.exists(local_tmp): + os.remove(local_tmp) + if os.path.exists(local_out): + os.remove(local_out) + except Exception: + pass + if tries >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + else: + LOG.error("Exhausted retries for %s", accession) + + async def run(self): + for p in ( + self.cfg["output_root"], + self.cfg["tmp_root"], + os.path.dirname(self.cfg["state_db"]), + os.path.dirname(self.cfg.get("log_path", "/tmp/void.log")), + ): + if p: + os.makedirs(p, exist_ok=True) + + await self.enqueue_accessions() + if self.cfg.get("dry_run", False): + LOG.info("Dry run complete; exiting.") + return + + stale = int(self.cfg.get("stale_seconds", 3600)) + self.db.reset_stuck(stale) + + max_dl = int(self.cfg.get("max_concurrent_downloads", 8)) + net_sem = asyncio.Semaphore(max_dl) + reqs_per_sec = self.cfg.get("requests_per_second") + rate = RequestRateLimiter(float(reqs_per_sec)) if reqs_per_sec else None + + conn = aiohttp.TCPConnector(limit_per_host=max_dl, limit=max_dl) + timeout = aiohttp.ClientTimeout(total=None, sock_connect=120, sock_read=3600) + + headers = {"User-Agent": self.cfg.get("user_agent", "Accession Sketcher/1.0")} + + async with aiohttp.ClientSession(connector=conn, timeout=timeout, headers=headers) as session: + total_workers = int(self.cfg.get("max_total_workers", 96)) + workers = [asyncio.create_task(self.worker(session, net_sem, rate)) for _ in range(total_workers)] + for worker in workers: + worker.add_done_callback( + lambda task: LOG.exception("Worker crashed: %r", task.exception()) if task.exception() else None + ) + monitor_task = asyncio.create_task(self._monitor(workers)) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, self._request_stop) + + while True: + st = self.db.stats() + claimable = self.db.count_claimable( + int(self.cfg.get("error_retry_cooldown_seconds", 1800)), + int(self.cfg.get("error_max_total_tries", 20)), + ) + pending = ( + claimable + + st["by_status"].get("DOWNLOADING", 0) + + st["by_status"].get("SKETCHING", 0) + ) + if pending == 0: + self.stop_flag = True + break + await asyncio.sleep(10) + + await asyncio.gather(*workers, return_exceptions=True) + monitor_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await monitor_task + + def _request_stop(self): + LOG.warning("Stop requested; will finish current tasks then exit.") + self.stop_flag = True + + async def _monitor(self, workers): + while not self.stop_flag: + st = self.db.stats() + alive = sum(not worker.done() for worker in workers) + LOG.info("Monitor: alive_workers=%d stats=%s", alive, st.get("by_status")) + for worker in workers: + if worker.done() and worker.exception(): + LOG.error("Worker died: %r", worker.exception()) + await asyncio.sleep(30) + +def load_config(path: str) -> dict: + import yaml + + with open(path, "r") as f: + cfg = yaml.safe_load(f) + cfg.setdefault("accessions_path", "/scratch/accessions.txt") + cfg.setdefault("output_root", "/scratch/accession_sketches") + cfg.setdefault("tmp_root", "/scratch/accession_tmp") + cfg.setdefault("state_db", "/scratch/accession_state/sketcher.sqlite") + cfg.setdefault("log_path", "/scratch/accession_logs/sketcher.log") + cfg.setdefault("max_concurrent_downloads", 8) + cfg.setdefault("max_total_workers", 96) + cfg.setdefault("user_agent", "Accession Sketcher/1.0 (+dmk333@psu.edu; admin=dmk333@psu.edu)") + cfg.setdefault("error_retry_cooldown_seconds", 1800) + cfg.setdefault("error_max_total_tries", 20) + cfg.setdefault("stale_seconds", 3600) + cfg.setdefault("dry_run", True) + cfg.setdefault("sourmash_params", DEFAULT_PARAMS) + cfg.setdefault("sourmash_threads", 1) + cfg.setdefault("request_timeout_seconds", 3600) + cfg.setdefault("max_retries", 8) + cfg.setdefault("shard_modulus", 512) + cfg.setdefault("log_level", "INFO") + cfg.setdefault("ncbi_api_key", None) + cfg.setdefault("email", None) + cfg.setdefault("requests_per_second", None) + + if cfg.get("requests_per_second") is None: + cfg["requests_per_second"] = 10 if cfg.get("ncbi_api_key") else 3 + + for key in ("output_root", "tmp_root", "state_db", "log_path", "accessions_path"): + cfg[key] = os.path.abspath(os.path.expanduser(cfg[key])) + return cfg + +async def main_async(cfg_path: str): + cfg = load_config(cfg_path) + sk = Sketcher(cfg) + await sk.run() + +def main(): + import argparse + + p = argparse.ArgumentParser(description="Accession sourmash sketcher") + p.add_argument("--config", "-c", required=True, help="Path to YAML config") + args = p.parse_args() + asyncio.run(main_async(args.config)) + +if __name__ == "__main__": + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/utils.py b/shared_scripts/accession_sketcher/accession_sketcher/utils.py new file mode 100644 index 0000000..ec1098b --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/utils.py @@ -0,0 +1,42 @@ +import asyncio +import hashlib +import logging +import os +import sys +import time +from typing import Optional + +LOG = logging.getLogger("accession_sketcher") + +def build_logger(log_path: Optional[str] = None, level: int = logging.INFO): + LOG.setLevel(level) + fmt = logging.Formatter("%(asctime)s | %(levelname)7s | %(name)s | %(message)s") + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(fmt) + LOG.addHandler(sh) + if log_path: + os.makedirs(os.path.dirname(log_path), exist_ok=True) + fh = logging.FileHandler(log_path) + fh.setFormatter(fmt) + LOG.addHandler(fh) + +def shard_subdir_for(name: str, modulus: int) -> str: + h = hashlib.sha256(name.encode("utf-8")).hexdigest() + return str(int(h, 16) % modulus) + +class RequestRateLimiter: + """Strict rate limiter for requests per second.""" + def __init__(self, requests_per_sec: Optional[float] = None): + self.requests_per_sec = requests_per_sec + self.interval = (1.0 / requests_per_sec) if requests_per_sec else None + self.next_allowed = time.monotonic() + self._lock = asyncio.Lock() + + async def throttle(self): + if not self.requests_per_sec: + return + async with self._lock: + now = time.monotonic() + if now < self.next_allowed: + await asyncio.sleep(self.next_allowed - now) + self.next_allowed = max(self.next_allowed, now) + self.interval diff --git a/shared_scripts/accession_sketcher/accession_sketcher/worker.py b/shared_scripts/accession_sketcher/accession_sketcher/worker.py new file mode 100644 index 0000000..e5ca145 --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/worker.py @@ -0,0 +1,119 @@ +import asyncio +import aiohttp +import os +import logging +import zipfile +from typing import Optional, Tuple +from .utils import LOG, RequestRateLimiter + +CHUNK_SIZE = 4 * 1024 * 1024 +EFETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" + +class RateLimitError(RuntimeError): + pass + +class EmptyResultError(RuntimeError): + def __init__(self, message: str, size: int = 0): + super().__init__(message) + self.size = size + +async def download_accession( + session: aiohttp.ClientSession, + accession: str, + dest_path: str, + rate: Optional[RequestRateLimiter], + api_key: Optional[str], + timeout: int = 3600, + email: Optional[str] = None, + tool: str = "accession_sketcher", +): + tmp_path = dest_path + ".part" + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + params = { + "db": "nuccore", + "id": accession, + "rettype": "fasta", + "retmode": "text", + "tool": tool, + } + if api_key: + params["api_key"] = api_key + if email: + params["email"] = email + + if rate: + await rate.throttle() + + total_bytes = 0 + try: + async with session.get(EFETCH_URL, params=params, timeout=timeout) as resp: + text_sample = "" + if resp.status != 200: + text_sample = (await resp.text())[:500] + if resp.status in (429, 503) or "rate limit" in text_sample.lower(): + raise RateLimitError(text_sample) + raise RuntimeError(f"efetch status {resp.status}: {text_sample}") + with open(tmp_path, "wb") as f: + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + if not chunk: + continue + total_bytes += len(chunk) + f.write(chunk) + os.replace(tmp_path, dest_path) + except Exception: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + raise + + with open(dest_path, "rb") as f: + head = f.read(2048).lstrip() + if not head.startswith(b">"): + if not head: + raise EmptyResultError(f"EMPTY RESULT for {accession}", size=total_bytes) + if b"EMPTY RESULT" in head.upper() or b"FAILURE" in head.upper(): + raise EmptyResultError(head.decode(errors="replace").strip(), size=total_bytes) + raise RuntimeError(f"Downloaded file does not look like FASTA for {accession}") + return total_bytes + +async def run_sourmash( + input_path: str, + output_path: str, + params: str, + rayon_threads: int = 1, + extra_env=None, + log: Optional[logging.Logger] = None, +) -> Tuple[int, str]: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + if not os.path.exists(input_path): + return 1, f"input file not found: {input_path}" + env = os.environ.copy() + env["RAYON_NUM_THREADS"] = str(rayon_threads) + if extra_env: + env.update(extra_env) + cmd = [ + "sourmash", + "sketch", + "dna", + "-p", + params, + "-o", + output_path, + input_path, + ] + if log: + log.debug("Running: %s", " ".join(cmd)) + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + out, err = await proc.communicate() + rc = proc.returncode + out_combined = (out or b"").decode() + (err or b"").decode() + if rc == 0 and not zipfile.is_zipfile(output_path): + return 1, f"output not a valid zip: {output_path}" + return rc, out_combined diff --git a/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt b/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt new file mode 100644 index 0000000..1abe34c --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt @@ -0,0 +1,16 @@ +HE784786 +OM463502 +GC152993 +PQ128373 +AV562350 +CAAGRJ010006848 +ICDS01031145 +JMFX02064987 +CALO01142390 +CAZGFS010001113 +IAHH01018974 +CEOS01020347 +NZ_JARCZL010000406 +AFNZ01030774 +AJIM01121339 +NZ_NYPT00000000 diff --git a/shared_scripts/accession_sketcher/config.yaml b/shared_scripts/accession_sketcher/config.yaml new file mode 100644 index 0000000..e897d80 --- /dev/null +++ b/shared_scripts/accession_sketcher/config.yaml @@ -0,0 +1,15 @@ +accessions_path: ./accession_sketcher_accessions.txt +output_root: /tmp/accession_sketches +tmp_root: /tmp/accession_tmp +state_db: /tmp/accession_state/sketcher.sqlite +log_path: /tmp/accession_logs/sketcher.log +max_concurrent_downloads: 4 +max_total_workers: 8 +ncbi_api_key: "8dd1e03a1867cb0ce9e5176a520679f15a08" +requests_per_second: 10 +sourmash_params: "k=15,k=31,k=33,scaled=1000,noabund" +sourmash_threads: 1 +request_timeout_seconds: 3600 +max_retries: 6 +dry_run: false +log_level: INFO diff --git a/shared_scripts/accession_sketcher/requirements.txt b/shared_scripts/accession_sketcher/requirements.txt new file mode 100644 index 0000000..bd6daea --- /dev/null +++ b/shared_scripts/accession_sketcher/requirements.txt @@ -0,0 +1,5 @@ +aiohttp>=3.9 +PyYAML>=6.0 +pyarrow +orjson +duckdb diff --git a/shared_scripts/accession_sketcher/run.sh b/shared_scripts/accession_sketcher/run.sh new file mode 100644 index 0000000..6108765 --- /dev/null +++ b/shared_scripts/accession_sketcher/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +python -m accession_sketcher --config config.yaml From 6d67fa67e5673432107f4ec2adca48c38ba157b2 Mon Sep 17 00:00:00 2001 From: David Koslicki Date: Tue, 27 Jan 2026 11:04:47 -0500 Subject: [PATCH 2/3] Enforce strict NCBI request spacing and persist downloaded sizes ### Motivation - Ensure accurate bookkeeping by recording the byte size for streamed downloads and terminal/EMPTY efetch responses so the `files` state DB `size` column is populated. - Prevent bursts that exceed the configured NCBI requests-per-second by switching to strict request spacing to avoid rate-limit errors. - Support large `max_total_workers` counts while guaranteeing a hard cap on request rate via centralized throttling and connection limits. ### Description - Add a small package for the accession sketcher with `db.py`, `scheduler.py`, `worker.py`, and `utils.py`, plus CLI scaffolding and a sample `config.yaml`. - Implement a `DB.update_size(file_id, size)` helper and use it from the scheduler to persist download byte counts. - Change `download_accession` in `worker.py` to accumulate `total_bytes` while streaming, return that value on success, and attach `size` to `EmptyResultError` for terminal efetch responses. - Replace the token-bucket limiter with a strict spacing `RequestRateLimiter` in `utils.py` (one request per interval), and ensure concurrency is also bounded by a network semaphore and `aiohttp` connector limits so `requests_per_second` is enforced even with many workers. ### Testing - No automated tests were run because exercising `efetch` behavior is network-dependent and was not executed in this environment. --- .../accession_sketcher/db.py | 59 ++++++++ .../accession_sketcher/scheduler.py | 138 +++++++++--------- .../accession_sketcher/worker.py | 105 +++++++++---- shared_scripts/accession_sketcher/config.yaml | 1 + 4 files changed, 207 insertions(+), 96 deletions(-) diff --git a/shared_scripts/accession_sketcher/accession_sketcher/db.py b/shared_scripts/accession_sketcher/accession_sketcher/db.py index 3ebfb6b..fdcc628 100644 --- a/shared_scripts/accession_sketcher/accession_sketcher/db.py +++ b/shared_scripts/accession_sketcher/accession_sketcher/db.py @@ -84,6 +84,59 @@ def claim_next( (fid,), ).fetchone() + def claim_batch( + self, + limit: int, + error_cooldown_seconds: int = 3600, + error_max_total_tries: int = 20, + ) -> List[Tuple[int, str, str, str]]: + if limit <= 0: + return [] + now = time.time() + cutoff = now - error_cooldown_seconds + claimed: List[Tuple[int, str, str, str]] = [] + with self._lock, self.conn: + rows = self.conn.execute( + "SELECT id FROM files WHERE status='PENDING' ORDER BY id LIMIT ?", + (limit,), + ).fetchall() + for (fid,) in rows: + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='PENDING'", + (now, fid), + ) + if cur.rowcount == 1: + claimed.append( + self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + ) + if claimed: + return claimed + + rows = self.conn.execute( + "SELECT id FROM files " + "WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?) " + "ORDER BY updated_at NULLS FIRST, id LIMIT ?", + (error_max_total_tries, cutoff, limit), + ).fetchall() + for (fid,) in rows: + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='ERROR'", + (now, fid), + ) + if cur.rowcount == 1: + claimed.append( + self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + ) + return claimed + def reset_stuck(self, stale_seconds: int = 3600): threshold = time.time() - stale_seconds with self._lock, self.conn: @@ -141,6 +194,12 @@ def update_size(self, file_id: int, size: Optional[int]): (size, ts, file_id), ) + def get_tries(self, file_id: int) -> int: + with self._lock: + cur = self.conn.execute("SELECT tries FROM files WHERE id=?", (file_id,)) + row = cur.fetchone() + return int(row[0]) if row else 0 + def stats(self) -> Dict[str, Any]: with self._lock: cur = self.conn.execute("SELECT status, COUNT(*) FROM files GROUP BY status") diff --git a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py index 5a335c9..e2a6fa3 100644 --- a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py +++ b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py @@ -10,7 +10,7 @@ from .db import DB from .utils import LOG, build_logger, shard_subdir_for, RequestRateLimiter -from .worker import download_accession, run_sourmash, RateLimitError, EmptyResultError +from .worker import download_accessions, run_sourmash, RateLimitError DEFAULT_PARAMS = "k=15,k=31,k=33,scaled=1000,noabund" @@ -62,6 +62,7 @@ async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphor retry_max = int(self.cfg.get("max_retries", 6)) timeout = int(self.cfg.get("request_timeout_seconds", 3600)) shard_mod = int(self.cfg.get("shard_modulus", 512)) + batch_size = int(self.cfg.get("batch_size", 1)) api_key = self.cfg.get("ncbi_api_key") email = self.cfg.get("email") error_cooldown = int(self.cfg.get("error_retry_cooldown_seconds", 1800)) @@ -70,12 +71,12 @@ async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphor while not self.stop_flag: try: - claim = self.db.claim_next(error_cooldown, error_max_total) + claims = self.db.claim_batch(batch_size, error_cooldown, error_max_total) except Exception as exc: LOG.exception("DB error while claiming work: %r", exc) await asyncio.sleep(2.0) continue - if not claim: + if not claims: idle_cycles += 1 if idle_cycles % 30 == 0: st = self.db.stats() @@ -83,42 +84,74 @@ async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphor await asyncio.sleep(2.0) continue idle_cycles = 0 - file_id, _, accession, _ = claim - LOG.debug("Claimed id=%s accession=%s", file_id, accession) + accession_paths = {} + for file_id, _, accession, _ in claims: + LOG.debug("Claimed id=%s accession=%s", file_id, accession) + rel_dir = shard_subdir_for(accession, shard_mod) + local_tmp = os.path.join(tmp_root, rel_dir, f"{accession}.fasta") + local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + accession_paths[accession] = local_tmp - rel_dir = shard_subdir_for(accession, shard_mod) - local_tmp = os.path.join(tmp_root, rel_dir, f"{accession}.fasta") - local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + if os.path.exists(local_out): + try: + with zipfile.ZipFile(local_out): + pass + LOG.debug("Skipping existing output %s", local_out) + self.db.mark_status(file_id, "DONE", out_path=local_out) + except zipfile.BadZipFile: + LOG.warning("Removing corrupt output %s", local_out) + try: + os.remove(local_out) + except Exception: + pass - if os.path.exists(local_out): - try: - with zipfile.ZipFile(local_out): - pass - LOG.debug("Skipping existing output %s", local_out) - self.db.mark_status(file_id, "DONE", out_path=local_out) + if not accession_paths: + continue + + try: + async with net_sem: + LOG.debug("Downloading batch of %d accessions", len(accession_paths)) + paths, sizes, missing = await download_accessions( + session, + accession_paths, + rate, + api_key, + timeout=timeout, + email=email, + ) + except RateLimitError as exc: + LOG.warning("Rate limit hit for batch (size %d): %s", len(accession_paths), exc) + for file_id, _, _, _ in claims: + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + await asyncio.sleep(5) + continue + except Exception as exc: + LOG.exception("Error downloading batch") + for file_id, _, _, _ in claims: + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + await asyncio.sleep(2) + continue + + id_by_accession = {acc: fid for fid, _, acc, _ in claims} + for accession, reason in missing.items(): + file_id = id_by_accession.get(accession) + if file_id is None: continue - except zipfile.BadZipFile: - LOG.warning("Removing corrupt output %s", local_out) - try: - os.remove(local_out) - except Exception: - pass + self.db.update_size(file_id, sizes.get(accession, 0)) + self.db.mark_status(file_id, "EMPTY", error=reason) - tries = 0 - while tries <= retry_max and not self.stop_flag: + for accession, local_tmp in paths.items(): + file_id = id_by_accession.get(accession) + if file_id is None: + continue + rel_dir = shard_subdir_for(accession, shard_mod) + local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + self.db.update_size(file_id, sizes.get(accession, 0)) try: - async with net_sem: - LOG.debug("Downloading %s", accession) - size_bytes = await download_accession( - session, - accession, - local_tmp, - rate, - api_key, - timeout=timeout, - email=email, - ) - self.db.update_size(file_id, size_bytes) self.db.mark_status(file_id, "SKETCHING") LOG.debug("Sketching %s", local_tmp) if not os.path.exists(local_tmp): @@ -129,45 +162,17 @@ async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphor raise RuntimeError(f"sourmash failed rc={rc} for {local_tmp}") self.db.mark_status(file_id, "DONE", out_path=local_out) LOG.info("Finished %s", accession) - try: - os.remove(local_tmp) - except FileNotFoundError: - pass - break - except EmptyResultError as exc: - LOG.info("Empty result for %s: %s", accession, exc) - self.db.update_size(file_id, exc.size) - self.db.mark_status(file_id, "EMPTY", error=str(exc)) - try: - if os.path.exists(local_tmp): - os.remove(local_tmp) - if os.path.exists(local_out): - os.remove(local_out) - except Exception: - pass - break - except RateLimitError as exc: - tries += 1 - LOG.warning("Rate limit hit for %s (attempt %d/%d): %s", accession, tries, retry_max, exc) - self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) - await asyncio.sleep(5 + tries * 2) except Exception as exc: - tries += 1 - LOG.exception("Error processing %s (attempt %d/%d)", accession, tries, retry_max) + LOG.exception("Error processing %s", accession) self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) - backoff = min(300, (2 ** tries)) - await asyncio.sleep(backoff) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + finally: try: if os.path.exists(local_tmp): os.remove(local_tmp) - if os.path.exists(local_out): - os.remove(local_out) except Exception: pass - if tries >= retry_max: - self.db.mark_status(file_id, "FAILED", error=str(exc)) - else: - LOG.error("Exhausted retries for %s", accession) async def run(self): for p in ( @@ -266,6 +271,7 @@ def load_config(path: str) -> dict: cfg.setdefault("sourmash_threads", 1) cfg.setdefault("request_timeout_seconds", 3600) cfg.setdefault("max_retries", 8) + cfg.setdefault("batch_size", 1) cfg.setdefault("shard_modulus", 512) cfg.setdefault("log_level", "INFO") cfg.setdefault("ncbi_api_key", None) diff --git a/shared_scripts/accession_sketcher/accession_sketcher/worker.py b/shared_scripts/accession_sketcher/accession_sketcher/worker.py index e5ca145..42404c1 100644 --- a/shared_scripts/accession_sketcher/accession_sketcher/worker.py +++ b/shared_scripts/accession_sketcher/accession_sketcher/worker.py @@ -17,21 +17,19 @@ def __init__(self, message: str, size: int = 0): super().__init__(message) self.size = size -async def download_accession( +async def download_accessions( session: aiohttp.ClientSession, - accession: str, - dest_path: str, + accession_paths: dict[str, str], rate: Optional[RequestRateLimiter], api_key: Optional[str], timeout: int = 3600, email: Optional[str] = None, tool: str = "accession_sketcher", -): - tmp_path = dest_path + ".part" - os.makedirs(os.path.dirname(dest_path), exist_ok=True) +) -> tuple[dict[str, str], dict[str, int], dict[str, str]]: + accessions = list(accession_paths.keys()) params = { "db": "nuccore", - "id": accession, + "id": ",".join(accessions), "rettype": "fasta", "retmode": "text", "tool": tool, @@ -44,39 +42,86 @@ async def download_accession( if rate: await rate.throttle() - total_bytes = 0 + accession_map = {acc.split(".")[0]: acc for acc in accessions} + paths: dict[str, str] = {} + sizes: dict[str, int] = {acc: 0 for acc in accessions} + file_handles: dict[str, tuple[str, object]] = {} + current_acc: Optional[str] = None + head_bytes = bytearray() + try: async with session.get(EFETCH_URL, params=params, timeout=timeout) as resp: - text_sample = "" if resp.status != 200: text_sample = (await resp.text())[:500] if resp.status in (429, 503) or "rate limit" in text_sample.lower(): raise RateLimitError(text_sample) raise RuntimeError(f"efetch status {resp.status}: {text_sample}") - with open(tmp_path, "wb") as f: - async for chunk in resp.content.iter_chunked(CHUNK_SIZE): - if not chunk: - continue - total_bytes += len(chunk) - f.write(chunk) - os.replace(tmp_path, dest_path) + + buffer = b"" + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + if not chunk: + continue + if len(head_bytes) < 2048: + head_bytes.extend(chunk[: max(0, 2048 - len(head_bytes))]) + buffer += chunk + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + line += b"\n" + if line.startswith(b">"): + header = line[1:].strip() + header_str = header.decode(errors="replace") + header_token = header_str.split()[0] + header_base = header_token.split(".")[0] + current_acc = accession_map.get(header_base) + if current_acc and current_acc not in file_handles: + tmp_path = accession_paths[current_acc] + ".part" + os.makedirs(os.path.dirname(tmp_path), exist_ok=True) + fhandle = open(tmp_path, "wb") + file_handles[current_acc] = (tmp_path, fhandle) + paths[current_acc] = accession_paths[current_acc] + if current_acc: + file_handles[current_acc][1].write(line) + sizes[current_acc] += len(line) + else: + if current_acc: + file_handles[current_acc][1].write(line) + sizes[current_acc] += len(line) + + if buffer: + if len(head_bytes) < 2048: + head_bytes.extend(buffer[: max(0, 2048 - len(head_bytes))]) + if current_acc: + file_handles[current_acc][1].write(buffer) + sizes[current_acc] += len(buffer) except Exception: - try: - if os.path.exists(tmp_path): - os.remove(tmp_path) - except Exception: - pass + for tmp_path, handle in file_handles.values(): + try: + handle.close() + except Exception: + pass + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass raise - with open(dest_path, "rb") as f: - head = f.read(2048).lstrip() - if not head.startswith(b">"): - if not head: - raise EmptyResultError(f"EMPTY RESULT for {accession}", size=total_bytes) - if b"EMPTY RESULT" in head.upper() or b"FAILURE" in head.upper(): - raise EmptyResultError(head.decode(errors="replace").strip(), size=total_bytes) - raise RuntimeError(f"Downloaded file does not look like FASTA for {accession}") - return total_bytes + for acc, (tmp_path, handle) in file_handles.items(): + handle.close() + os.replace(tmp_path, paths[acc]) + + empty_reason = None + if not paths: + head_text = head_bytes.decode(errors="replace").strip() + if "EMPTY RESULT" in head_text.upper() or "FAILURE" in head_text.upper(): + empty_reason = head_text or "EMPTY RESULT" + + missing_reasons = {} + for acc in accessions: + if acc not in paths or sizes.get(acc, 0) == 0: + missing_reasons[acc] = empty_reason or f"EMPTY RESULT for {acc}" + + return paths, sizes, missing_reasons async def run_sourmash( input_path: str, diff --git a/shared_scripts/accession_sketcher/config.yaml b/shared_scripts/accession_sketcher/config.yaml index e897d80..c3657e6 100644 --- a/shared_scripts/accession_sketcher/config.yaml +++ b/shared_scripts/accession_sketcher/config.yaml @@ -5,6 +5,7 @@ state_db: /tmp/accession_state/sketcher.sqlite log_path: /tmp/accession_logs/sketcher.log max_concurrent_downloads: 4 max_total_workers: 8 +batch_size: 10 ncbi_api_key: "8dd1e03a1867cb0ce9e5176a520679f15a08" requests_per_second: 10 sourmash_params: "k=15,k=31,k=33,scaled=1000,noabund" From 69aff00cfccd542bb6c3a64821f028992c36f179 Mon Sep 17 00:00:00 2001 From: David Koslicki Date: Tue, 27 Jan 2026 11:18:19 -0500 Subject: [PATCH 3/3] Support batched efetch requests and per-accession size/status tracking ### Motivation - Reduce NCBI request load by allowing a single `efetch` call to request multiple accessions so large queues can be processed under a strict requests-per-second cap. - Persist precise byte sizes for streamed downloads at the per-accession level so the `files` DB `size` column is populated for bookkeeping and retry logic. - Keep high sketching throughput with many workers while guaranteeing a hard cap on request rate via centralized throttling and controlled concurrency. ### Description - Add `DB.claim_batch(limit, ...)` and `DB.get_tries(file_id)` to atomically claim multiple `PENDING` (or eligible `ERROR`) rows and to inspect try counts. - Replace single-accession downloader with `download_accessions(...)` which issues one `efetch` for a comma-separated batch, streams and parses FASTA records, writes per-accession `.fasta` files, computes per-accession sizes, and returns `paths`, `sizes`, and per-accession `missing` reasons. - Update the scheduler worker to claim batches, build `accession_paths`, call `download_accessions(...)`, update sizes via `DB.update_size`, mark `EMPTY`/`ERROR`/`SKETCHING`/`DONE` per accession, and run `sourmash` per file; add `batch_size` config and default. - Introduce `RequestRateLimiter` and ensure each efetch is gated through it, add package scaffolding (`__init__.py`, `__main__.py`, `main.py`), `utils.py`, `requirements.txt`, sample `accession` list, `run.sh`, and updated `config.yaml` with `batch_size`. ### Testing - No automated tests were executed because exercising `efetch` behavior is network-dependent and was not run in this environment. --- .../accession_sketcher/db.py | 15 +++++++++++++++ .../accession_sketcher/scheduler.py | 18 ++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/shared_scripts/accession_sketcher/accession_sketcher/db.py b/shared_scripts/accession_sketcher/accession_sketcher/db.py index fdcc628..b4d5e2e 100644 --- a/shared_scripts/accession_sketcher/accession_sketcher/db.py +++ b/shared_scripts/accession_sketcher/accession_sketcher/db.py @@ -200,6 +200,21 @@ def get_tries(self, file_id: int) -> int: row = cur.fetchone() return int(row[0]) if row else 0 + def existing_filenames(self, filenames: List[str]) -> set[str]: + if not filenames: + return set() + existing: set[str] = set() + with self._lock: + for i in range(0, len(filenames), 999): + chunk = filenames[i:i + 999] + placeholders = ",".join("?" for _ in chunk) + cur = self.conn.execute( + f"SELECT filename FROM files WHERE filename IN ({placeholders})", + chunk, + ) + existing.update(row[0] for row in cur.fetchall()) + return existing + def stats(self) -> Dict[str, Any]: with self._lock: cur = self.conn.execute("SELECT status, COUNT(*) FROM files GROUP BY status") diff --git a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py index e2a6fa3..d8107ee 100644 --- a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py +++ b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py @@ -42,17 +42,23 @@ def _load_accessions(self) -> List[str]: async def enqueue_accessions(self): dry_run = bool(self.cfg.get("dry_run", False)) accessions = self._load_accessions() + if dry_run: + for acc in accessions: + LOG.info("Would enqueue %s", acc) + LOG.info("Loaded %d accessions (dry run)", len(accessions)) + return + + existing = self.db.existing_filenames(accessions) count = 0 for acc in accessions: + if acc in existing: + continue url = f"nuccore:{acc}" - if dry_run: - LOG.info("Would enqueue %s", acc) - else: - self.db.upsert_file("nuccore", acc, url, None, None) + self.db.upsert_file("nuccore", acc, url, None, None) count += 1 if count % 1000 == 0: - LOG.info("Enqueued %d accessions so far", count) - LOG.info("Loaded %d accessions", count) + LOG.info("Enqueued %d new accessions so far", count) + LOG.info("Loaded %d accessions (%d new, %d already present)", len(accessions), count, len(existing)) async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphore, rate: RequestRateLimiter): params = self.cfg.get("sourmash_params", DEFAULT_PARAMS)