Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphgen/configs/search_dna_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ pipeline:
tool: GraphGen # tool name for NCBI API
use_local_blast: true # whether to use local blast for DNA search
local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension)
blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory)
max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited)

2 changes: 2 additions & 0 deletions graphgen/configs/search_protein_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ pipeline:
use_local_blast: true # whether to use local blast for uniprot search
local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
# options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database)
blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory)
max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited)
2 changes: 2 additions & 0 deletions graphgen/configs/search_rna_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ pipeline:
rnacentral_params:
use_local_blast: true # whether to use local blast for RNA search
local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension)
blast_num_threads: 2 # number of threads for BLAST search (reduce to save memory)
max_concurrent: 5 # maximum number of concurrent search tasks (reduce to prevent OOM, default: unlimited)
102 changes: 84 additions & 18 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import os
import time
from typing import Dict
Expand Down Expand Up @@ -87,24 +88,45 @@ def __init__(
@async_to_sync_method
async def read(self, read_config: Dict):
"""
read files from input sources
read files from input sources with batch processing
"""
# Get batch_size from config, default to 10000
batch_size = read_config.pop("batch_size", 10000)

doc_stream = read_files(**read_config, cache_dir=self.working_dir)

batch = {}
total_processed = 0

for doc in doc_stream:
doc_id = compute_mm_hash(doc, prefix="doc-")
batch[doc_id] = doc

# Process batch when it reaches batch_size
if len(batch) >= batch_size:
_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
if new_docs:
self.full_docs_storage.upsert(new_docs)
total_processed += len(new_docs)
logger.info("Processed batch: %d new documents (total: %d)", len(new_docs), total_processed)
batch.clear()

# TODO: configurable whether to use coreference resolution

_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
if len(new_docs) == 0:
# Process remaining documents in batch
if batch:
_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
if new_docs:
self.full_docs_storage.upsert(new_docs)
total_processed += len(new_docs)
logger.info("Processed final batch: %d new documents (total: %d)", len(new_docs), total_processed)

if total_processed == 0:
logger.warning("All documents are already in the storage")
return
self.full_docs_storage.upsert(new_docs)
self.full_docs_storage.index_done_callback()
else:
self.full_docs_storage.index_done_callback()

@async_to_sync_method
async def chunk(self, chunk_config: Dict):
Expand Down Expand Up @@ -169,24 +191,68 @@ async def build_kg(self):
async def search(self, search_config: Dict):
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))

seeds = self.full_docs_storage.get_all()
if len(seeds) == 0:
logger.warning("All documents are already been searched")
# Get search_batch_size from config (default: 10000)
search_batch_size = search_config.get("search_batch_size", 10000)

# Get save_interval from config (default: 1000, 0 to disable)
save_interval = search_config.get("save_interval", 1000)

# Process in batches to avoid OOM
all_flattened_results = {}
batch_num = 0

for seeds_batch in self.full_docs_storage.iter_batches(batch_size=search_batch_size):
if len(seeds_batch) == 0:
continue

batch_num += 1
logger.info("Processing search batch %d with %d documents", batch_num, len(seeds_batch))

search_results = await search_all(
seed_data=seeds_batch,
search_config=search_config,
search_storage=self.search_storage if save_interval > 0 else None,
save_interval=save_interval,
)

# Convert search_results from {data_source: [results]} to {key: result}
# This maintains backward compatibility
for data_source, result_list in search_results.items():
if not isinstance(result_list, list):
continue
for result in result_list:
if result is None:
continue
# Use _search_query as key if available, otherwise generate a key
if isinstance(result, dict) and "_search_query" in result:
query = result["_search_query"]
key = f"{data_source}:{query}"
else:
# Generate a unique key
result_str = str(result)
key_hash = hashlib.md5(result_str.encode()).hexdigest()[:8]
key = f"{data_source}:{key_hash}"
all_flattened_results[key] = result

if len(all_flattened_results) == 0:
logger.warning("No search results generated")
return
search_results = await search_all(
seed_data=seeds,
search_config=search_config,
)

_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
_add_search_keys = self.search_storage.filter_keys(list(all_flattened_results.keys()))
search_results = {
k: v for k, v in search_results.items() if k in _add_search_keys
k: v for k, v in all_flattened_results.items() if k in _add_search_keys
}
if len(search_results) == 0:
logger.warning("All search results are already in the storage")
return
self.search_storage.upsert(search_results)
self.search_storage.index_done_callback()

# Only save if not using periodic saving (to avoid duplicate saves)
if save_interval == 0:
self.search_storage.upsert(search_results)
self.search_storage.index_done_callback()
else:
# Results were already saved periodically, just update index
self.search_storage.index_done_callback()

@async_to_sync_method
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
Expand Down
56 changes: 55 additions & 1 deletion graphgen/models/reader/jsonl_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Dict, List
import os
from typing import Any, Dict, Iterator, List

from graphgen.bases.base_reader import BaseReader
from graphgen.utils import logger
Expand Down Expand Up @@ -28,3 +29,56 @@ def read(self, file_path: str) -> List[Dict[str, Any]]:
except json.JSONDecodeError as e:
logger.error("Error decoding JSON line: %s. Error: %s", line, e)
return self.filter(docs)

def read_stream(self, file_path: str) -> Iterator[Dict[str, Any]]:
"""
Stream read JSONL files line by line without loading entire file into memory.
Returns an iterator that yields filtered documents.

:param file_path: Path to the JSONL file.
:return: Iterator of dictionaries containing the data.
"""
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
try:
doc = json.loads(line)
assert "type" in doc, f"Missing 'type' in document: {doc}"
if doc.get("type") == "text" and self.text_column not in doc:
raise ValueError(
f"Missing '{self.text_column}' in document: {doc}"
)

# Apply filtering logic inline (similar to BaseReader.filter)
if doc.get("type") == "text":
content = doc.get(self.text_column, "").strip()
if content:
yield doc
elif doc.get("type") in ("image", "table", "equation"):
img_path = doc.get("img_path")
if self._image_exists(img_path):
yield doc
else:
yield doc
except json.JSONDecodeError as e:
logger.error("Error decoding JSON line: %s. Error: %s", line, e)

@staticmethod
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
"""
Check if an image exists at the given local path or URL.
:param path_or_url: Local file path or remote URL of the image.
:param timeout: Timeout for remote URL requests in seconds.
:return: True if the image exists, False otherwise.
"""
if not path_or_url:
return False
if not path_or_url.startswith(("http://", "https://", "ftp://")):
path = path_or_url.replace("file://", "", 1)
path = os.path.abspath(path)
return os.path.isfile(path)
try:
import requests
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
return resp.status_code == 200
except Exception:
return False
Loading
Loading