diff --git a/src/input_adapters/jensenlab/tinx.py b/src/input_adapters/jensenlab/tinx.py index 6c44282..95bba3a 100644 --- a/src/input_adapters/jensenlab/tinx.py +++ b/src/input_adapters/jensenlab/tinx.py @@ -14,6 +14,8 @@ class TINXAdapter(InputAdapter): + progress_every = 1000 + def __init__( self, protein_mentions_file_path: str, @@ -56,34 +58,30 @@ def get_version(self) -> DatasourceVersionInfo: ) def get_all(self) -> Generator[List[Union[Node, Relationship]], None, None]: - protein_pmids, pmid_to_protein_count, pmid_to_proteins = self._load_protein_mentions() - disease_pmids, pmid_to_disease_count = self._load_disease_mentions() + print("TIN-X: building protein PMID maps") + pmid_to_protein_count, pmid_to_proteins = self._build_protein_pmid_maps() + print( + f"TIN-X: built protein PMID maps for {len(pmid_to_proteins)} pmids " + f"across {sum(pmid_to_protein_count.values())} protein-pmid mentions" + ) - proteins = [ - Protein( - id=EquivalentId(id=protein_id, type=Prefix.ENSEMBL).id_str(), - novelty=[novelty], - ) - for protein_id, pmids in protein_pmids.items() - for novelty in [self._compute_novelty(pmids, pmid_to_protein_count)] - if novelty is not None - ] - - diseases = [ - Disease( - id=doid, - novelty=[novelty], - ) - for doid, pmids in disease_pmids.items() - for novelty in [self._compute_novelty(pmids, pmid_to_disease_count)] - if novelty is not None - ] + print("TIN-X: building disease PMID counts") + pmid_to_disease_count = self._build_disease_pmid_counts() + print( + f"TIN-X: built disease PMID counts for {len(pmid_to_disease_count)} pmids " + f"across {sum(pmid_to_disease_count.values())} disease-pmid mentions" + ) + + print("TIN-X: emitting protein novelty") + yield from self._yield_protein_novelty(pmid_to_protein_count) + print("TIN-X: emitting disease novelty") + yield from self._yield_disease_novelty(pmid_to_disease_count) + print("TIN-X: emitting protein-disease importance edges") - yield proteins - yield diseases batch: List[TINXImportanceEdge] = [] pair_count = 0 - for doid, disease_pmid_set in disease_pmids.items(): + processed_diseases = 0 + for doid, disease_pmid_set in self._iter_disease_mentions(): protein_scores: Dict[str, float] = defaultdict(float) for pmid in disease_pmid_set: proteins_for_pmid = pmid_to_proteins.get(pmid) @@ -114,68 +112,135 @@ def get_all(self) -> Generator[List[Union[Node, Relationship]], None, None]: ) pair_count += 1 if len(batch) >= self.batch_size: + print( + f"TIN-X: yielding {len(batch)} importance edges " + f"after {processed_diseases + 1} diseases and {pair_count} total pairs" + ) yield batch batch = [] if self.max_pairs is not None and pair_count >= self.max_pairs: break + processed_diseases += 1 + if processed_diseases % self.progress_every == 0: + print( + f"TIN-X: processed {processed_diseases} diseases and emitted " + f"{pair_count} importance pairs" + ) if self.max_pairs is not None and pair_count >= self.max_pairs: break if batch: + print( + f"TIN-X: final importance batch {len(batch)} edges " + f"after {processed_diseases} diseases and {pair_count} total pairs" + ) yield batch - def _load_protein_mentions(self) -> Tuple[Dict[str, Set[str]], Dict[str, int], Dict[str, Set[str]]]: - protein_pmids: Dict[str, Set[str]] = {} + def _build_protein_pmid_maps(self) -> Tuple[Dict[str, int], Dict[str, List[str]]]: pmid_to_protein_count: Dict[str, int] = defaultdict(int) - pmid_to_proteins: Dict[str, Set[str]] = defaultdict(set) - loaded_proteins = 0 + pmid_to_proteins: Dict[str, List[str]] = defaultdict(list) + for loaded_proteins, (protein_id, pmids) in enumerate(self._iter_protein_mentions(), start=1): + for pmid in pmids: + pmid_to_protein_count[pmid] += 1 + pmid_to_proteins[pmid].append(protein_id) + if loaded_proteins % self.progress_every == 0: + print( + f"TIN-X: indexed {loaded_proteins} proteins into " + f"{len(pmid_to_proteins)} unique pmids" + ) + return pmid_to_protein_count, pmid_to_proteins + + def _build_disease_pmid_counts(self) -> Dict[str, int]: + pmid_to_disease_count: Dict[str, int] = defaultdict(int) + for loaded_diseases, (_, pmids) in enumerate(self._iter_disease_mentions(), start=1): + for pmid in pmids: + pmid_to_disease_count[pmid] += 1 + if loaded_diseases % self.progress_every == 0: + print( + f"TIN-X: indexed {loaded_diseases} diseases into " + f"{len(pmid_to_disease_count)} unique pmids" + ) + return pmid_to_disease_count + + def _yield_protein_novelty(self, pmid_to_protein_count: Dict[str, int]) -> Generator[List[Protein], None, None]: + batch: List[Protein] = [] + emitted = 0 + for protein_id, pmids in self._iter_protein_mentions(): + novelty = self._compute_novelty(pmids, pmid_to_protein_count) + if novelty is None: + continue + batch.append( + Protein( + id=EquivalentId(id=protein_id, type=Prefix.ENSEMBL).id_str(), + novelty=[novelty], + ) + ) + emitted += 1 + if len(batch) >= self.batch_size: + print(f"TIN-X: yielding {len(batch)} protein novelty nodes ({emitted} total)") + yield batch + batch = [] + if batch: + print(f"TIN-X: final protein novelty batch {len(batch)} ({emitted} total)") + yield batch + + def _yield_disease_novelty(self, pmid_to_disease_count: Dict[str, int]) -> Generator[List[Disease], None, None]: + batch: List[Disease] = [] + emitted = 0 + for doid, pmids in self._iter_disease_mentions(): + novelty = self._compute_novelty(pmids, pmid_to_disease_count) + if novelty is None: + continue + batch.append( + Disease( + id=doid, + novelty=[novelty], + ) + ) + emitted += 1 + if len(batch) >= self.batch_size: + print(f"TIN-X: yielding {len(batch)} disease novelty nodes ({emitted} total)") + yield batch + batch = [] + if batch: + print(f"TIN-X: final disease novelty batch {len(batch)} ({emitted} total)") + yield batch + + def _iter_protein_mentions(self) -> Generator[Tuple[str, Set[str]], None, None]: + seen_proteins: Set[str] = set() with open(self.protein_mentions_file_path, "r", encoding="utf-8", errors="replace") as handle: for raw_line in handle: row = raw_line.rstrip("\n").split("\t", 1) if len(row) < 2: continue protein_id = row[0].strip() - if not protein_id.startswith("ENSP"): + if not protein_id.startswith("ENSP") or protein_id in seen_proteins: continue pmids = self._parse_pmid_field(row[1]) if not pmids: continue - if protein_id in protein_pmids: - continue - protein_pmids[protein_id] = pmids - for pmid in pmids: - pmid_to_protein_count[pmid] += 1 - pmid_to_proteins[pmid].add(protein_id) - loaded_proteins += 1 - if self.max_proteins is not None and loaded_proteins >= self.max_proteins: + seen_proteins.add(protein_id) + yield protein_id, pmids + if self.max_proteins is not None and len(seen_proteins) >= self.max_proteins: break - return protein_pmids, pmid_to_protein_count, pmid_to_proteins - - def _load_disease_mentions(self) -> Tuple[Dict[str, Set[str]], Dict[str, int]]: - disease_pmids: Dict[str, Set[str]] = {} - pmid_to_disease_count: Dict[str, int] = defaultdict(int) - loaded_diseases = 0 + def _iter_disease_mentions(self) -> Generator[Tuple[str, Set[str]], None, None]: + seen_diseases: Set[str] = set() with open(self.disease_mentions_file_path, "r", encoding="utf-8", errors="replace") as handle: for raw_line in handle: row = raw_line.rstrip("\n").split("\t", 1) if len(row) < 2: continue doid = self._normalize_doid(row[0].strip()) - if doid is None: + if doid is None or doid in seen_diseases: continue pmids = self._parse_pmid_field(row[1]) if not pmids: continue - if doid in disease_pmids: - continue - disease_pmids[doid] = pmids - for pmid in pmids: - pmid_to_disease_count[pmid] += 1 - loaded_diseases += 1 - if self.max_diseases is not None and loaded_diseases >= self.max_diseases: + seen_diseases.add(doid) + yield doid, pmids + if self.max_diseases is not None and len(seen_diseases) >= self.max_diseases: break - return disease_pmids, pmid_to_disease_count def _download_date(self) -> Optional[date]: timestamps = []