Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
31c5a64
feat: add config and operator node types
ChenZiHong-Gavin Dec 3, 2025
8bcbe51
refactor: refactor readers with ray data
ChenZiHong-Gavin Dec 3, 2025
246348f
fix: delete param parallelism for readers
ChenZiHong-Gavin Dec 3, 2025
319e1e7
fix: fix import error
ChenZiHong-Gavin Dec 3, 2025
42fcb09
refactor read and chunk operators with no side effects
ChenZiHong-Gavin Dec 4, 2025
b458e48
fix: fix import error
ChenZiHong-Gavin Dec 4, 2025
95c4783
fix: fix return logic
ChenZiHong-Gavin Dec 4, 2025
c844d65
refactor: rename operator split to chunk
ChenZiHong-Gavin Dec 4, 2025
c447936
refactor: refactor build_kg to accomodate ray data
ChenZiHong-Gavin Dec 4, 2025
3edbb81
feat: add StorageFactory & global params
ChenZiHong-Gavin Dec 4, 2025
ee0639d
refactor: refactor quiz to accomodata ray data engine
ChenZiHong-Gavin Dec 5, 2025
157f0b0
fix: reload graph before quizzing
ChenZiHong-Gavin Dec 5, 2025
99a6e5f
Merge branch 'main' of https://github.com/open-sciencelab/GraphGen in…
ChenZiHong-Gavin Dec 5, 2025
ec2033b
Potential fix for pull request finding 'Unreachable code'
ChenZiHong-Gavin Dec 5, 2025
bc07222
fix: fix quiz params
ChenZiHong-Gavin Dec 5, 2025
c9435d7
refactor: refactor quiz&judge to ray actors
ChenZiHong-Gavin Dec 10, 2025
c55fc09
Merge branch 'refactor/refactor-with-ray-data' of https://github.com/…
ChenZiHong-Gavin Dec 10, 2025
d7d6c2a
fix: fix transferring quizzed data to JudgeService
ChenZiHong-Gavin Dec 10, 2025
a6aedaf
refactor: refactor partition to accomodate ray data
ChenZiHong-Gavin Dec 10, 2025
ea1603b
fix: fix lint problem
ChenZiHong-Gavin Dec 10, 2025
244deb4
refactor: refactor op generate
ChenZiHong-Gavin Dec 11, 2025
d460a2a
feat: write results in output folder
ChenZiHong-Gavin Dec 11, 2025
cd011ad
fix: raise error when no dataset is created
ChenZiHong-Gavin Dec 11, 2025
aab7438
fix: return generator in ece_partitioner
ChenZiHong-Gavin Dec 11, 2025
7643b9f
fix: return generator in ece_partitioner
ChenZiHong-Gavin Dec 11, 2025
c42b604
refactor: refactor data format to support multi-modal input
ChenZiHong-Gavin Dec 11, 2025
42dc73e
fix: delete fetching schema to avoid ray's duplicate execution
ChenZiHong-Gavin Dec 11, 2025
73f70a5
fix: fix operators' registry
ChenZiHong-Gavin Dec 11, 2025
37cbfcf
feat: refactor schema_guided_extraction & add examples
ChenZiHong-Gavin Dec 11, 2025
b400d2e
feat: seperate ray logs and service logs
ChenZiHong-Gavin Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
StorageNameSpace,
)
from .base_tokenizer import BaseTokenizer
from .datatypes import Chunk, QAPair, Token
from .datatypes import Chunk, Config, Node, QAPair, Token
49 changes: 22 additions & 27 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class BasePartitioner(ABC):
@abstractmethod
async def partition(
def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
Expand All @@ -20,39 +20,34 @@ async def partition(
"""

@staticmethod
async def community2batch(
communities: List[Community], g: BaseGraphStorage
) -> list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
def community2batch(
comm: Community, g: BaseGraphStorage
) -> tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]:
"""
Convert communities to batches of nodes and edges.
:param communities
:param comm: Community
:param g: Graph storage instance
:return: List of batches, each batch is a tuple of (nodes, edges)
"""
batches = []
for comm in communities:
nodes = comm.nodes
edges = comm.edges
nodes_data = []
for node in nodes:
node_data = g.get_node(node)
if node_data:
nodes_data.append((node, node_data))
edges_data = []
for u, v in edges:
edge_data = g.get_edge(u, v)
nodes = comm.nodes
edges = comm.edges
nodes_data = []
for node in nodes:
node_data = g.get_node(node)
if node_data:
nodes_data.append((node, node_data))
edges_data = []
for u, v in edges:
edge_data = g.get_edge(u, v)
if edge_data:
edges_data.append((u, v, edge_data))
else:
edge_data = g.get_edge(v, u)
if edge_data:
edges_data.append((u, v, edge_data))
else:
edge_data = g.get_edge(v, u)
if edge_data:
edges_data.append((v, u, edge_data))
batches.append((nodes_data, edges_data))
return batches
edges_data.append((v, u, edge_data))
return nodes_data, edges_data

@staticmethod
def _build_adjacency_list(
Expand Down
95 changes: 55 additions & 40 deletions graphgen/bases/base_reader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import pandas as pd
import requests
from ray.data import Dataset


class BaseReader(ABC):
Expand All @@ -14,52 +16,65 @@ def __init__(self, text_column: str = "content"):
self.text_column = text_column

@abstractmethod
def read(self, file_path: str) -> List[Dict[str, Any]]:
def read(self, input_path: Union[str, List[str]]) -> Dataset:
"""
Read data from the specified file path.

:param file_path: Path to the input file.
:return: List of dictionaries containing the data.
:param input_path: Path to the input file or list of file paths.
:return: Ray Dataset containing the read data.
"""

@staticmethod
def filter(data: List[dict]) -> List[dict]:
def _should_keep_item(self, item: Dict[str, Any]) -> bool:
"""
Determine whether to keep the given item based on the text column.

:param item: Dictionary representing a data entry.
:return: True if the item should be kept, False otherwise.
"""
Filter out entries with empty or missing text in the specified column.
item_type = item.get("type")
assert item_type in [
"text",
"image",
"table",
"equation",
"protein",
], f"Unsupported item type: {item_type}"
if item_type == "text":
content = item.get(self.text_column, "").strip()
return bool(content)
return True

:param data: List of dictionaries containing the data.
:return: Filtered list of dictionaries.
def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
"""
Validate data format.
"""
if "type" not in batch.columns:
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")

def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
"""
Check if an image exists at the given local path or URL.
:param path_or_url: Local file path or remote URL of the image.
:param timeout: Timeout for remote URL requests in seconds.
:return: True if the image exists, False otherwise.
"""
if not path_or_url:
return False
if not path_or_url.startswith(("http://", "https://", "ftp://")):
path = path_or_url.replace("file://", "", 1)
path = os.path.abspath(path)
return os.path.isfile(path)
try:
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
return resp.status_code == 200
except requests.RequestException:
return False
if "text" in batch["type"].values:
if self.text_column not in batch.columns:
raise ValueError(
f"Missing '{self.text_column}' column for text documents"
)

filtered_data = []
for item in data:
if item.get("type") == "text":
content = item.get("content", "").strip()
if content:
filtered_data.append(item)
elif item.get("type") in ("image", "table", "equation"):
img_path = item.get("img_path")
if _image_exists(img_path):
filtered_data.append(item)
else:
filtered_data.append(item)
return filtered_data
return batch

@staticmethod
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
"""
Check if an image exists at the given local path or URL.
:param path_or_url: Local file path or remote URL of the image.
:param timeout: Timeout for remote URL requests in seconds.
:return: True if the image exists, False otherwise.
"""
if not path_or_url:
return False
if not path_or_url.startswith(("http://", "https://", "ftp://")):
path = path_or_url.replace("file://", "", 1)
path = os.path.abspath(path)
return os.path.isfile(path)
try:
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
return resp.status_code == 200
except requests.RequestException:
return False
6 changes: 3 additions & 3 deletions graphgen/bases/base_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Iterable, List, Literal, Optional, Union

from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger
from graphgen.utils.log import logger


class BaseSplitter(ABC):
Expand Down Expand Up @@ -33,7 +33,7 @@ def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.

:param text: The input text to be split.
:param text: The input text to be chunk.
:return: A list of text chunks.
"""

Expand Down Expand Up @@ -111,7 +111,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
# Now that we have the separator, chunk the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
Expand Down
17 changes: 0 additions & 17 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,6 @@ def query_done_callback(self):
"""commit the storage operations after querying"""


class BaseListStorage(Generic[T], StorageNameSpace):
def all_items(self) -> list[T]:
raise NotImplementedError

def get_by_index(self, index: int) -> Union[T, None]:
raise NotImplementedError

def append(self, data: T):
raise NotImplementedError

def upsert(self, data: list[T]):
raise NotImplementedError

def drop(self):
raise NotImplementedError


class BaseKVStorage(Generic[T], StorageNameSpace):
def all_keys(self) -> list[str]:
raise NotImplementedError
Expand Down
44 changes: 44 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass, field
from typing import List, Union

from pydantic import BaseModel, Field, field_validator


@dataclass
class Chunk:
Expand Down Expand Up @@ -48,3 +50,45 @@ class Community:
nodes: List[str] = field(default_factory=list)
edges: List[tuple] = field(default_factory=list)
metadata: dict = field(default_factory=dict)


class Node(BaseModel):
id: str = Field(..., description="unique node id")
op_name: str = Field(..., description="operator name")
type: str = Field(
..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
)
params: dict = Field(default_factory=dict, description="operator parameters")
dependencies: List[str] = Field(
default_factory=list, description="list of dependent node ids"
)
execution_params: dict = Field(
default_factory=dict, description="execution parameters like replicas, batch_size"
)

@classmethod
@field_validator("type")
def validate_type(cls, v: str) -> str:
valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
if v not in valid_types:
raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
return v


class Config(BaseModel):
global_params: dict = Field(
default_factory=dict, description="global context for the computation graph"
)

nodes: List[Node] = Field(
..., min_length=1, description="list of nodes in the computation graph"
)

@classmethod
@field_validator("nodes")
def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
ids = [node.id for node in v]
if len(ids) != len(set(ids)):
duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
raise ValueError(f"Duplicate node ids found: {duplicates}")
return v
2 changes: 2 additions & 0 deletions graphgen/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .init_llm import init_llm
from .init_storage import init_storage
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
return HTTPClient(**config)
if backend in ("openai_api", "azure_openai_api"):
from graphgen.models.llm.api.openai_client import OpenAIClient

# pass in concrete backend to the OpenAIClient so that internally we can distinguish
# between OpenAI and Azure OpenAI
return OpenAIClient(**config, backend=backend)
Expand Down Expand Up @@ -79,3 +80,6 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
backend = config.pop("backend")
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
return llm_wrapper


# TODO: use ray serve when loading large models to avoid re-loading in each actor
28 changes: 28 additions & 0 deletions graphgen/common/init_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from graphgen.models import JsonKVStorage, NetworkXStorage


class StorageFactory:
"""
Factory class to create storage instances based on backend.
Supported backends:
kv_storage(key-value storage):
- json_kv: JsonKVStorage
graph_storage:
- networkx: NetworkXStorage (graph storage)
"""

@staticmethod
def create_storage(backend: str, working_dir: str, namespace: str):
if backend == "json_kv":
return JsonKVStorage(working_dir, namespace=namespace)

if backend == "networkx":
return NetworkXStorage(working_dir, namespace=namespace)

raise NotImplementedError(
f"Storage backend '{backend}' is not implemented yet."
)


def init_storage(backend: str, working_dir: str, namespace: str):
return StorageFactory.create_storage(backend, working_dir, namespace)
Loading
Loading