diff --git a/graphgen/configs/search_dna_config.yaml b/graphgen/configs/search_dna_config.yaml index f53a5eb8..82368754 100644 --- a/graphgen/configs/search_dna_config.yaml +++ b/graphgen/configs/search_dna_config.yaml @@ -14,4 +14,6 @@ pipeline: tool: GraphGen # tool name for NCBI API use_local_blast: true # whether to use local blast for DNA search local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension) + blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory) + max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited) diff --git a/graphgen/configs/search_protein_config.yaml b/graphgen/configs/search_protein_config.yaml index bfbf84eb..ed04ff12 100644 --- a/graphgen/configs/search_protein_config.yaml +++ b/graphgen/configs/search_protein_config.yaml @@ -13,3 +13,5 @@ pipeline: use_local_blast: true # whether to use local blast for uniprot search local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot # options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database) + blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory) + max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited) diff --git a/graphgen/configs/search_rna_config.yaml b/graphgen/configs/search_rna_config.yaml index 10422988..83bbca7d 100644 --- a/graphgen/configs/search_rna_config.yaml +++ b/graphgen/configs/search_rna_config.yaml @@ -12,3 +12,5 @@ pipeline: rnacentral_params: use_local_blast: true # whether to use local blast for RNA search local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension) + blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory) + max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index bc7e7742..6cff74b1 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,3 +1,4 @@ +import hashlib import os import time from typing import Dict @@ -87,24 +88,45 @@ def __init__( @async_to_sync_method async def read(self, read_config: Dict): """ - read files from input sources + read files from input sources with batch processing """ + # Get batch_size from config, default to 10000 + batch_size = read_config.pop("batch_size", 10000) + doc_stream = read_files(**read_config, cache_dir=self.working_dir) batch = {} + total_processed = 0 + for doc in doc_stream: doc_id = compute_mm_hash(doc, prefix="doc-") batch[doc_id] = doc + + # Process batch when it reaches batch_size + if len(batch) >= batch_size: + _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) + new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} + if new_docs: + self.full_docs_storage.upsert(new_docs) + total_processed += len(new_docs) + logger.info("Processed batch: %d new documents (total: %d)", len(new_docs), total_processed) + batch.clear() # TODO: configurable whether to use coreference resolution - _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) - new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} - if len(new_docs) == 0: + # Process remaining documents in batch + if batch: + _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) + new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} + if new_docs: + self.full_docs_storage.upsert(new_docs) + total_processed += len(new_docs) + logger.info("Processed final batch: %d new documents (total: %d)", len(new_docs), total_processed) + + if total_processed == 0: logger.warning("All documents are already in the storage") - return - self.full_docs_storage.upsert(new_docs) - self.full_docs_storage.index_done_callback() + else: + self.full_docs_storage.index_done_callback() @async_to_sync_method async def chunk(self, chunk_config: Dict): @@ -169,24 +191,68 @@ async def build_kg(self): async def search(self, search_config: Dict): logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - seeds = self.full_docs_storage.get_all() - if len(seeds) == 0: - logger.warning("All documents are already been searched") + # Get search_batch_size from config (default: 10000) + search_batch_size = search_config.get("search_batch_size", 10000) + + # Get save_interval from config (default: 1000, 0 to disable) + save_interval = search_config.get("save_interval", 1000) + + # Process in batches to avoid OOM + all_flattened_results = {} + batch_num = 0 + + for seeds_batch in self.full_docs_storage.iter_batches(batch_size=search_batch_size): + if len(seeds_batch) == 0: + continue + + batch_num += 1 + logger.info("Processing search batch %d with %d documents", batch_num, len(seeds_batch)) + + search_results = await search_all( + seed_data=seeds_batch, + search_config=search_config, + search_storage=self.search_storage if save_interval > 0 else None, + save_interval=save_interval, + ) + + # Convert search_results from {data_source: [results]} to {key: result} + # This maintains backward compatibility + for data_source, result_list in search_results.items(): + if not isinstance(result_list, list): + continue + for result in result_list: + if result is None: + continue + # Use _search_query as key if available, otherwise generate a key + if isinstance(result, dict) and "_search_query" in result: + query = result["_search_query"] + key = f"{data_source}:{query}" + else: + # Generate a unique key + result_str = str(result) + key_hash = hashlib.md5(result_str.encode()).hexdigest()[:8] + key = f"{data_source}:{key_hash}" + all_flattened_results[key] = result + + if len(all_flattened_results) == 0: + logger.warning("No search results generated") return - search_results = await search_all( - seed_data=seeds, - search_config=search_config, - ) - _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) + _add_search_keys = self.search_storage.filter_keys(list(all_flattened_results.keys())) search_results = { - k: v for k, v in search_results.items() if k in _add_search_keys + k: v for k, v in all_flattened_results.items() if k in _add_search_keys } if len(search_results) == 0: logger.warning("All search results are already in the storage") return - self.search_storage.upsert(search_results) - self.search_storage.index_done_callback() + + # Only save if not using periodic saving (to avoid duplicate saves) + if save_interval == 0: + self.search_storage.upsert(search_results) + self.search_storage.index_done_callback() + else: + # Results were already saved periodically, just update index + self.search_storage.index_done_callback() @async_to_sync_method async def quiz_and_judge(self, quiz_and_judge_config: Dict): diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py index 31bc3195..f84aeadd 100644 --- a/graphgen/models/reader/jsonl_reader.py +++ b/graphgen/models/reader/jsonl_reader.py @@ -1,5 +1,6 @@ import json -from typing import Any, Dict, List +import os +from typing import Any, Dict, Iterator, List from graphgen.bases.base_reader import BaseReader from graphgen.utils import logger @@ -28,3 +29,56 @@ def read(self, file_path: str) -> List[Dict[str, Any]]: except json.JSONDecodeError as e: logger.error("Error decoding JSON line: %s. Error: %s", line, e) return self.filter(docs) + + def read_stream(self, file_path: str) -> Iterator[Dict[str, Any]]: + """ + Stream read JSONL files line by line without loading entire file into memory. + Returns an iterator that yields filtered documents. + + :param file_path: Path to the JSONL file. + :return: Iterator of dictionaries containing the data. + """ + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + try: + doc = json.loads(line) + assert "type" in doc, f"Missing 'type' in document: {doc}" + if doc.get("type") == "text" and self.text_column not in doc: + raise ValueError( + f"Missing '{self.text_column}' in document: {doc}" + ) + + # Apply filtering logic inline (similar to BaseReader.filter) + if doc.get("type") == "text": + content = doc.get(self.text_column, "").strip() + if content: + yield doc + elif doc.get("type") in ("image", "table", "equation"): + img_path = doc.get("img_path") + if self._image_exists(img_path): + yield doc + else: + yield doc + except json.JSONDecodeError as e: + logger.error("Error decoding JSON line: %s. Error: %s", line, e) + + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + import requests + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except Exception: + return False diff --git a/graphgen/models/searcher/db/ncbi_searcher.py b/graphgen/models/searcher/db/ncbi_searcher.py index f453c700..73b3eba0 100644 --- a/graphgen/models/searcher/db/ncbi_searcher.py +++ b/graphgen/models/searcher/db/ncbi_searcher.py @@ -24,11 +24,11 @@ @lru_cache(maxsize=None) def _get_pool(): - return ThreadPoolExecutor(max_workers=10) + return ThreadPoolExecutor(max_workers=20) # NOTE:can increase for better parallelism # ensure only one NCBI request at a time -_ncbi_lock = asyncio.Lock() +_blast_lock = asyncio.Lock() class NCBISearch(BaseSearcher): @@ -49,6 +49,7 @@ def __init__( email: str = "email@example.com", api_key: str = "", tool: str = "GraphGen", + blast_num_threads: int = 4, ): """ Initialize the NCBI Search client. @@ -59,6 +60,7 @@ def __init__( email (str): Email address for NCBI API requests. api_key (str): API key for NCBI API requests, see https://account.ncbi.nlm.nih.gov/settings/. tool (str): Tool name for NCBI API requests. + blast_num_threads (int): Number of threads for BLAST search. """ super().__init__() Entrez.timeout = 60 # 60 seconds timeout @@ -70,9 +72,17 @@ def __init__( Entrez.sleep_between_tries = 5 self.use_local_blast = use_local_blast self.local_blast_db = local_blast_db - if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"): - logger.error("Local BLAST database files not found. Please check the path.") - self.use_local_blast = False + self.blast_num_threads = blast_num_threads + if self.use_local_blast: + # Check for single-file database (.nhr) or multi-file database (.00.nhr) + db_exists = ( + os.path.isfile(f"{self.local_blast_db}.nhr") or + os.path.isfile(f"{self.local_blast_db}.00.nhr") + ) + if not db_exists: + logger.error("Local BLAST database files not found. Please check the path.") + logger.error("Expected: %s.nhr or %s.00.nhr", self.local_blast_db, self.local_blast_db) + self.use_local_blast = False @staticmethod def _nested_get(data: dict, *keys, default=None): @@ -87,14 +97,16 @@ def _nested_get(data: dict, *keys, default=None): def _infer_molecule_type_detail(accession: Optional[str], gene_type: Optional[int] = None) -> Optional[str]: """Infer molecule_type_detail from accession prefix or gene type.""" if accession: - if accession.startswith(("NM_", "XM_")): - return "mRNA" - if accession.startswith(("NC_", "NT_")): - return "genomic DNA" - if accession.startswith(("NR_", "XR_")): - return "RNA" - if accession.startswith("NG_"): - return "genomic region" + # Map accession prefixes to molecule types + prefix_map = { + ("NM_", "XM_"): "mRNA", + ("NC_", "NT_"): "genomic DNA", + ("NR_", "XR_"): "RNA", + ("NG_",): "genomic region", + } + for prefixes, mol_type in prefix_map.items(): + if accession.startswith(prefixes): + return mol_type # Fallback: infer from gene type if available if gene_type is not None: gene_type_map = { @@ -153,7 +165,6 @@ def _gene_record_to_dict(self, gene_record, gene_id: str) -> dict: None, ) # Fallback: if no type 3 accession, try any available accession - # This is needed for genes that don't have mRNA transcripts but have other sequence records if not representative_accession: representative_accession = next( ( @@ -209,6 +220,12 @@ def _gene_record_to_dict(self, gene_record, gene_id: str) -> dict: "_representative_accession": representative_accession, } + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RequestException, IncompleteRead)), + reraise=True, + ) def get_by_gene_id(self, gene_id: str, preferred_accession: Optional[str] = None) -> Optional[dict]: """Get gene information by Gene ID.""" def _extract_metadata_from_genbank(result: dict, accession: str): @@ -217,12 +234,7 @@ def _extract_metadata_from_genbank(result: dict, accession: str): record = SeqIO.read(handle, "genbank") result["title"] = record.description - result["molecule_type_detail"] = ( - "mRNA" if accession.startswith(("NM_", "XM_")) else - "genomic DNA" if accession.startswith(("NC_", "NT_")) else - "RNA" if accession.startswith(("NR_", "XR_")) else - "genomic region" if accession.startswith("NG_") else "N/A" - ) + result["molecule_type_detail"] = self._infer_molecule_type_detail(accession) or "N/A" for feature in record.features: if feature.type == "source": @@ -257,25 +269,62 @@ def _extract_sequence_from_fasta(result: dict, accession: str): result["sequence_length"] = None return result + def _extract_sequence(result: dict, accession: str): + """ + Extract sequence using the appropriate method based on configuration. + If use_local_blast=True, use local database. Otherwise, use NCBI API. + Always fetches sequence (no option to skip). + """ + # If using local BLAST, use local database + if self.use_local_blast: + sequence = self._extract_sequence_from_local_db(accession) + + if sequence: + result["sequence"] = sequence + result["sequence_length"] = len(sequence) + else: + # Failed to extract from local DB, set to None (no fallback to API) + result["sequence"] = None + result["sequence_length"] = None + logger.warning( + "Failed to extract sequence from local DB for accession %s. " + "Not falling back to NCBI API as use_local_blast=True.", + accession + ) + else: + # Use NCBI API to fetch sequence + result = _extract_sequence_from_fasta(result, accession) + + return result + try: with Entrez.efetch(db="gene", id=gene_id, retmode="xml") as handle: gene_record = Entrez.read(handle) - if not gene_record: - return None + + if not gene_record: + return None - result = self._gene_record_to_dict(gene_record, gene_id) - if accession := (preferred_accession or result.get("_representative_accession")): - result = _extract_metadata_from_genbank(result, accession) - result = _extract_sequence_from_fasta(result, accession) + result = self._gene_record_to_dict(gene_record, gene_id) + + if accession := (preferred_accession or result.get("_representative_accession")): + result = _extract_metadata_from_genbank(result, accession) + # Extract sequence using appropriate method + result = _extract_sequence(result, accession) - result.pop("_representative_accession", None) - return result + result.pop("_representative_accession", None) + return result except (RequestException, IncompleteRead): raise except Exception as exc: logger.error("Gene ID %s not found: %s", gene_id, exc) return None + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RequestException, IncompleteRead)), + reraise=True, + ) def get_by_accession(self, accession: str) -> Optional[dict]: """Get sequence information by accession number.""" def _extract_gene_id(link_handle): @@ -301,9 +350,11 @@ def _extract_gene_id(link_handle): return None result = self.get_by_gene_id(gene_id, preferred_accession=accession) + if result: result["id"] = accession result["url"] = f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}" + return result except (RequestException, IncompleteRead): raise @@ -311,6 +362,12 @@ def _extract_gene_id(link_handle): logger.error("Accession %s not found: %s", accession, exc) return None + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RequestException, IncompleteRead)), + reraise=True, + ) def get_best_hit(self, keyword: str) -> Optional[dict]: """Search NCBI Gene database with a keyword and return the best hit.""" if not keyword.strip(): @@ -320,31 +377,87 @@ def get_best_hit(self, keyword: str) -> Optional[dict]: for search_term in [f"{keyword}[Gene] OR {keyword}[All Fields]", keyword]: with Entrez.esearch(db="gene", term=search_term, retmax=1, sort="relevance") as search_handle: search_results = Entrez.read(search_handle) - if len(gene_id := search_results.get("IdList", [])) > 0: - return self.get_by_gene_id(gene_id) + + if len(gene_id := search_results.get("IdList", [])) > 0: + result = self.get_by_gene_id(gene_id) + return result except (RequestException, IncompleteRead): raise except Exception as e: logger.error("Keyword %s not found: %s", keyword, e) return None + def _extract_sequence_from_local_db(self, accession: str) -> Optional[str]: + """Extract sequence from local BLAST database using blastdbcmd.""" + try: + cmd = [ + "blastdbcmd", + "-db", self.local_blast_db, + "-entry", accession, + "-outfmt", "%s" # Only sequence, no header + ] + sequence = subprocess.check_output( + cmd, + text=True, + timeout=10, # 10 second timeout for local extraction + stderr=subprocess.DEVNULL + ).strip() + return sequence if sequence else None + except subprocess.TimeoutExpired: + logger.warning("Timeout extracting sequence from local DB for accession %s", accession) + return None + except Exception as exc: + logger.warning("Failed to extract sequence from local DB for accession %s: %s", accession, exc) + return None + def _local_blast(self, seq: str, threshold: float) -> Optional[str]: - """Perform local BLAST search using local BLAST database.""" + """ + Perform local BLAST search using local BLAST database. + Optimized with multi-threading and faster output format. + """ try: with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp: tmp.write(f">query\n{seq}\n") tmp_name = tmp.name + # Optimized BLAST command with: + # - num_threads: Use multiple threads for faster search + # - outfmt 6 sacc: Only return accession (minimal output) + # - max_target_seqs 1: Only need the best hit + # - evalue: Threshold for significance cmd = [ "blastn", "-db", self.local_blast_db, "-query", tmp_name, - "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc" + "-evalue", str(threshold), + "-max_target_seqs", "1", + "-num_threads", str(self.blast_num_threads), + "-outfmt", "6 sacc" # Only accession, tab-separated ] - logger.debug("Running local blastn: %s", " ".join(cmd)) - out = subprocess.check_output(cmd, text=True).strip() + logger.debug("Running local blastn (threads=%d): %s", + self.blast_num_threads, " ".join(cmd)) + + # Run BLAST with timeout to avoid hanging + try: + out = subprocess.check_output( + cmd, + text=True, + timeout=300, # 5 minute timeout for BLAST search + stderr=subprocess.DEVNULL # Suppress BLAST warnings to reduce I/O + ).strip() + except subprocess.TimeoutExpired: + logger.warning("BLAST search timed out after 5 minutes for sequence") + os.remove(tmp_name) + return None + os.remove(tmp_name) return out.split("\n", maxsplit=1)[0] if out else None except Exception as exc: logger.error("Local blastn failed: %s", exc) + # Clean up temp file if it still exists + try: + if 'tmp_name' in locals(): + os.remove(tmp_name) + except Exception: + pass return None def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]: @@ -393,15 +506,24 @@ def _process_network_blast_result(blast_record, seq: str, threshold: float) -> O return None # Try local BLAST first if enabled - if self.use_local_blast and (accession := self._local_blast(seq, threshold)): - logger.debug("Local BLAST found accession: %s", accession) - return self.get_by_accession(accession) + if self.use_local_blast: + accession = self._local_blast(seq, threshold) + + if accession: + logger.debug("Local BLAST found accession: %s", accession) + # When using local BLAST, skip sequence fetching by default (faster, fewer API calls) + # Sequence is already known from the query, so we only need metadata + result = self.get_by_accession(accession) + return result + + logger.info("Local BLAST found no match for sequence. API fallback disabled when using local database.") + return None - # Fall back to network BLAST + # Fall back to network BLAST only if local BLAST is not enabled logger.debug("Falling back to NCBIWWW.qblast") - with NCBIWWW.qblast("blastn", "nr", seq, hitlist_size=1, expect=threshold) as result_handle: - return _process_network_blast_result(NCBIXML.read(result_handle), seq, threshold) + result = _process_network_blast_result(NCBIXML.read(result_handle), seq, threshold) + return result except (RequestException, IncompleteRead): raise except Exception as e: @@ -425,17 +547,26 @@ async def search(self, query: str, threshold: float = 0.01, **kwargs) -> Optiona loop = asyncio.get_running_loop() - # limit concurrent requests (NCBI rate limit: max 3 requests per second) - async with _ncbi_lock: - # Auto-detect query type and execute in thread pool - if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I): - result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold) - elif re.fullmatch(r"^\d+$", query): - result = await loop.run_in_executor(_get_pool(), self.get_by_gene_id, query) - elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I): - result = await loop.run_in_executor(_get_pool(), self.get_by_accession, query) - else: - result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query) + # Auto-detect query type and execute in thread pool + # All methods need lock because they all call NCBI API (rate limit: max 3 requests per second) + # Even if get_by_fasta uses local BLAST, it still calls get_by_accession which needs API + async def _execute_with_lock(func, *args): + """Execute function with lock for NCBI API calls.""" + async with _blast_lock: + return await loop.run_in_executor(_get_pool(), func, *args) + + if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I): + # FASTA sequence: always use lock (even with local BLAST, get_by_accession needs API) + result = await _execute_with_lock(self.get_by_fasta, query, threshold) + elif re.fullmatch(r"^\d+$", query): + # Gene ID: always use lock (network API call) + result = await _execute_with_lock(self.get_by_gene_id, query) + elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I): + # Accession: always use lock (network API call) + result = await _execute_with_lock(self.get_by_accession, query) + else: + # Keyword: always use lock (network API call) + result = await _execute_with_lock(self.get_best_hit, query) if result: result["_search_query"] = query diff --git a/graphgen/models/searcher/db/rnacentral_searcher.py b/graphgen/models/searcher/db/rnacentral_searcher.py index 58c5e86e..7fcba467 100644 --- a/graphgen/models/searcher/db/rnacentral_searcher.py +++ b/graphgen/models/searcher/db/rnacentral_searcher.py @@ -23,7 +23,7 @@ @lru_cache(maxsize=None) def _get_pool(): - return ThreadPoolExecutor(max_workers=10) + return ThreadPoolExecutor(max_workers=20) # NOTE:can increase for better parallelism class RNACentralSearch(BaseSearcher): """ @@ -35,12 +35,21 @@ class RNACentralSearch(BaseSearcher): API Documentation: https://rnacentral.org/api/v1 """ - def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"): + def __init__( + self, + use_local_blast: bool = False, + local_blast_db: str = "rna_db", + api_timeout: int = 30, + blast_num_threads: int = 4 + ): super().__init__() self.base_url = "https://rnacentral.org/api/v1" self.headers = {"Accept": "application/json"} self.use_local_blast = use_local_blast self.local_blast_db = local_blast_db + self.api_timeout = api_timeout + self.blast_num_threads = blast_num_threads # Number of threads for BLAST search + if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"): logger.error("Local BLAST database files not found. Please check the path.") self.use_local_blast = False @@ -58,7 +67,8 @@ def _rna_data_to_dict( acc = xref.get("accession", {}) if s := acc.get("species"): organisms.add(s) - if g := acc.get("gene", "").strip(): + gene_value = acc.get("gene") + if isinstance(gene_value, str) and (g := gene_value.strip()): gene_names.add(g) if m := xref.get("modifications"): modifications.extend(m) @@ -141,6 +151,12 @@ def _calculate_md5(sequence: str) -> str: return hashlib.md5(normalized_seq.encode("ascii")).hexdigest() + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type((requests.Timeout, requests.RequestException)), + reraise=False, + ) def get_by_rna_id(self, rna_id: str) -> Optional[dict]: """ Get RNA information by RNAcentral ID. @@ -151,12 +167,16 @@ def get_by_rna_id(self, rna_id: str) -> Optional[dict]: url = f"{self.base_url}/rna/{rna_id}" url += "?flat=true" - resp = requests.get(url, headers=self.headers, timeout=30) + resp = requests.get(url, headers=self.headers, timeout=self.api_timeout) resp.raise_for_status() rna_data = resp.json() xrefs_data = rna_data.get("xrefs", []) - return self._rna_data_to_dict(rna_id, rna_data, xrefs_data) + result = self._rna_data_to_dict(rna_id, rna_data, xrefs_data) + return result + except requests.Timeout as e: + logger.warning("Timeout getting RNA ID %s (timeout=%ds): %s", rna_id, self.api_timeout, e) + return None except requests.RequestException as e: logger.error("Network error getting RNA ID %s: %s", rna_id, e) return None @@ -164,6 +184,12 @@ def get_by_rna_id(self, rna_id: str) -> Optional[dict]: logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e) return None + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type((requests.Timeout, requests.RequestException)), + reraise=False, + ) def get_best_hit(self, keyword: str) -> Optional[dict]: """ Search RNAcentral with a keyword and return the best hit. @@ -178,7 +204,7 @@ def get_best_hit(self, keyword: str) -> Optional[dict]: try: url = f"{self.base_url}/rna" params = {"search": keyword, "format": "json"} - resp = requests.get(url, params=params, headers=self.headers, timeout=30) + resp = requests.get(url, params=params, headers=self.headers, timeout=self.api_timeout) resp.raise_for_status() data = resp.json() @@ -206,22 +232,54 @@ def get_best_hit(self, keyword: str) -> Optional[dict]: return None def _local_blast(self, seq: str, threshold: float) -> Optional[str]: - """Perform local BLAST search using local BLAST database.""" + """ + Perform local BLAST search using local BLAST database. + Optimized with multi-threading and faster output format. + """ try: + # Use temporary file for query sequence with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp: tmp.write(f">query\n{seq}\n") tmp_name = tmp.name + # Optimized BLAST command with: + # - num_threads: Use multiple threads for faster search + # - outfmt 6 sacc: Only return accession (minimal output) + # - max_target_seqs 1: Only need the best hit + # - evalue: Threshold for significance cmd = [ "blastn", "-db", self.local_blast_db, "-query", tmp_name, - "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc" + "-evalue", str(threshold), + "-max_target_seqs", "1", + "-num_threads", str(self.blast_num_threads), + "-outfmt", "6 sacc" # Only accession, tab-separated ] - logger.debug("Running local blastn for RNA: %s", " ".join(cmd)) - out = subprocess.check_output(cmd, text=True).strip() + logger.debug("Running local blastn for RNA (threads=%d): %s", + self.blast_num_threads, " ".join(cmd)) + + # Run BLAST with timeout to avoid hanging + try: + out = subprocess.check_output( + cmd, + text=True, + timeout=300, # 5 minute timeout for BLAST search + stderr=subprocess.DEVNULL # Suppress BLAST warnings to reduce I/O + ).strip() + except subprocess.TimeoutExpired: + logger.warning("BLAST search timed out after 5 minutes for sequence") + os.remove(tmp_name) + return None + os.remove(tmp_name) return out.split("\n", maxsplit=1)[0] if out else None except Exception as exc: logger.error("Local blastn failed: %s", exc) + # Clean up temp file if it still exists + try: + if 'tmp_name' in locals(): + os.remove(tmp_name) + except Exception: + pass return None def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]: @@ -240,7 +298,8 @@ def _extract_sequence(sequence: str) -> Optional[str]: seq = "".join(seq_lines[1:]) else: seq = sequence.strip().replace(" ", "").replace("\n", "") - return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None + # Accept both U (original RNA) and T + return seq if seq and re.fullmatch(r"[AUCGTN\s]+", seq, re.I) else None try: seq = _extract_sequence(sequence) @@ -253,9 +312,21 @@ def _extract_sequence(sequence: str) -> Optional[str]: accession = self._local_blast(seq, threshold) if accession: logger.debug("Local BLAST found accession: %s", accession) - return self.get_by_rna_id(accession) + detailed = self.get_by_rna_id(accession) + if detailed: + return detailed + logger.info( + "Local BLAST found accession %s but could not retrieve metadata from API.", + accession + ) + return None + logger.info( + "Local BLAST found no match for sequence. " + "API fallback disabled when using local database." + ) + return None - # Fall back to RNAcentral API if local BLAST didn't find result + # Fall back to RNAcentral API only if local BLAST is not enabled logger.debug("Falling back to RNAcentral API.") md5_hash = self._calculate_md5(seq) @@ -271,11 +342,18 @@ def _extract_sequence(sequence: str) -> Optional[str]: if not results: logger.info("No exact match found in RNAcentral for sequence") return None + rna_id = results[0].get("rnacentral_id") - if not rna_id: - logger.error("No RNAcentral ID found in search results.") - return None - return self.get_by_rna_id(rna_id) + if rna_id: + detailed = self.get_by_rna_id(rna_id) + if detailed: + return detailed + # Fallback: use search result data if get_by_rna_id returns None + logger.debug("Using search result data for %s (get_by_rna_id returned None)", rna_id) + return self._rna_data_to_dict(rna_id, results[0]) + + logger.error("No RNAcentral ID found in search results.") + return None except Exception as e: logger.error("Sequence search failed: %s", e) return None @@ -297,10 +375,13 @@ async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional loop = asyncio.get_running_loop() - # check if RNA sequence (AUCG characters, contains U) - if query.startswith(">") or ( - re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper() - ): + # check if RNA sequence (AUCG or ATCG characters, contains U or T) + # Note: Sequences with T are also RNA sequences + is_rna_sequence = query.startswith(">") or ( + re.fullmatch(r"[AUCGTN\s]+", query, re.I) and + ("U" in query.upper() or "T" in query.upper()) + ) + if is_rna_sequence: result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold) # check if RNAcentral ID (typically starts with URS) elif re.fullmatch(r"URS\d+", query, re.I): diff --git a/graphgen/models/searcher/db/uniprot_searcher.py b/graphgen/models/searcher/db/uniprot_searcher.py index f5542f8c..d39031d3 100644 --- a/graphgen/models/searcher/db/uniprot_searcher.py +++ b/graphgen/models/searcher/db/uniprot_searcher.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import lru_cache from io import StringIO -from typing import Dict, Optional +from typing import Dict, Optional, List from Bio import ExPASy, SeqIO, SwissProt, UniProt from Bio.Blast import NCBIWWW, NCBIXML @@ -24,7 +24,7 @@ @lru_cache(maxsize=None) def _get_pool(): - return ThreadPoolExecutor(max_workers=10) + return ThreadPoolExecutor(max_workers=20) # NOTE:can increase for better parallelism # ensure only one BLAST searcher at a time @@ -39,10 +39,17 @@ class UniProtSearch(BaseSearcher): 3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async. """ - def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"): + def __init__( + self, + use_local_blast: bool = False, + local_blast_db: str = "sp_db", + blast_num_threads: int = 4 + ): super().__init__() self.use_local_blast = use_local_blast self.local_blast_db = local_blast_db + self.blast_num_threads = blast_num_threads # Number of threads for BLAST search + if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"): logger.error("Local BLAST database files not found. Please check the path.") self.use_local_blast = False @@ -124,52 +131,58 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]: logger.error("Empty FASTA sequence provided.") return None - accession = None if self.use_local_blast: accession = self._local_blast(seq, threshold) if accession: logger.debug("Local BLAST found accession: %s", accession) + return self.get_by_accession(accession) + logger.info( + "Local BLAST found no match for sequence. " + "API fallback disabled when using local database." + ) + return None - if not accession: - logger.debug("Falling back to NCBIWWW.qblast.") + # Fall back to network BLAST only if local BLAST is not enabled + logger.debug("Falling back to NCBIWWW.qblast.") - # UniProtKB/Swiss-Prot BLAST API - try: - logger.debug( - "Performing BLAST searcher for the given sequence: %s", seq - ) - result_handle = NCBIWWW.qblast( - program="blastp", - database="swissprot", - sequence=seq, - hitlist_size=1, - expect=threshold, - ) - blast_record = NCBIXML.read(result_handle) - except RequestException: - raise - except Exception as e: # pylint: disable=broad-except - logger.error("BLAST searcher failed: %s", e) - return None + # UniProtKB/Swiss-Prot BLAST API + try: + logger.debug( + "Performing BLAST searcher for the given sequence: %s", seq + ) + result_handle = NCBIWWW.qblast( + program="blastp", + database="swissprot", + sequence=seq, + hitlist_size=1, + expect=threshold, + ) + blast_record = NCBIXML.read(result_handle) + except RequestException: + raise + except Exception as e: # pylint: disable=broad-except + logger.error("BLAST searcher failed: %s", e) + return None - if not blast_record.alignments: - logger.info("No BLAST hits found for the given sequence.") - return None + if not blast_record.alignments: + logger.info("No BLAST hits found for the given sequence.") + return None - best_alignment = blast_record.alignments[0] - best_hsp = best_alignment.hsps[0] - if best_hsp.expect > threshold: - logger.info("No BLAST hits below the threshold E-value.") - return None - hit_id = best_alignment.hit_id + best_alignment = blast_record.alignments[0] + best_hsp = best_alignment.hsps[0] + if best_hsp.expect > threshold: + logger.info("No BLAST hits below the threshold E-value.") + return None - # like sp|P01308.1|INS_HUMAN - accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id + # like sp|P01308.1|INS_HUMAN + hit_id = best_alignment.hit_id + accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id return self.get_by_accession(accession) def _local_blast(self, seq: str, threshold: float) -> Optional[str]: """ Perform local BLAST search using local BLAST database. + Optimized with multi-threading and faster output format. :param seq: The protein sequence. :param threshold: E-value threshold for BLAST searcher. :return: The accession number of the best hit or None if not found. @@ -181,6 +194,11 @@ def _local_blast(self, seq: str, threshold: float) -> Optional[str]: tmp.write(f">query\n{seq}\n") tmp_name = tmp.name + # Optimized BLAST command with: + # - num_threads: Use multiple threads for faster search + # - outfmt 6 sacc: Only return accession (minimal output) + # - max_target_seqs 1: Only need the best hit + # - evalue: Threshold for significance cmd = [ "blastp", "-db", @@ -191,11 +209,27 @@ def _local_blast(self, seq: str, threshold: float) -> Optional[str]: str(threshold), "-max_target_seqs", "1", + "-num_threads", + str(self.blast_num_threads), "-outfmt", - "6 sacc", # only return accession + "6 sacc", # Only accession, tab-separated ] - logger.debug("Running local blastp: %s", " ".join(cmd)) - out = subprocess.check_output(cmd, text=True).strip() + logger.debug("Running local blastp (threads=%d): %s", + self.blast_num_threads, " ".join(cmd)) + + # Run BLAST with timeout to avoid hanging + try: + out = subprocess.check_output( + cmd, + text=True, + timeout=300, # 5 minute timeout for BLAST search + stderr=subprocess.DEVNULL # Suppress BLAST warnings to reduce I/O + ).strip() + except subprocess.TimeoutExpired: + logger.warning("BLAST search timed out after 5 minutes for sequence") + os.remove(tmp_name) + return None + os.remove(tmp_name) if out: return out.split("\n", maxsplit=1)[0] @@ -234,13 +268,23 @@ async def search( if query.startswith(">") or re.fullmatch( r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I ): - async with _blast_lock: + # Only use lock for network BLAST (NCBIWWW), local BLAST can run in parallel + if self.use_local_blast: + # Local BLAST can run in parallel, no lock needed result = await loop.run_in_executor( _get_pool(), self.get_by_fasta, query, threshold ) + else: + # Network BLAST needs lock to respect rate limits + async with _blast_lock: + result = await loop.run_in_executor( + _get_pool(), self.get_by_fasta, query, threshold + ) # check if accession number - elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I): + # UniProt accession IDs: 6-10 characters, must start with a letter + # Format: [A-Z][A-Z0-9]{5,9} (6-10 chars total: 1 letter + 5-9 alphanumeric) + elif re.fullmatch(r"[A-Z][A-Z0-9]{5,9}", query, re.I): result = await loop.run_in_executor( _get_pool(), self.get_by_accession, query ) diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index 53962117..ae41fa21 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from typing import Iterator, Tuple from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage from graphgen.utils import load_json, logger, write_json @@ -42,6 +43,42 @@ def get_by_ids(self, ids, fields=None) -> list: def get_all(self) -> dict[str, dict]: return self._data + def iter_items(self) -> Iterator[Tuple[str, dict]]: + """ + Iterate over all items without loading everything into memory at once. + Returns an iterator of (key, value) tuples. + """ + for key, value in self._data.items(): + yield key, value + + def get_batch(self, keys: list[str]) -> dict[str, dict]: + """ + Get a batch of items by their keys. + + :param keys: List of keys to retrieve. + :return: Dictionary of {key: value} for the requested keys. + """ + return {key: self._data.get(key) for key in keys if key in self._data} + + def iter_batches(self, batch_size: int = 10000) -> Iterator[dict[str, dict]]: + """ + Iterate over items in batches to avoid loading everything into memory. + + :param batch_size: Number of items per batch. + :return: Iterator of dictionaries, each containing up to batch_size items. + """ + batch = {} + count = 0 + for key, value in self._data.items(): + batch[key] = value + count += 1 + if count >= batch_size: + yield batch + batch = {} + count = 0 + if batch: + yield batch + def filter_keys(self, data: list[str]) -> set[str]: return {s for s in data if s not in self._data} diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index d9e7f673..39723e76 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -93,7 +93,13 @@ def read_files( suffix = Path(file_path).suffix.lstrip(".").lower() reader = _build_reader(suffix, cache_dir) - yield from reader.read(file_path) + # Prefer stream reading if available (for memory efficiency) + if hasattr(reader, "read_stream"): + yield from reader.read_stream(file_path) + else: + # Fallback to regular read() method + for doc in reader.read(file_path): + yield doc except Exception as e: # pylint: disable=broad-except logger.exception("Error reading %s: %s", file_info.get("path"), e) diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py index 6017cfee..85119327 100644 --- a/graphgen/operators/search/search_all.py +++ b/graphgen/operators/search/search_all.py @@ -15,12 +15,16 @@ async def search_all( seed_data: dict, search_config: dict, + search_storage=None, + save_interval: int = 1000, ) -> dict: """ Perform searches across multiple search types and aggregate the results. :param seed_data: A dictionary containing seed data with entity names. :param search_config: A dictionary specifying which data sources to use for searching. - :return: A dictionary with + :param search_storage: Optional storage instance for periodic saving of results. + :param save_interval: Number of search results to accumulate before saving (default: 1000, 0 to disable). + :return: A dictionary with search results """ results = {} @@ -31,11 +35,50 @@ async def search_all( data = [d["content"] for d in data if "content" in d] data = list(set(data)) # Remove duplicates + # Prepare save callback for this data source + def make_save_callback(source_name): + def save_callback(intermediate_results, completed_count): + """Save intermediate search results.""" + if search_storage is None: + return + + # Convert results list to dict format + # Results are tuples of (query, result_dict) or just result_dict + batch_results = {} + for result in intermediate_results: + if result is None: + continue + # Check if result is a dict with _search_query key + if isinstance(result, dict) and "_search_query" in result: + query = result["_search_query"] + # Create a key for the result (using query as key) + key = f"{source_name}:{query}" + batch_results[key] = result + elif isinstance(result, dict): + # If no _search_query, use a generated key + key = f"{source_name}:{completed_count}" + batch_results[key] = result + + if batch_results: + # Filter out already existing keys + new_keys = search_storage.filter_keys(list(batch_results.keys())) + new_results = {k: v for k, v in batch_results.items() if k in new_keys} + if new_results: + search_storage.upsert(new_results) + search_storage.index_done_callback() + logger.debug("Saved %d intermediate results for %s", len(new_results), source_name) + + return save_callback + if data_source == "uniprot": from graphgen.models import UniProtSearch + uniprot_params = search_config.get("uniprot_params", {}).copy() + # Get max_concurrent from config before passing params to constructor + max_concurrent = uniprot_params.pop("max_concurrent", None) + uniprot_search_client = UniProtSearch( - **search_config.get("uniprot_params", {}) + **uniprot_params ) uniprot_results = await run_concurrent( @@ -43,14 +86,21 @@ async def search_all( data, desc="Searching UniProt database", unit="keyword", + save_interval=save_interval if save_interval > 0 else 0, + save_callback=make_save_callback("uniprot") if search_storage and save_interval > 0 else None, + max_concurrent=max_concurrent, ) results[data_source] = uniprot_results elif data_source == "ncbi": from graphgen.models import NCBISearch + ncbi_params = search_config.get("ncbi_params", {}).copy() + # Get max_concurrent from config before passing params to constructor + max_concurrent = ncbi_params.pop("max_concurrent", None) + ncbi_search_client = NCBISearch( - **search_config.get("ncbi_params", {}) + **ncbi_params ) ncbi_results = await run_concurrent( @@ -58,14 +108,21 @@ async def search_all( data, desc="Searching NCBI database", unit="keyword", + save_interval=save_interval if save_interval > 0 else 0, + save_callback=make_save_callback("ncbi") if search_storage and save_interval > 0 else None, + max_concurrent=max_concurrent, ) results[data_source] = ncbi_results elif data_source == "rnacentral": from graphgen.models import RNACentralSearch + rnacentral_params = search_config.get("rnacentral_params", {}).copy() + # Get max_concurrent from config before passing params to constructor + max_concurrent = rnacentral_params.pop("max_concurrent", None) + rnacentral_search_client = RNACentralSearch( - **search_config.get("rnacentral_params", {}) + **rnacentral_params ) rnacentral_results = await run_concurrent( @@ -73,6 +130,9 @@ async def search_all( data, desc="Searching RNAcentral database", unit="keyword", + save_interval=save_interval if save_interval > 0 else 0, + save_callback=make_save_callback("rnacentral") if search_storage and save_interval > 0 else None, + max_concurrent=max_concurrent, ) results[data_source] = rnacentral_results diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index ac63f87b..2a8c492c 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -17,11 +17,48 @@ async def run_concurrent( desc: str = "processing", unit: str = "item", progress_bar: Optional[gr.Progress] = None, + save_interval: int = 0, + save_callback: Optional[Callable[[List[R], int], None]] = None, + max_concurrent: Optional[int] = None, ) -> List[R]: - tasks = [asyncio.create_task(coro_fn(it)) for it in items] + """ + Run coroutines concurrently with optional periodic saving. + + :param coro_fn: Coroutine function to run for each item + :param items: List of items to process + :param desc: Description for progress bar + :param unit: Unit name for progress bar + :param progress_bar: Optional Gradio progress bar + :param save_interval: Number of completed tasks before calling save_callback (0 to disable) + :param save_callback: Callback function to save intermediate results (results, completed_count) + :param max_concurrent: Maximum number of concurrent tasks (None for unlimited, default: None) + :return: List of results + """ + if not items: + return [] + + # Use semaphore to limit concurrent tasks if max_concurrent is specified + semaphore = asyncio.Semaphore(max_concurrent) if max_concurrent is not None and max_concurrent > 0 else None + + async def run_with_semaphore(item: T) -> R: + """Wrapper to apply semaphore if needed.""" + if semaphore: + async with semaphore: + return await coro_fn(item) + else: + return await coro_fn(item) + + # Create tasks with concurrency limit + if max_concurrent is not None and max_concurrent > 0: + # Use semaphore-controlled wrapper + tasks = [asyncio.create_task(run_with_semaphore(it)) for it in items] + else: + # Original behavior: create all tasks at once + tasks = [asyncio.create_task(coro_fn(it)) for it in items] completed_count = 0 results = [] + pending_save_results = [] pbar = tqdm_async(total=len(items), desc=desc, unit=unit) @@ -32,6 +69,8 @@ async def run_concurrent( try: result = await future results.append(result) + if save_interval > 0 and save_callback is not None: + pending_save_results.append(result) except Exception as e: # pylint: disable=broad-except logger.exception("Task failed: %s", e) # even if failed, record it to keep results consistent with tasks @@ -44,11 +83,31 @@ async def run_concurrent( progress = completed_count / len(items) progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})") + # Periodic save + if save_interval > 0 and save_callback is not None and completed_count % save_interval == 0: + try: + # Filter out exceptions before saving + valid_results = [res for res in pending_save_results if not isinstance(res, Exception)] + save_callback(valid_results, completed_count) + pending_save_results = [] # Clear after saving + logger.info("Saved intermediate results: %d/%d completed", completed_count, len(items)) + except Exception as e: + logger.warning("Failed to save intermediate results: %s", e) + pbar.close() if progress_bar is not None: progress_bar(1.0, desc=f"{desc} (completed)") + # Save remaining results if any + if save_interval > 0 and save_callback is not None and pending_save_results: + try: + valid_results = [res for res in pending_save_results if not isinstance(res, Exception)] + save_callback(valid_results, completed_count) + logger.info("Saved final intermediate results: %d completed", completed_count) + except Exception as e: + logger.warning("Failed to save final intermediate results: %s", e) + # filter out exceptions results = [res for res in results if not isinstance(res, Exception)] diff --git a/scripts/search/build_db/build_dna_blast_db.sh b/scripts/search/build_db/build_dna_blast_db.sh index 1928d7d0..21b86141 100755 --- a/scripts/search/build_db/build_dna_blast_db.sh +++ b/scripts/search/build_db/build_dna_blast_db.sh @@ -24,8 +24,8 @@ set -e # - {category}.{number}.genomic.fna.gz (基因组序列) # - {category}.{number}.rna.fna.gz (RNA序列) # -# Usage: ./build_dna_blast_db.sh [human_mouse|representative|complete|all] -# human_mouse: Download only Homo sapiens and Mus musculus sequences (minimal, smallest) +# Usage: ./build_dna_blast_db.sh [human_mouse_drosophila_yeast|representative|complete|all] +# human_mouse_drosophila_yeast: Download only Homo sapiens, Mus musculus, Drosophila melanogaster, and Saccharomyces cerevisiae sequences (minimal, smallest) # representative: Download genomic sequences from major categories (recommended, smaller) # Includes: vertebrate_mammalian, vertebrate_other, bacteria, archaea, fungi # complete: Download all complete genomic sequences from complete/ directory (very large) @@ -36,7 +36,7 @@ set -e # For CentOS/RHEL/Fedora: sudo dnf install ncbi-blast+ # Or download from: https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/ -DOWNLOAD_TYPE=${1:-human_mouse} +DOWNLOAD_TYPE=${1:-human_mouse_drosophila_yeast} # Better to use a stable DOWNLOAD_TMP name to support resuming downloads DOWNLOAD_TMP=_downloading_dna @@ -68,7 +68,8 @@ check_file_for_species() { # This should be sufficient to identify the species in most cases if curl -s --max-time 30 --range 0-512000 "${url}" -o "${temp_file}" 2>/dev/null && [ -s "${temp_file}" ]; then # Try to decompress and check for species names - if gunzip -c "${temp_file}" 2>/dev/null | head -2000 | grep -qE "(Homo sapiens|Mus musculus)"; then + # Check for: Homo sapiens (人), Mus musculus (小鼠), Drosophila melanogaster (果蝇), Saccharomyces cerevisiae (酵母) + if gunzip -c "${temp_file}" 2>/dev/null | head -2000 | grep -qE "(Homo sapiens|Mus musculus|Drosophila melanogaster|Saccharomyces cerevisiae)"; then rm -f "${temp_file}" return 0 # Contains target species else @@ -84,39 +85,50 @@ check_file_for_species() { # Download based on type case ${DOWNLOAD_TYPE} in - human_mouse) - echo "Downloading RefSeq sequences for Homo sapiens and Mus musculus only (minimal size)..." - echo "This will check each file to see if it contains human or mouse sequences..." - category="vertebrate_mammalian" - echo "Checking files in ${category} category..." + human_mouse_drosophila_yeast) + echo "Downloading RefSeq sequences for Homo sapiens, Mus musculus, Drosophila melanogaster, and Saccharomyces cerevisiae (minimal size)..." + echo "This will check each file to see if it contains target species sequences..." - # Get list of files and save to temp file to avoid subshell issues - curl -s "https://ftp.ncbi.nlm.nih.gov/refseq/release/${category}/" | \ - grep -oE 'href="[^"]*\.genomic\.fna\.gz"' | \ - sed 's/href="\(.*\)"/\1/' > /tmp/refseq_files.txt - - file_count=0 - download_count=0 + # Check multiple categories: vertebrate_mammalian (人、小鼠), invertebrate (果蝇), fungi (酵母) + categories="vertebrate_mammalian invertebrate fungi" + total_file_count=0 + total_download_count=0 - while read filename; do - file_count=$((file_count + 1)) - url="https://ftp.ncbi.nlm.nih.gov/refseq/release/${category}/${filename}" - echo -n "[${file_count}] Checking ${filename}... " + for category in ${categories}; do + echo "Checking files in ${category} category..." - if check_file_for_species "${url}" "${filename}"; then - echo "✓ contains target species, downloading..." - download_count=$((download_count + 1)) - wget -c -q --show-progress "${url}" || { - echo "Warning: Failed to download ${filename}" - } - else - echo "✗ skipping (no human/mouse data)" - fi - done < /tmp/refseq_files.txt + # Get list of files and save to temp file to avoid subshell issues + curl -s "https://ftp.ncbi.nlm.nih.gov/refseq/release/${category}/" | \ + grep -oE 'href="[^"]*\.genomic\.fna\.gz"' | \ + sed 's/href="\(.*\)"/\1/' > /tmp/refseq_files_${category}.txt + + file_count=0 + download_count=0 + + while read filename; do + file_count=$((file_count + 1)) + total_file_count=$((total_file_count + 1)) + url="https://ftp.ncbi.nlm.nih.gov/refseq/release/${category}/${filename}" + echo -n "[${total_file_count}] Checking ${category}/${filename}... " + + if check_file_for_species "${url}" "${filename}"; then + echo "✓ contains target species, downloading..." + download_count=$((download_count + 1)) + total_download_count=$((total_download_count + 1)) + wget -c -q --show-progress "${url}" || { + echo "Warning: Failed to download ${filename}" + } + else + echo "✗ skipping (no target species data)" + fi + done < /tmp/refseq_files_${category}.txt + + rm -f /tmp/refseq_files_${category}.txt + echo " ${category}: Checked ${file_count} files, downloaded ${download_count} files." + done - rm -f /tmp/refseq_files.txt echo "" - echo "Summary: Checked ${file_count} files, downloaded ${download_count} files containing human or mouse sequences." + echo "Summary: Checked ${total_file_count} files total, downloaded ${total_download_count} files containing target species (human, mouse, fruit fly, yeast)." ;; representative) echo "Downloading RefSeq representative sequences (recommended, smaller size)..." @@ -168,8 +180,8 @@ case ${DOWNLOAD_TYPE} in ;; *) echo "Error: Unknown download type '${DOWNLOAD_TYPE}'" - echo "Usage: $0 [human_mouse|representative|complete|all]" - echo " human_mouse: Download only Homo sapiens and Mus musculus (minimal)" + echo "Usage: $0 [human_mouse_drosophila_yeast|representative|complete|all]" + echo " human_mouse_drosophila_yeast: Download only Homo sapiens, Mus musculus, Drosophila melanogaster, and Saccharomyces cerevisiae (minimal)" echo " representative: Download major categories (recommended)" echo " complete: Download all complete genomic sequences (very large)" echo " all: Download all genomic sequences (extremely large)" diff --git a/scripts/search/build_db/build_protein_blast_db.sh b/scripts/search/build_db/build_protein_blast_db.sh index 9292875a..a9169959 100755 --- a/scripts/search/build_db/build_protein_blast_db.sh +++ b/scripts/search/build_db/build_protein_blast_db.sh @@ -9,48 +9,137 @@ set -e # For CentOS/RHEL/Fedora: sudo dnf install ncbi-blast+ # Or download from: https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/ +# NOTE: UniProt mirror +# Available mirrors: +# - UK/EBI: ftp://ftp.ebi.ac.uk/pub/databases/uniprot (current, recommended) +# - US: ftp://ftp.uniprot.org/pub/databases/uniprot +# - CH: ftp://ftp.expasy.org/databases/uniprot +UNIPROT_BASE="ftp://ftp.ebi.ac.uk/pub/databases/uniprot" + +# Parse command line arguments +DOWNLOAD_MODE="sprot" # sprot (Swiss-Prot) or full (sprot + trembl) + +usage() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " -s, --sprot-only Download only Swiss-Prot database (recommended, high quality)" + echo " -f, --full Download full release (Swiss-Prot + TrEMBL, merged as uniprot_\${RELEASE})" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 --sprot-only # Download only uniprot_sprot" + echo " $0 --full # Download uniprot_\${RELEASE} (Swiss-Prot + TrEMBL)" +} + +while [[ $# -gt 0 ]]; do + case $1 in + -s|--sprot-only) + DOWNLOAD_MODE="sprot" + shift + ;; + -f|--full) + DOWNLOAD_MODE="full" + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" + usage + exit 1 + ;; + esac +done + +echo "Download mode: ${DOWNLOAD_MODE}" +if [ "${DOWNLOAD_MODE}" = "sprot" ]; then + echo " - Will download: uniprot_sprot only" +else + echo " - Will download: uniprot_\${RELEASE} (Swiss-Prot + TrEMBL merged)" +fi +echo "Using mirror: ${UNIPROT_BASE} (EBI/UK - fast for Asia/Europe)" +echo "" + # Better to use a stable DOWNLOAD_TMP name to support resuming downloads DOWNLOAD_TMP=_downloading mkdir -p ${DOWNLOAD_TMP} cd ${DOWNLOAD_TMP} -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/RELEASE.metalink" +echo "Downloading RELEASE.metalink..." +wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/RELEASE.metalink" # Extract the release name (like 2017_10 or 2017_1) # Use sed for cross-platform compatibility (works on both macOS and Linux) RELEASE=$(sed -n 's/.*\([0-9]\{4\}_[0-9]\{1,2\}\)<\/version>.*/\1/p' RELEASE.metalink | head -1) -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz" -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz" -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/reldate.txt" -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/README" -wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/LICENSE" +echo "UniProt release: ${RELEASE}" +echo "" + +# Download Swiss-Prot (always needed) +echo "Downloading uniprot_sprot.fasta.gz..." +wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz" + +# Download TrEMBL only if full mode +if [ "${DOWNLOAD_MODE}" = "full" ]; then + echo "Downloading uniprot_trembl.fasta.gz..." + wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz" +fi + +# Download metadata files +echo "Downloading metadata files..." +wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/reldate.txt" +wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/README" +wget -c "${UNIPROT_BASE}/current_release/knowledgebase/complete/LICENSE" cd .. -mkdir ${RELEASE} +mkdir -p ${RELEASE} mv ${DOWNLOAD_TMP}/* ${RELEASE} rmdir ${DOWNLOAD_TMP} cd ${RELEASE} +echo "" +echo "Extracting files..." gunzip uniprot_sprot.fasta.gz -gunzip uniprot_trembl.fasta.gz -cat uniprot_sprot.fasta uniprot_trembl.fasta >uniprot_${RELEASE}.fasta +if [ "${DOWNLOAD_MODE}" = "full" ]; then + gunzip uniprot_trembl.fasta.gz + echo "Merging Swiss-Prot and TrEMBL..." + cat uniprot_sprot.fasta uniprot_trembl.fasta >uniprot_${RELEASE}.fasta +fi + +echo "" +echo "Building BLAST databases..." -makeblastdb -in uniprot_${RELEASE}.fasta -out uniprot_${RELEASE} -dbtype prot -parse_seqids -title uniprot_${RELEASE} +# Always build Swiss-Prot database makeblastdb -in uniprot_sprot.fasta -out uniprot_sprot -dbtype prot -parse_seqids -title uniprot_sprot -makeblastdb -in uniprot_trembl.fasta -out uniprot_trembl -dbtype prot -parse_seqids -title uniprot_trembl + +# Build full release database only if in full mode +if [ "${DOWNLOAD_MODE}" = "full" ]; then + makeblastdb -in uniprot_${RELEASE}.fasta -out uniprot_${RELEASE} -dbtype prot -parse_seqids -title uniprot_${RELEASE} + makeblastdb -in uniprot_trembl.fasta -out uniprot_trembl -dbtype prot -parse_seqids -title uniprot_trembl +fi cd .. +echo "" echo "BLAST databases created successfully!" echo "Database locations:" -echo " - Combined: $(pwd)/${RELEASE}/uniprot_${RELEASE}" -echo " - Swiss-Prot: $(pwd)/${RELEASE}/uniprot_sprot" -echo " - TrEMBL: $(pwd)/${RELEASE}/uniprot_trembl" -echo "" -echo "To use these databases, set in your config:" -echo " local_blast_db: $(pwd)/${RELEASE}/uniprot_sprot # or uniprot_${RELEASE} or uniprot_trembl" +if [ "${DOWNLOAD_MODE}" = "sprot" ]; then + echo " - Swiss-Prot: $(pwd)/${RELEASE}/uniprot_sprot" + echo "" + echo "To use this database, set in your config:" + echo " local_blast_db: $(pwd)/${RELEASE}/uniprot_sprot" +else + echo " - Combined: $(pwd)/${RELEASE}/uniprot_${RELEASE}" + echo " - Swiss-Prot: $(pwd)/${RELEASE}/uniprot_sprot" + echo " - TrEMBL: $(pwd)/${RELEASE}/uniprot_trembl" + echo "" + echo "To use these databases, set in your config:" + echo " local_blast_db: $(pwd)/${RELEASE}/uniprot_sprot # or uniprot_${RELEASE} or uniprot_trembl" +fi diff --git a/scripts/search/build_db/build_rna_blast_db.sh b/scripts/search/build_db/build_rna_blast_db.sh index 26e1cd33..503c654b 100755 --- a/scripts/search/build_db/build_rna_blast_db.sh +++ b/scripts/search/build_db/build_rna_blast_db.sh @@ -10,16 +10,20 @@ set -e # RNAcentral is a comprehensive database of non-coding RNA sequences that # integrates data from multiple expert databases including RefSeq, Rfam, etc. # -# Usage: ./build_rna_blast_db.sh [all|list|database_name] +# Usage: ./build_rna_blast_db.sh [all|list|selected|database_name...] # all (default): Download complete active database (~8.4G compressed) # list: List all available database subsets +# selected: Download predefined database subsets (ensembl_gencode, mirbase, gtrnadb, refseq, lncbase) # database_name: Download specific database subset (e.g., refseq, rfam, mirbase) +# database_name1 database_name2 ...: Download multiple database subsets # # Available database subsets (examples): # - refseq.fasta (~98M): RefSeq RNA sequences # - rfam.fasta (~1.5G): Rfam RNA families # - mirbase.fasta (~10M): microRNA sequences -# - ensembl.fasta (~2.9G): Ensembl annotations +# - ensembl_gencode.fasta (~337M): Ensembl/GENCODE annotations (human) +# - gtrnadb.fasta (~38M): tRNA sequences +# - lncbase.fasta (~106K): Human lncRNA database # - See "list" option for complete list # # The complete "active" database contains all sequences from all expert databases. @@ -30,20 +34,24 @@ set -e # For CentOS/RHEL/Fedora: sudo dnf install ncbi-blast+ # Or download from: https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/ -# RNAcentral HTTP base URL (using HTTPS for better reliability) +# RNAcentral base URL (using EBI HTTPS) +# NOTE: RNAcentral only has one official mirror at EBI RNACENTRAL_BASE="https://ftp.ebi.ac.uk/pub/databases/RNAcentral" RNACENTRAL_RELEASE_URL="${RNACENTRAL_BASE}/current_release" RNACENTRAL_SEQUENCES_URL="${RNACENTRAL_RELEASE_URL}/sequences" RNACENTRAL_BY_DB_URL="${RNACENTRAL_SEQUENCES_URL}/by-database" -# Parse command line argument +# Parse command line arguments DB_SELECTION=${1:-all} +# Predefined database list for "selected" option +SELECTED_DATABASES=("ensembl_gencode" "mirbase" "gtrnadb" "refseq" "lncbase") + # List available databases if requested if [ "${DB_SELECTION}" = "list" ]; then echo "Available RNAcentral database subsets:" echo "" - echo "Fetching list from RNAcentral FTP..." + echo "Fetching list from RNAcentral..." listing=$(curl -s "${RNACENTRAL_BY_DB_URL}/") echo "${listing}" | \ grep -oE '' | \ @@ -54,30 +62,41 @@ if [ "${DB_SELECTION}" = "list" ]; then echo " - ${db%.fasta}: ${size}" done echo "" - echo "Usage: $0 [database_name]" + echo "Usage: $0 [all|list|selected|database_name...]" echo " Example: $0 refseq # Download only RefSeq sequences (~98M)" echo " Example: $0 rfam # Download only Rfam sequences (~1.5G)" + echo " Example: $0 selected # Download predefined databases (ensembl_gencode, mirbase, gtrnadb, refseq, lncbase)" + echo " Example: $0 refseq mirbase # Download multiple databases" echo " Example: $0 all # Download complete active database (~8.4G)" exit 0 fi -# Better to use a stable DOWNLOAD_TMP name to support resuming downloads -DOWNLOAD_TMP=_downloading_rnacentral -mkdir -p ${DOWNLOAD_TMP} -cd ${DOWNLOAD_TMP} +# Determine which databases to download +if [ "${DB_SELECTION}" = "selected" ]; then + # Use predefined database list + DATABASES=("${SELECTED_DATABASES[@]}") + echo "Downloading selected databases: ${DATABASES[*]}" +elif [ "${DB_SELECTION}" = "all" ]; then + # Single database mode (all) + DATABASES=("all") +else + # Multiple databases provided as arguments + DATABASES=("$@") +fi -# Get RNAcentral release version from release notes +# Get RNAcentral release version from release notes (once for all databases) echo "Getting RNAcentral release information..." RELEASE_NOTES_URL="${RNACENTRAL_RELEASE_URL}/release_notes.txt" -RELEASE_NOTES="release_notes.txt" -wget -q "${RELEASE_NOTES_URL}" 2>/dev/null || { +RELEASE_NOTES_TMP=$(mktemp) +wget -q "${RELEASE_NOTES_URL}" -O "${RELEASE_NOTES_TMP}" 2>/dev/null || { echo "Warning: Could not download release notes, using current date as release identifier" RELEASE=$(date +%Y%m%d) } -if [ -f "${RELEASE_NOTES}" ]; then +if [ -f "${RELEASE_NOTES_TMP}" ] && [ -s "${RELEASE_NOTES_TMP}" ]; then # Try to extract version from release notes (first line usually contains version info) - RELEASE=$(head -1 "${RELEASE_NOTES}" | grep -oE '[0-9]+\.[0-9]+' | head -1 | tr -d '.') + RELEASE=$(head -1 "${RELEASE_NOTES_TMP}" | grep -oE '[0-9]+\.[0-9]+' | head -1 | tr -d '.') + rm -f "${RELEASE_NOTES_TMP}" fi if [ -z "${RELEASE}" ]; then @@ -87,133 +106,328 @@ else echo "RNAcentral release: ${RELEASE}" fi -# Download RNAcentral FASTA file -if [ "${DB_SELECTION}" = "all" ]; then - # Download complete active database - FASTA_FILE="rnacentral_active.fasta.gz" - DB_NAME="rnacentral" - echo "Downloading RNAcentral active sequences (~8.4G)..." - echo " Contains sequences currently present in at least one expert database" - echo " Uses standard URS IDs (e.g., URS000149A9AF)" - echo " ⭐ MATCHES the online RNAcentral API database - ensures consistency" - FASTA_URL="${RNACENTRAL_SEQUENCES_URL}/${FASTA_FILE}" - IS_COMPRESSED=true -else - # Download specific database subset - DB_NAME="${DB_SELECTION}" - FASTA_FILE="${DB_SELECTION}.fasta" - echo "Downloading RNAcentral database subset: ${DB_SELECTION}" - echo " This is a subset of the active database from a specific expert database" - echo " File: ${FASTA_FILE}" - FASTA_URL="${RNACENTRAL_BY_DB_URL}/${FASTA_FILE}" - IS_COMPRESSED=false - - # Check if database exists - if ! curl -s -o /dev/null -w "%{http_code}" "${FASTA_URL}" | grep -q "200"; then - echo "Error: Database '${DB_SELECTION}' not found" - echo "Run '$0 list' to see available databases" +# Process each database +DB_COUNT=${#DATABASES[@]} +DB_INDEX=0 + +for DB_SELECTION in "${DATABASES[@]}"; do + DB_INDEX=$((DB_INDEX + 1)) + echo "" + echo "==========================================" + echo "Processing database ${DB_INDEX}/${DB_COUNT}: ${DB_SELECTION}" + echo "==========================================" + echo "" + + # Check if database already exists and is complete + # First check with current release version + if [ "${DB_SELECTION}" = "all" ]; then + OUTPUT_DIR="rnacentral_${RELEASE}" + DB_NAME="rnacentral" + DB_OUTPUT_NAME="${DB_NAME}_${RELEASE}" + else + OUTPUT_DIR="rnacentral_${DB_SELECTION}_${RELEASE}" + DB_NAME="${DB_SELECTION}" + DB_OUTPUT_NAME="${DB_NAME}_${RELEASE}" + fi + + # Check if BLAST database already exists with current release + if [ -d "${OUTPUT_DIR}" ] && [ -f "${OUTPUT_DIR}/${DB_OUTPUT_NAME}.nhr" ] && [ -f "${OUTPUT_DIR}/${DB_OUTPUT_NAME}.nin" ]; then + echo "✓ Database ${DB_SELECTION} already exists and appears complete: ${OUTPUT_DIR}/" + echo " BLAST database: ${OUTPUT_DIR}/${DB_OUTPUT_NAME}" + echo " Skipping download and database creation..." + continue + fi + + # Also check for any existing version of this database (e.g., different release dates) + EXISTING_DIR=$(ls -d rnacentral_${DB_SELECTION}_* 2>/dev/null | head -1) + if [ -n "${EXISTING_DIR}" ] && [ "${DB_SELECTION}" != "all" ]; then + EXISTING_DB_NAME=$(basename "${EXISTING_DIR}" | sed "s/rnacentral_${DB_SELECTION}_//") + if [ -f "${EXISTING_DIR}/${DB_SELECTION}_${EXISTING_DB_NAME}.nhr" ] && [ -f "${EXISTING_DIR}/${DB_SELECTION}_${EXISTING_DB_NAME}.nin" ]; then + echo "✓ Database ${DB_SELECTION} already exists (version ${EXISTING_DB_NAME}): ${EXISTING_DIR}/" + echo " BLAST database: ${EXISTING_DIR}/${DB_SELECTION}_${EXISTING_DB_NAME}" + echo " Skipping download and database creation..." + echo " Note: Using existing version ${EXISTING_DB_NAME} instead of ${RELEASE}" + continue + fi + fi + + # Better to use a stable DOWNLOAD_TMP name to support resuming downloads + DOWNLOAD_TMP="_downloading_rnacentral_${DB_SELECTION}" + mkdir -p ${DOWNLOAD_TMP} + cd ${DOWNLOAD_TMP} + + # Download RNAcentral FASTA file + if [ "${DB_SELECTION}" = "all" ]; then + # Download complete active database + FASTA_FILE="rnacentral_active.fasta.gz" + DB_NAME="rnacentral" + echo "Downloading RNAcentral active sequences (~8.4G)..." + echo " Contains sequences currently present in at least one expert database" + echo " Uses standard URS IDs (e.g., URS000149A9AF)" + echo " ⭐ MATCHES the online RNAcentral API database - ensures consistency" + FASTA_URL="${RNACENTRAL_SEQUENCES_URL}/${FASTA_FILE}" + IS_COMPRESSED=true + else + # Download specific database subset + DB_NAME="${DB_SELECTION}" + FASTA_FILE="${DB_SELECTION}.fasta" + echo "Downloading RNAcentral database subset: ${DB_SELECTION}" + echo " This is a subset of the active database from a specific expert database" + echo " File: ${FASTA_FILE}" + FASTA_URL="${RNACENTRAL_BY_DB_URL}/${FASTA_FILE}" + IS_COMPRESSED=false + + # Check if database exists (use HTTP status code check for HTTPS) + HTTP_CODE=$(curl -s --max-time 10 -o /dev/null -w "%{http_code}" "${FASTA_URL}" 2>/dev/null | tail -1 || echo "000") + if ! echo "${HTTP_CODE}" | grep -q "^200$"; then + echo "Error: Database '${DB_SELECTION}' not found (HTTP code: ${HTTP_CODE})" + echo "Run '$0 list' to see available databases" + cd .. + rm -rf ${DOWNLOAD_TMP} + exit 1 + fi + fi + + echo "Downloading from: ${FASTA_URL}" + echo "This may take a while depending on your internet connection..." + if [ "${DB_SELECTION}" = "all" ]; then + echo "File size is approximately 8-9GB, please be patient..." + else + echo "Downloading database subset..." + fi + + wget -c "${FASTA_URL}" || { + echo "Error: Failed to download RNAcentral FASTA file" + echo "Please check your internet connection and try again" + echo "URL: ${FASTA_URL}" + cd .. + rm -rf ${DOWNLOAD_TMP} + exit 1 + } + + if [ ! -f "${FASTA_FILE}" ]; then + echo "Error: Downloaded file not found" + cd .. + rm -rf ${DOWNLOAD_TMP} exit 1 fi -fi - -echo "Downloading from: ${FASTA_URL}" -echo "This may take a while depending on your internet connection..." -if [ "${DB_SELECTION}" = "all" ]; then - echo "File size is approximately 8-9GB, please be patient..." -else - echo "Downloading database subset..." -fi -wget -c --progress=bar:force "${FASTA_URL}" 2>&1 || { - echo "Error: Failed to download RNAcentral FASTA file" - echo "Please check your internet connection and try again" - echo "You can also try downloading manually from: ${FASTA_URL}" - exit 1 -} - -if [ ! -f "${FASTA_FILE}" ]; then - echo "Error: Downloaded file not found" - exit 1 -fi + + cd .. + + # Create release directory + if [ "${DB_SELECTION}" = "all" ]; then + OUTPUT_DIR="rnacentral_${RELEASE}" + else + OUTPUT_DIR="rnacentral_${DB_NAME}_${RELEASE}" + fi + mkdir -p ${OUTPUT_DIR} + mv ${DOWNLOAD_TMP}/* ${OUTPUT_DIR}/ 2>/dev/null || true + rmdir ${DOWNLOAD_TMP} 2>/dev/null || true + + cd ${OUTPUT_DIR} + + # Extract FASTA file if compressed + echo "Preparing RNAcentral sequences..." + if [ -f "${FASTA_FILE}" ]; then + if [ "${IS_COMPRESSED}" = "true" ]; then + echo "Decompressing ${FASTA_FILE}..." + OUTPUT_FASTA="${DB_NAME}_${RELEASE}.fasta" + gunzip -c "${FASTA_FILE}" > "${OUTPUT_FASTA}" || { + echo "Error: Failed to decompress FASTA file" + cd .. + exit 1 + } + # Optionally remove the compressed file to save space + # rm "${FASTA_FILE}" + else + # File is not compressed, just copy/rename + OUTPUT_FASTA="${DB_NAME}_${RELEASE}.fasta" + cp "${FASTA_FILE}" "${OUTPUT_FASTA}" || { + echo "Error: Failed to copy FASTA file" + cd .. + exit 1 + } + fi + else + echo "Error: FASTA file not found" + cd .. + exit 1 + fi + + # Check if we have sequences + if [ ! -s "${OUTPUT_FASTA}" ]; then + echo "Error: FASTA file is empty" + cd .. + exit 1 + fi + + # Get file size for user information + FILE_SIZE=$(du -h "${OUTPUT_FASTA}" | cut -f1) + echo "FASTA file size: ${FILE_SIZE}" + + echo "Creating BLAST database..." + # Create BLAST database for RNA sequences (use -dbtype nucl for nucleotide) + # Note: RNAcentral uses RNAcentral IDs (URS...) as sequence identifiers, + # which matches the format expected by the RNACentralSearch class + DB_OUTPUT_NAME="${DB_NAME}_${RELEASE}" + makeblastdb -in "${OUTPUT_FASTA}" \ + -out "${DB_OUTPUT_NAME}" \ + -dbtype nucl \ + -parse_seqids \ + -title "RNAcentral_${DB_NAME}_${RELEASE}" + + echo "" + echo "BLAST database created successfully!" + echo "Database location: $(pwd)/${DB_OUTPUT_NAME}" + echo "" + echo "To use this database, set in your config (search_rna_config.yaml):" + echo " rnacentral_params:" + echo " use_local_blast: true" + echo " local_blast_db: $(pwd)/${DB_OUTPUT_NAME}" + echo "" + echo "Note: The database files are:" + ls -lh ${DB_OUTPUT_NAME}.* | head -5 + echo "" + if [ "${DB_SELECTION}" = "all" ]; then + echo "This database uses RNAcentral IDs (URS...), which matches the online" + echo "RNAcentral search API, ensuring consistent results between local and online searches." + else + echo "This is a subset database from ${DB_SELECTION} expert database." + echo "For full coverage matching online API, use 'all' option." + fi + + cd .. +done -cd .. +echo "" +echo "==========================================" +echo "All databases processed successfully!" +echo "==========================================" +echo "" -# Create release directory -if [ "${DB_SELECTION}" = "all" ]; then - OUTPUT_DIR="rnacentral_${RELEASE}" -else - OUTPUT_DIR="rnacentral_${DB_NAME}_${RELEASE}" -fi -mkdir -p ${OUTPUT_DIR} -mv ${DOWNLOAD_TMP}/* ${OUTPUT_DIR}/ 2>/dev/null || true -rmdir ${DOWNLOAD_TMP} 2>/dev/null || true - -cd ${OUTPUT_DIR} - -# Extract FASTA file if compressed -echo "Preparing RNAcentral sequences..." -if [ -f "${FASTA_FILE}" ]; then - if [ "${IS_COMPRESSED}" = "true" ]; then - echo "Decompressing ${FASTA_FILE}..." - OUTPUT_FASTA="${DB_NAME}_${RELEASE}.fasta" - gunzip -c "${FASTA_FILE}" > "${OUTPUT_FASTA}" || { - echo "Error: Failed to decompress FASTA file" - exit 1 +# If multiple databases were downloaded, offer to merge them +if [ ${#DATABASES[@]} -gt 1 ] && [ "${DATABASES[0]}" != "all" ]; then + echo "Multiple databases downloaded. Creating merged database for unified search..." + MERGED_DIR="rnacentral_merged_${RELEASE}" + mkdir -p ${MERGED_DIR} + cd ${MERGED_DIR} + + MERGED_FASTA="rnacentral_merged_${RELEASE}.fasta" + MERGED_FASTA_TMP="${MERGED_FASTA}.tmp" + echo "Combining FASTA files from all databases..." + echo " Note: Duplicate sequence IDs will be removed (keeping first occurrence)..." + + # Combine all FASTA files into a temporary file + # Find actual database directories (may have different release versions) + FOUND_ANY=false + for DB_SELECTION in "${DATABASES[@]}"; do + [ "${DB_SELECTION}" = "all" ] && continue + + # Try current release version first, then search for any existing version + OUTPUT_FASTA="../rnacentral_${DB_SELECTION}_${RELEASE}/${DB_SELECTION}_${RELEASE}.fasta" + [ ! -f "${OUTPUT_FASTA}" ] && { + EXISTING_DIR=$(ls -d ../rnacentral_${DB_SELECTION}_* 2>/dev/null | head -1) + [ -n "${EXISTING_DIR}" ] && { + EXISTING_VERSION=$(basename "${EXISTING_DIR}" | sed "s/rnacentral_${DB_SELECTION}_//") + OUTPUT_FASTA="${EXISTING_DIR}/${DB_SELECTION}_${EXISTING_VERSION}.fasta" + } } - # Optionally remove the compressed file to save space - # rm "${FASTA_FILE}" - else - # File is not compressed, just copy/rename - OUTPUT_FASTA="${DB_NAME}_${RELEASE}.fasta" - cp "${FASTA_FILE}" "${OUTPUT_FASTA}" || { - echo "Error: Failed to copy FASTA file" - exit 1 + + if [ -f "${OUTPUT_FASTA}" ]; then + echo " Adding ${DB_SELECTION} sequences..." + cat "${OUTPUT_FASTA}" >> "${MERGED_FASTA_TMP}" + FOUND_ANY=true + else + echo " Warning: Could not find FASTA file for ${DB_SELECTION}" + fi + done + + # Validate that we have files to merge + if [ "${FOUND_ANY}" = "false" ] || [ ! -s "${MERGED_FASTA_TMP}" ]; then + echo "Error: No FASTA files found to merge" + cd .. + rm -rf ${MERGED_DIR} + exit 1 + fi + + # Remove duplicates based on sequence ID (keeping first occurrence) + echo " Removing duplicate sequence IDs..." + awk ' + /^>/ { + # Process previous sequence if we have one + if (current_id != "" && !seen[current_id]) { + print current_header ORS current_seq + seen[current_id] = 1 + } + # Start new sequence + current_header = $0 + current_id = substr($0, 2) + sub(/[ \t].*/, "", current_id) # Extract ID up to first space/tab + current_seq = "" + next + } + { + # Accumulate sequence data (preserve newlines) + current_seq = (current_seq == "" ? $0 : current_seq "\n" $0) + } + END { + # Process last sequence + if (current_id != "" && !seen[current_id]) { + print current_header ORS current_seq } + } + ' "${MERGED_FASTA_TMP}" > "${MERGED_FASTA}" + rm -f "${MERGED_FASTA_TMP}" + + # Check if merged file was created and has content + if [ ! -s "${MERGED_FASTA}" ]; then + echo "Warning: Merged FASTA file is empty or not created" + cd .. + rm -rf ${MERGED_DIR} + else + FILE_SIZE=$(du -h "${MERGED_FASTA}" | cut -f1) + echo "Merged FASTA file size: ${FILE_SIZE}" + + echo "Creating merged BLAST database..." + MERGED_DB_NAME="rnacentral_merged_${RELEASE}" + makeblastdb -in "${MERGED_FASTA}" \ + -out "${MERGED_DB_NAME}" \ + -dbtype nucl \ + -parse_seqids \ + -title "RNAcentral_Merged_${RELEASE}" + + echo "" + echo "✓ Merged BLAST database created successfully!" + echo "Database location: $(pwd)/${MERGED_DB_NAME}" + echo "" + echo "To use the merged database, set in your config (search_rna_config.yaml):" + echo " rnacentral_params:" + echo " use_local_blast: true" + echo " local_blast_db: $(pwd)/${MERGED_DB_NAME}" + echo "" + echo "Note: The merged database includes: ${DATABASES[*]}" + cd .. fi -else - echo "Error: FASTA file not found" - exit 1 fi -# Check if we have sequences -if [ ! -s "${OUTPUT_FASTA}" ]; then - echo "Error: FASTA file is empty" - exit 1 -fi - -# Get file size for user information -FILE_SIZE=$(du -h "${OUTPUT_FASTA}" | cut -f1) -echo "FASTA file size: ${FILE_SIZE}" - -echo "Creating BLAST database..." -# Create BLAST database for RNA sequences (use -dbtype nucl for nucleotide) -# Note: RNAcentral uses RNAcentral IDs (URS...) as sequence identifiers, -# which matches the format expected by the RNACentralSearch class -DB_OUTPUT_NAME="${DB_NAME}_${RELEASE}" -makeblastdb -in "${OUTPUT_FASTA}" \ - -out "${DB_OUTPUT_NAME}" \ - -dbtype nucl \ - -parse_seqids \ - -title "RNAcentral_${DB_NAME}_${RELEASE}" - echo "" -echo "BLAST database created successfully!" -echo "Database location: $(pwd)/${DB_OUTPUT_NAME}" -echo "" -echo "To use this database, set in your config (search_rna_config.yaml):" -echo " rnacentral_params:" -echo " use_local_blast: true" -echo " local_blast_db: $(pwd)/${DB_OUTPUT_NAME}" -echo "" -echo "Note: The database files are:" -ls -lh ${DB_OUTPUT_NAME}.* | head -5 -echo "" -if [ "${DB_SELECTION}" = "all" ]; then - echo "This database uses RNAcentral IDs (URS...), which matches the online" - echo "RNAcentral search API, ensuring consistent results between local and online searches." -else - echo "This is a subset database from ${DB_SELECTION} expert database." - echo "For full coverage matching online API, use 'all' option." -fi +echo "Summary of downloaded databases:" +for DB_SELECTION in "${DATABASES[@]}"; do + if [ "${DB_SELECTION}" = "all" ]; then + OUTPUT_DIR="rnacentral_${RELEASE}" + DB_NAME="rnacentral" + else + OUTPUT_DIR="rnacentral_${DB_SELECTION}_${RELEASE}" + DB_NAME="${DB_SELECTION}" + fi + if [ -d "${OUTPUT_DIR}" ]; then + echo " - ${DB_NAME}: ${OUTPUT_DIR}/" + fi +done -cd .. +if [ -d "rnacentral_merged_${RELEASE}" ]; then + echo " - merged (all databases): rnacentral_merged_${RELEASE}/" + echo "" + echo "💡 Recommendation: Use the merged database for searching across all databases." +fi