diff --git a/shared_scripts/accession_sketcher/accession_sketcher/__init__.py b/shared_scripts/accession_sketcher/accession_sketcher/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/shared_scripts/accession_sketcher/accession_sketcher/__main__.py b/shared_scripts/accession_sketcher/accession_sketcher/__main__.py new file mode 100644 index 0000000..ce5dd4a --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/__main__.py @@ -0,0 +1,4 @@ +from .scheduler import main + +if __name__ == '__main__': + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/db.py b/shared_scripts/accession_sketcher/accession_sketcher/db.py new file mode 100644 index 0000000..b4d5e2e --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/db.py @@ -0,0 +1,252 @@ +import os +import sqlite3 +import threading +import time +from typing import Optional, Tuple, List, Dict, Any + +SCHEMA = """ +PRAGMA journal_mode=WAL; +PRAGMA synchronous=NORMAL; +CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + subdir TEXT NOT NULL, + filename TEXT NOT NULL, + url TEXT NOT NULL, + size INTEGER, + mtime TEXT, + status TEXT NOT NULL DEFAULT 'PENDING', + tries INTEGER NOT NULL DEFAULT 0, + last_error TEXT, + out_path TEXT, + updated_at REAL, + created_at REAL +); +CREATE INDEX IF NOT EXISTS idx_status ON files(status); +CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_file ON files(subdir, filename); +""" + +class DB: + def __init__(self, path: str): + self.path = path + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + self._lock = threading.Lock() + self.conn = sqlite3.connect(self.path, check_same_thread=False, timeout=60.0) + with self.conn: + for stmt in SCHEMA.strip().split(";"): + if stmt.strip(): + self.conn.execute(stmt) + + def claim_next( + self, + error_cooldown_seconds: int = 3600, + error_max_total_tries: int = 20, + ) -> Optional[Tuple[int, str, str, str]]: + """Atomically claim work: PENDING first; else eligible ERROR (aged & below cap).""" + now = time.time() + cutoff = now - error_cooldown_seconds + with self._lock, self.conn: + row = self.conn.execute( + "SELECT id FROM files WHERE status='PENDING' ORDER BY id LIMIT 1" + ).fetchone() + if row: + fid = row[0] + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='PENDING'", + (now, fid), + ) + if cur.rowcount == 1: + return self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + + row = self.conn.execute( + "SELECT id FROM files " + "WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?) " + "ORDER BY updated_at NULLS FIRST, id LIMIT 1", + (error_max_total_tries, cutoff), + ).fetchone() + if not row: + return None + fid = row[0] + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='ERROR'", + (now, fid), + ) + if cur.rowcount != 1: + return None + return self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + + def claim_batch( + self, + limit: int, + error_cooldown_seconds: int = 3600, + error_max_total_tries: int = 20, + ) -> List[Tuple[int, str, str, str]]: + if limit <= 0: + return [] + now = time.time() + cutoff = now - error_cooldown_seconds + claimed: List[Tuple[int, str, str, str]] = [] + with self._lock, self.conn: + rows = self.conn.execute( + "SELECT id FROM files WHERE status='PENDING' ORDER BY id LIMIT ?", + (limit,), + ).fetchall() + for (fid,) in rows: + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='PENDING'", + (now, fid), + ) + if cur.rowcount == 1: + claimed.append( + self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + ) + if claimed: + return claimed + + rows = self.conn.execute( + "SELECT id FROM files " + "WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?) " + "ORDER BY updated_at NULLS FIRST, id LIMIT ?", + (error_max_total_tries, cutoff, limit), + ).fetchall() + for (fid,) in rows: + cur = self.conn.execute( + "UPDATE files SET status='DOWNLOADING', updated_at=? " + "WHERE id=? AND status='ERROR'", + (now, fid), + ) + if cur.rowcount == 1: + claimed.append( + self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE id=?", + (fid,), + ).fetchone() + ) + return claimed + + def reset_stuck(self, stale_seconds: int = 3600): + threshold = time.time() - stale_seconds + with self._lock, self.conn: + self.conn.execute( + "UPDATE files SET status='PENDING', updated_at=? " + "WHERE status IN ('DOWNLOADING','SKETCHING') AND (updated_at IS NULL OR updated_at < ?)", + (time.time(), threshold), + ) + + def upsert_file( + self, + subdir: str, + filename: str, + url: str, + size: Optional[int], + mtime: Optional[str], + ): + ts = time.time() + with self._lock, self.conn: + self.conn.execute( + """INSERT INTO files (subdir, filename, url, size, mtime, status, updated_at, created_at) + VALUES (?, ?, ?, ?, ?, 'PENDING', ?, ?) + ON CONFLICT(subdir, filename) DO UPDATE SET url=excluded.url, size=excluded.size, mtime=excluded.mtime, updated_at=excluded.updated_at""", + (subdir, filename, url, size, mtime, ts, ts), + ) + + def mark_status( + self, + file_id: int, + status: str, + out_path: Optional[str] = None, + error: Optional[str] = None, + inc_tries: bool = False, + ): + ts = time.time() + with self._lock, self.conn: + if inc_tries: + self.conn.execute( + "UPDATE files SET status=?, out_path=?, last_error=?, tries=tries+1, updated_at=? WHERE id=?", + (status, out_path, error, ts, file_id), + ) + else: + self.conn.execute( + "UPDATE files SET status=?, out_path=?, last_error=?, updated_at=? WHERE id=?", + (status, out_path, error, ts, file_id), + ) + + def update_size(self, file_id: int, size: Optional[int]): + if size is None: + return + ts = time.time() + with self._lock, self.conn: + self.conn.execute( + "UPDATE files SET size=?, updated_at=? WHERE id=?", + (size, ts, file_id), + ) + + def get_tries(self, file_id: int) -> int: + with self._lock: + cur = self.conn.execute("SELECT tries FROM files WHERE id=?", (file_id,)) + row = cur.fetchone() + return int(row[0]) if row else 0 + + def existing_filenames(self, filenames: List[str]) -> set[str]: + if not filenames: + return set() + existing: set[str] = set() + with self._lock: + for i in range(0, len(filenames), 999): + chunk = filenames[i:i + 999] + placeholders = ",".join("?" for _ in chunk) + cur = self.conn.execute( + f"SELECT filename FROM files WHERE filename IN ({placeholders})", + chunk, + ) + existing.update(row[0] for row in cur.fetchall()) + return existing + + def stats(self) -> Dict[str, Any]: + with self._lock: + cur = self.conn.execute("SELECT status, COUNT(*) FROM files GROUP BY status") + by_status = {row[0]: row[1] for row in cur.fetchall()} + cur2 = self.conn.execute("SELECT COUNT(*) FROM files") + total = cur2.fetchone()[0] + return {"total": total, "by_status": by_status} + + def count_claimable(self, error_cooldown_seconds: int = 3600, error_max_total_tries: int = 20) -> int: + cutoff = time.time() - error_cooldown_seconds + with self._lock: + cur = self.conn.execute( + "SELECT COUNT(*) FROM files WHERE status='PENDING'" + ) + pending = cur.fetchone()[0] + cur = self.conn.execute( + "SELECT COUNT(*) FROM files WHERE status='ERROR' AND tries < ? AND (updated_at IS NULL OR updated_at <= ?)", + (error_max_total_tries, cutoff), + ) + retryable = cur.fetchone()[0] + return pending + retryable + + def iter_pending(self, batch_size: int = 1000): + offset = 0 + while True: + with self._lock: + rows = self.conn.execute( + "SELECT id, subdir, filename, url FROM files WHERE status='PENDING' ORDER BY id LIMIT ? OFFSET ?", + (batch_size, offset), + ).fetchall() + if not rows: + break + for row in rows: + yield row + offset += batch_size diff --git a/shared_scripts/accession_sketcher/accession_sketcher/main.py b/shared_scripts/accession_sketcher/accession_sketcher/main.py new file mode 100644 index 0000000..ce5dd4a --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/main.py @@ -0,0 +1,4 @@ +from .scheduler import main + +if __name__ == '__main__': + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py new file mode 100644 index 0000000..d8107ee --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/scheduler.py @@ -0,0 +1,308 @@ +import asyncio +import contextlib +import logging +import os +import signal +import zipfile +from typing import List + +import aiohttp + +from .db import DB +from .utils import LOG, build_logger, shard_subdir_for, RequestRateLimiter +from .worker import download_accessions, run_sourmash, RateLimitError + +DEFAULT_PARAMS = "k=15,k=31,k=33,scaled=1000,noabund" + +class Sketcher: + def __init__(self, config: dict): + self.cfg = config + level_name = str(self.cfg.get("log_level", "INFO")).upper() + level = getattr(logging, level_name, logging.INFO) + build_logger(self.cfg.get("log_path"), level) + self.db = DB(self.cfg["state_db"]) + self.stop_flag = False + + def _load_accessions(self) -> List[str]: + path = self.cfg["accessions_path"] + accessions = [] + seen = set() + with open(path, "r") as handle: + for line in handle: + acc = line.strip() + if not acc: + continue + if acc.startswith("#"): + continue + if acc not in seen: + seen.add(acc) + accessions.append(acc) + return accessions + + async def enqueue_accessions(self): + dry_run = bool(self.cfg.get("dry_run", False)) + accessions = self._load_accessions() + if dry_run: + for acc in accessions: + LOG.info("Would enqueue %s", acc) + LOG.info("Loaded %d accessions (dry run)", len(accessions)) + return + + existing = self.db.existing_filenames(accessions) + count = 0 + for acc in accessions: + if acc in existing: + continue + url = f"nuccore:{acc}" + self.db.upsert_file("nuccore", acc, url, None, None) + count += 1 + if count % 1000 == 0: + LOG.info("Enqueued %d new accessions so far", count) + LOG.info("Loaded %d accessions (%d new, %d already present)", len(accessions), count, len(existing)) + + async def worker(self, session: aiohttp.ClientSession, net_sem: asyncio.Semaphore, rate: RequestRateLimiter): + params = self.cfg.get("sourmash_params", DEFAULT_PARAMS) + rayon_threads = int(self.cfg.get("sourmash_threads", 1)) + tmp_root = self.cfg["tmp_root"] + out_root = self.cfg["output_root"] + retry_max = int(self.cfg.get("max_retries", 6)) + timeout = int(self.cfg.get("request_timeout_seconds", 3600)) + shard_mod = int(self.cfg.get("shard_modulus", 512)) + batch_size = int(self.cfg.get("batch_size", 1)) + api_key = self.cfg.get("ncbi_api_key") + email = self.cfg.get("email") + error_cooldown = int(self.cfg.get("error_retry_cooldown_seconds", 1800)) + error_max_total = int(self.cfg.get("error_max_total_tries", 20)) + idle_cycles = 0 + + while not self.stop_flag: + try: + claims = self.db.claim_batch(batch_size, error_cooldown, error_max_total) + except Exception as exc: + LOG.exception("DB error while claiming work: %r", exc) + await asyncio.sleep(2.0) + continue + if not claims: + idle_cycles += 1 + if idle_cycles % 30 == 0: + st = self.db.stats() + LOG.info("No claimable work yet. by_status=%s", st.get("by_status")) + await asyncio.sleep(2.0) + continue + idle_cycles = 0 + accession_paths = {} + for file_id, _, accession, _ in claims: + LOG.debug("Claimed id=%s accession=%s", file_id, accession) + rel_dir = shard_subdir_for(accession, shard_mod) + local_tmp = os.path.join(tmp_root, rel_dir, f"{accession}.fasta") + local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + accession_paths[accession] = local_tmp + + if os.path.exists(local_out): + try: + with zipfile.ZipFile(local_out): + pass + LOG.debug("Skipping existing output %s", local_out) + self.db.mark_status(file_id, "DONE", out_path=local_out) + except zipfile.BadZipFile: + LOG.warning("Removing corrupt output %s", local_out) + try: + os.remove(local_out) + except Exception: + pass + + if not accession_paths: + continue + + try: + async with net_sem: + LOG.debug("Downloading batch of %d accessions", len(accession_paths)) + paths, sizes, missing = await download_accessions( + session, + accession_paths, + rate, + api_key, + timeout=timeout, + email=email, + ) + except RateLimitError as exc: + LOG.warning("Rate limit hit for batch (size %d): %s", len(accession_paths), exc) + for file_id, _, _, _ in claims: + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + await asyncio.sleep(5) + continue + except Exception as exc: + LOG.exception("Error downloading batch") + for file_id, _, _, _ in claims: + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + await asyncio.sleep(2) + continue + + id_by_accession = {acc: fid for fid, _, acc, _ in claims} + for accession, reason in missing.items(): + file_id = id_by_accession.get(accession) + if file_id is None: + continue + self.db.update_size(file_id, sizes.get(accession, 0)) + self.db.mark_status(file_id, "EMPTY", error=reason) + + for accession, local_tmp in paths.items(): + file_id = id_by_accession.get(accession) + if file_id is None: + continue + rel_dir = shard_subdir_for(accession, shard_mod) + local_out = os.path.join(out_root, rel_dir, f"{accession}.sig.zip") + self.db.update_size(file_id, sizes.get(accession, 0)) + try: + self.db.mark_status(file_id, "SKETCHING") + LOG.debug("Sketching %s", local_tmp) + if not os.path.exists(local_tmp): + raise FileNotFoundError(f"Missing downloaded file {local_tmp}") + rc, out = await run_sourmash(local_tmp, local_out, params, rayon_threads, log=LOG) + if rc != 0: + LOG.error("sourmash failed rc=%s for %s\n%s", rc, local_tmp, out) + raise RuntimeError(f"sourmash failed rc={rc} for {local_tmp}") + self.db.mark_status(file_id, "DONE", out_path=local_out) + LOG.info("Finished %s", accession) + except Exception as exc: + LOG.exception("Error processing %s", accession) + self.db.mark_status(file_id, "ERROR", error=str(exc), inc_tries=True) + if self.db.get_tries(file_id) >= retry_max: + self.db.mark_status(file_id, "FAILED", error=str(exc)) + finally: + try: + if os.path.exists(local_tmp): + os.remove(local_tmp) + except Exception: + pass + + async def run(self): + for p in ( + self.cfg["output_root"], + self.cfg["tmp_root"], + os.path.dirname(self.cfg["state_db"]), + os.path.dirname(self.cfg.get("log_path", "/tmp/void.log")), + ): + if p: + os.makedirs(p, exist_ok=True) + + await self.enqueue_accessions() + if self.cfg.get("dry_run", False): + LOG.info("Dry run complete; exiting.") + return + + stale = int(self.cfg.get("stale_seconds", 3600)) + self.db.reset_stuck(stale) + + max_dl = int(self.cfg.get("max_concurrent_downloads", 8)) + net_sem = asyncio.Semaphore(max_dl) + reqs_per_sec = self.cfg.get("requests_per_second") + rate = RequestRateLimiter(float(reqs_per_sec)) if reqs_per_sec else None + + conn = aiohttp.TCPConnector(limit_per_host=max_dl, limit=max_dl) + timeout = aiohttp.ClientTimeout(total=None, sock_connect=120, sock_read=3600) + + headers = {"User-Agent": self.cfg.get("user_agent", "Accession Sketcher/1.0")} + + async with aiohttp.ClientSession(connector=conn, timeout=timeout, headers=headers) as session: + total_workers = int(self.cfg.get("max_total_workers", 96)) + workers = [asyncio.create_task(self.worker(session, net_sem, rate)) for _ in range(total_workers)] + for worker in workers: + worker.add_done_callback( + lambda task: LOG.exception("Worker crashed: %r", task.exception()) if task.exception() else None + ) + monitor_task = asyncio.create_task(self._monitor(workers)) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, self._request_stop) + + while True: + st = self.db.stats() + claimable = self.db.count_claimable( + int(self.cfg.get("error_retry_cooldown_seconds", 1800)), + int(self.cfg.get("error_max_total_tries", 20)), + ) + pending = ( + claimable + + st["by_status"].get("DOWNLOADING", 0) + + st["by_status"].get("SKETCHING", 0) + ) + if pending == 0: + self.stop_flag = True + break + await asyncio.sleep(10) + + await asyncio.gather(*workers, return_exceptions=True) + monitor_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await monitor_task + + def _request_stop(self): + LOG.warning("Stop requested; will finish current tasks then exit.") + self.stop_flag = True + + async def _monitor(self, workers): + while not self.stop_flag: + st = self.db.stats() + alive = sum(not worker.done() for worker in workers) + LOG.info("Monitor: alive_workers=%d stats=%s", alive, st.get("by_status")) + for worker in workers: + if worker.done() and worker.exception(): + LOG.error("Worker died: %r", worker.exception()) + await asyncio.sleep(30) + +def load_config(path: str) -> dict: + import yaml + + with open(path, "r") as f: + cfg = yaml.safe_load(f) + cfg.setdefault("accessions_path", "/scratch/accessions.txt") + cfg.setdefault("output_root", "/scratch/accession_sketches") + cfg.setdefault("tmp_root", "/scratch/accession_tmp") + cfg.setdefault("state_db", "/scratch/accession_state/sketcher.sqlite") + cfg.setdefault("log_path", "/scratch/accession_logs/sketcher.log") + cfg.setdefault("max_concurrent_downloads", 8) + cfg.setdefault("max_total_workers", 96) + cfg.setdefault("user_agent", "Accession Sketcher/1.0 (+dmk333@psu.edu; admin=dmk333@psu.edu)") + cfg.setdefault("error_retry_cooldown_seconds", 1800) + cfg.setdefault("error_max_total_tries", 20) + cfg.setdefault("stale_seconds", 3600) + cfg.setdefault("dry_run", True) + cfg.setdefault("sourmash_params", DEFAULT_PARAMS) + cfg.setdefault("sourmash_threads", 1) + cfg.setdefault("request_timeout_seconds", 3600) + cfg.setdefault("max_retries", 8) + cfg.setdefault("batch_size", 1) + cfg.setdefault("shard_modulus", 512) + cfg.setdefault("log_level", "INFO") + cfg.setdefault("ncbi_api_key", None) + cfg.setdefault("email", None) + cfg.setdefault("requests_per_second", None) + + if cfg.get("requests_per_second") is None: + cfg["requests_per_second"] = 10 if cfg.get("ncbi_api_key") else 3 + + for key in ("output_root", "tmp_root", "state_db", "log_path", "accessions_path"): + cfg[key] = os.path.abspath(os.path.expanduser(cfg[key])) + return cfg + +async def main_async(cfg_path: str): + cfg = load_config(cfg_path) + sk = Sketcher(cfg) + await sk.run() + +def main(): + import argparse + + p = argparse.ArgumentParser(description="Accession sourmash sketcher") + p.add_argument("--config", "-c", required=True, help="Path to YAML config") + args = p.parse_args() + asyncio.run(main_async(args.config)) + +if __name__ == "__main__": + main() diff --git a/shared_scripts/accession_sketcher/accession_sketcher/utils.py b/shared_scripts/accession_sketcher/accession_sketcher/utils.py new file mode 100644 index 0000000..ec1098b --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/utils.py @@ -0,0 +1,42 @@ +import asyncio +import hashlib +import logging +import os +import sys +import time +from typing import Optional + +LOG = logging.getLogger("accession_sketcher") + +def build_logger(log_path: Optional[str] = None, level: int = logging.INFO): + LOG.setLevel(level) + fmt = logging.Formatter("%(asctime)s | %(levelname)7s | %(name)s | %(message)s") + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(fmt) + LOG.addHandler(sh) + if log_path: + os.makedirs(os.path.dirname(log_path), exist_ok=True) + fh = logging.FileHandler(log_path) + fh.setFormatter(fmt) + LOG.addHandler(fh) + +def shard_subdir_for(name: str, modulus: int) -> str: + h = hashlib.sha256(name.encode("utf-8")).hexdigest() + return str(int(h, 16) % modulus) + +class RequestRateLimiter: + """Strict rate limiter for requests per second.""" + def __init__(self, requests_per_sec: Optional[float] = None): + self.requests_per_sec = requests_per_sec + self.interval = (1.0 / requests_per_sec) if requests_per_sec else None + self.next_allowed = time.monotonic() + self._lock = asyncio.Lock() + + async def throttle(self): + if not self.requests_per_sec: + return + async with self._lock: + now = time.monotonic() + if now < self.next_allowed: + await asyncio.sleep(self.next_allowed - now) + self.next_allowed = max(self.next_allowed, now) + self.interval diff --git a/shared_scripts/accession_sketcher/accession_sketcher/worker.py b/shared_scripts/accession_sketcher/accession_sketcher/worker.py new file mode 100644 index 0000000..42404c1 --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher/worker.py @@ -0,0 +1,164 @@ +import asyncio +import aiohttp +import os +import logging +import zipfile +from typing import Optional, Tuple +from .utils import LOG, RequestRateLimiter + +CHUNK_SIZE = 4 * 1024 * 1024 +EFETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" + +class RateLimitError(RuntimeError): + pass + +class EmptyResultError(RuntimeError): + def __init__(self, message: str, size: int = 0): + super().__init__(message) + self.size = size + +async def download_accessions( + session: aiohttp.ClientSession, + accession_paths: dict[str, str], + rate: Optional[RequestRateLimiter], + api_key: Optional[str], + timeout: int = 3600, + email: Optional[str] = None, + tool: str = "accession_sketcher", +) -> tuple[dict[str, str], dict[str, int], dict[str, str]]: + accessions = list(accession_paths.keys()) + params = { + "db": "nuccore", + "id": ",".join(accessions), + "rettype": "fasta", + "retmode": "text", + "tool": tool, + } + if api_key: + params["api_key"] = api_key + if email: + params["email"] = email + + if rate: + await rate.throttle() + + accession_map = {acc.split(".")[0]: acc for acc in accessions} + paths: dict[str, str] = {} + sizes: dict[str, int] = {acc: 0 for acc in accessions} + file_handles: dict[str, tuple[str, object]] = {} + current_acc: Optional[str] = None + head_bytes = bytearray() + + try: + async with session.get(EFETCH_URL, params=params, timeout=timeout) as resp: + if resp.status != 200: + text_sample = (await resp.text())[:500] + if resp.status in (429, 503) or "rate limit" in text_sample.lower(): + raise RateLimitError(text_sample) + raise RuntimeError(f"efetch status {resp.status}: {text_sample}") + + buffer = b"" + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + if not chunk: + continue + if len(head_bytes) < 2048: + head_bytes.extend(chunk[: max(0, 2048 - len(head_bytes))]) + buffer += chunk + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + line += b"\n" + if line.startswith(b">"): + header = line[1:].strip() + header_str = header.decode(errors="replace") + header_token = header_str.split()[0] + header_base = header_token.split(".")[0] + current_acc = accession_map.get(header_base) + if current_acc and current_acc not in file_handles: + tmp_path = accession_paths[current_acc] + ".part" + os.makedirs(os.path.dirname(tmp_path), exist_ok=True) + fhandle = open(tmp_path, "wb") + file_handles[current_acc] = (tmp_path, fhandle) + paths[current_acc] = accession_paths[current_acc] + if current_acc: + file_handles[current_acc][1].write(line) + sizes[current_acc] += len(line) + else: + if current_acc: + file_handles[current_acc][1].write(line) + sizes[current_acc] += len(line) + + if buffer: + if len(head_bytes) < 2048: + head_bytes.extend(buffer[: max(0, 2048 - len(head_bytes))]) + if current_acc: + file_handles[current_acc][1].write(buffer) + sizes[current_acc] += len(buffer) + except Exception: + for tmp_path, handle in file_handles.values(): + try: + handle.close() + except Exception: + pass + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + raise + + for acc, (tmp_path, handle) in file_handles.items(): + handle.close() + os.replace(tmp_path, paths[acc]) + + empty_reason = None + if not paths: + head_text = head_bytes.decode(errors="replace").strip() + if "EMPTY RESULT" in head_text.upper() or "FAILURE" in head_text.upper(): + empty_reason = head_text or "EMPTY RESULT" + + missing_reasons = {} + for acc in accessions: + if acc not in paths or sizes.get(acc, 0) == 0: + missing_reasons[acc] = empty_reason or f"EMPTY RESULT for {acc}" + + return paths, sizes, missing_reasons + +async def run_sourmash( + input_path: str, + output_path: str, + params: str, + rayon_threads: int = 1, + extra_env=None, + log: Optional[logging.Logger] = None, +) -> Tuple[int, str]: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + if not os.path.exists(input_path): + return 1, f"input file not found: {input_path}" + env = os.environ.copy() + env["RAYON_NUM_THREADS"] = str(rayon_threads) + if extra_env: + env.update(extra_env) + cmd = [ + "sourmash", + "sketch", + "dna", + "-p", + params, + "-o", + output_path, + input_path, + ] + if log: + log.debug("Running: %s", " ".join(cmd)) + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + out, err = await proc.communicate() + rc = proc.returncode + out_combined = (out or b"").decode() + (err or b"").decode() + if rc == 0 and not zipfile.is_zipfile(output_path): + return 1, f"output not a valid zip: {output_path}" + return rc, out_combined diff --git a/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt b/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt new file mode 100644 index 0000000..1abe34c --- /dev/null +++ b/shared_scripts/accession_sketcher/accession_sketcher_accessions.txt @@ -0,0 +1,16 @@ +HE784786 +OM463502 +GC152993 +PQ128373 +AV562350 +CAAGRJ010006848 +ICDS01031145 +JMFX02064987 +CALO01142390 +CAZGFS010001113 +IAHH01018974 +CEOS01020347 +NZ_JARCZL010000406 +AFNZ01030774 +AJIM01121339 +NZ_NYPT00000000 diff --git a/shared_scripts/accession_sketcher/config.yaml b/shared_scripts/accession_sketcher/config.yaml new file mode 100644 index 0000000..c3657e6 --- /dev/null +++ b/shared_scripts/accession_sketcher/config.yaml @@ -0,0 +1,16 @@ +accessions_path: ./accession_sketcher_accessions.txt +output_root: /tmp/accession_sketches +tmp_root: /tmp/accession_tmp +state_db: /tmp/accession_state/sketcher.sqlite +log_path: /tmp/accession_logs/sketcher.log +max_concurrent_downloads: 4 +max_total_workers: 8 +batch_size: 10 +ncbi_api_key: "8dd1e03a1867cb0ce9e5176a520679f15a08" +requests_per_second: 10 +sourmash_params: "k=15,k=31,k=33,scaled=1000,noabund" +sourmash_threads: 1 +request_timeout_seconds: 3600 +max_retries: 6 +dry_run: false +log_level: INFO diff --git a/shared_scripts/accession_sketcher/requirements.txt b/shared_scripts/accession_sketcher/requirements.txt new file mode 100644 index 0000000..bd6daea --- /dev/null +++ b/shared_scripts/accession_sketcher/requirements.txt @@ -0,0 +1,5 @@ +aiohttp>=3.9 +PyYAML>=6.0 +pyarrow +orjson +duckdb diff --git a/shared_scripts/accession_sketcher/run.sh b/shared_scripts/accession_sketcher/run.sh new file mode 100644 index 0000000..6108765 --- /dev/null +++ b/shared_scripts/accession_sketcher/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +python -m accession_sketcher --config config.yaml