Skip to content

Commit

Permalink
Merge pull request #40 from kumo-ai/yyuan/refactor
Browse files Browse the repository at this point in the history
refactor entire repo
  • Loading branch information
yiweny authored Oct 14, 2024
2 parents 2408e78 + 1c2326a commit 81a945d
Show file tree
Hide file tree
Showing 26 changed files with 838 additions and 57 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Testing PyTorch 2.2
name: Testing PyTorch 2.3

on: # yamllint disable-line rule:truthy
push:
Expand Down Expand Up @@ -27,15 +27,19 @@ jobs:

- name: Install PyTorch
run: |
pip install torch==2.3.0 --extra-index-url https://download.pytorch.org/whl/cpu
pip install --no-index pyg-lib==0.4.0 -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install torch==2.3.0 --extra-index-url https://download.pytorch.org/whl/cpu
python3 -m pip install --no-index pyg-lib==0.4.0 -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
- name: Install main package
run: |
source .venv/bin/activate
pip install -e .[full,test]
- name: Run tests
run: |
source .venv/bin/activate
pytest --cov --cov-report=xml
- name: Upload coverage
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# hybridgnn
# contextgnn

- [Overleaf Latex link](https://www.overleaf.com/8255131161fxgzwccqftmz#5676c1)

- [HybridGNN blogpost link](https://docs.google.com/document/d/1kcGl9zk_pHuZ5xE9HBBVCmJa6iLiZ_yOjX9eFPpHiXw/edit)
- [contextgnn blogpost link](https://docs.google.com/document/d/1kcGl9zk_pHuZ5xE9HBBVCmJa6iLiZ_yOjX9eFPpHiXw/edit)

- [Spreadsheet of results](https://docs.google.com/spreadsheets/d/1bnNurVKLCgWjgvd9fCO-NexCgU75Xql9erfn6h3Wooo/edit?usp=sharing)


## How to Run

Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/contextgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model contextgnn
```


Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py)
Run [`examples/relbench_example.py`](https://github.com/kumo-ai/contextgnn/blob/master/examples/relbench_example.py)

```sh
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model contextgnn
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model contextgnn
```


Expand Down
20 changes: 10 additions & 10 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.utils import GloveTextEmbedding, RHSEmbeddingMode
from contextgnn.nn.models import IDGNN, ContextGNN, ShallowRHSGNN
from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
LINK_PREDICTION_METRIC = "link_prediction_map"
Expand All @@ -43,8 +43,8 @@
parser.add_argument(
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
default="contextgnn",
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=50,
Expand Down Expand Up @@ -103,7 +103,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]
model_cls: Type[Union[IDGNN, ContextGNN, ShallowRHSGNN]]

if args.model == "idgnn":
model_search_space = {
Expand All @@ -118,7 +118,7 @@
"gamma_rate": [0.9, 0.95, 1.],
}
model_cls = IDGNN
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
elif args.model in ["contextgnn", "shallowrhsgnn"]:
model_search_space = {
"encoder_channels": [32, 64, 128, 256, 512],
"encoder_layers": [2, 4, 8],
Expand All @@ -135,7 +135,7 @@
"base_lr": [0.001, 0.01],
"gamma_rate": [0.8, 1.],
}
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)
model_cls = (ContextGNN if args.model == "contextgnn" else ShallowRHSGNN)


def train(
Expand Down Expand Up @@ -173,7 +173,7 @@ def train(

loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
elif args.model in ["contextgnn", "shallowrhsgnn"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
Expand Down Expand Up @@ -214,7 +214,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
device=out.device)
scores[batch[task.dst_entity_table].batch,
batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
elif args.model in ["contextgnn", "shallowrhsgnn"]:
# Get ground-truth
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
Expand Down Expand Up @@ -257,7 +257,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model in ["hybridgnn", "shallowrhsgnn"]:
if args.model in ["contextgnn", "shallowrhsgnn"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
model_cfg["dst_entity_table"] = task.dst_entity_table
elif args.model == "idgnn":
Expand Down
Empty file added contextgnn/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions contextgnn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .encoder import HeteroEncoder
from .rhs_embedding import RHSEmbedding

__all__ = classes = ['HeteroEncoder', 'RHSEmbedding']
131 changes: 131 additions & 0 deletions contextgnn/nn/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Dict, List, Optional

import torch
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.nn import PositionalEncoding
from torch_geometric.typing import NodeType

DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = {
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
torch_frame.multicategorical: (
torch_frame.nn.MultiCategoricalEmbeddingEncoder,
{},
),
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
}
SECONDS_IN_A_DAY = 60 * 60 * 24


class HeteroEncoder(torch.nn.Module):
r"""HeteroStypeWiseEncoder is a simple encoder to encode multi-modal
data from different node types.
Args:
channels (int): The output channels for each node type.
node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa: E501
A dictionary mapping from node type to column names dictionary
compatible to PyTorch Frame.
node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]):
A dictionary mapping from node type to column statistics dictionary
compatible to PyTorch Frame.
stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]):
A dictionary mapping from :obj:`torch_frame.stype` object into a
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its
keyword arguments :obj:`kwargs`.
torch_frame_model_cls: Model class for PyTorch Frame. The class object
takes :class:`TensorFrame` object as input and outputs
:obj:`channels`-dimensional embeddings. Default to
:class:`torch_frame.nn.ResNet`.
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for
:class:`torch_frame_model_cls` class. Default keyword argument is
set specific for :class:`torch_frame.nn.ResNet`. Expect it to
be changed for different :class:`torch_frame_model_cls`.
"""
def __init__(
self,
channels: int,
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype,
List[str]]],
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any],
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

self.encoders = torch.nn.ModuleDict()

for node_type in node_to_col_names_dict.keys():
stype_encoder_dict = {
stype:
stype_encoder_cls_kwargs[stype][0](
**stype_encoder_cls_kwargs[stype][1])
for stype in node_to_col_names_dict[node_type].keys()
}

self.encoders[node_type] = torch_frame_model_cls(
**torch_frame_model_kwargs,
out_channels=channels,
col_stats=node_to_col_stats[node_type],
col_names_dict=node_to_col_names_dict[node_type],
stype_encoder_dict=stype_encoder_dict,
)

def reset_parameters(self) -> None:
for encoder in self.encoders.values():
encoder.reset_parameters()

def forward(
self,
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
) -> Dict[NodeType, Tensor]:
x_dict = {
node_type: self.encoders[node_type](tf)
for node_type, tf in tf_dict.items()
}
return x_dict


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int) -> None:
super().__init__()

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
for node_type in node_types
})
self.lin_dict = torch.nn.ModuleDict({
node_type:
torch.nn.Linear(channels, channels)
for node_type in node_types
})

def reset_parameters(self) -> None:
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()

def forward(
self,
seed_time: Tensor,
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / SECONDS_IN_A_DAY

x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

return out_dict
10 changes: 10 additions & 0 deletions contextgnn/nn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .graphsage import HeteroGraphSAGE
from .idgnn import IDGNN
from .contextgnn import ContextGNN
from .shallowrhsgnn import ShallowRHSGNN
from .rhsembeddinggnn import RHSEmbeddingGNN

__all__ = classes = [
'HeteroGraphSAGE', 'IDGNN', 'ContextGNN', 'ShallowRHSGNN',
'RHSEmbeddingGNN'
]
Loading

0 comments on commit 81a945d

Please sign in to comment.