-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from kumo-ai/yyuan/refactor
refactor entire repo
- Loading branch information
Showing
26 changed files
with
838 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .encoder import HeteroEncoder | ||
from .rhs_embedding import RHSEmbedding | ||
|
||
__all__ = classes = ['HeteroEncoder', 'RHSEmbedding'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
import torch_frame | ||
from torch import Tensor | ||
from torch_frame.data.stats import StatType | ||
from torch_frame.nn.models import ResNet | ||
from torch_geometric.nn import PositionalEncoding | ||
from torch_geometric.typing import NodeType | ||
|
||
DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = { | ||
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), | ||
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), | ||
torch_frame.multicategorical: ( | ||
torch_frame.nn.MultiCategoricalEmbeddingEncoder, | ||
{}, | ||
), | ||
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), | ||
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), | ||
} | ||
SECONDS_IN_A_DAY = 60 * 60 * 24 | ||
|
||
|
||
class HeteroEncoder(torch.nn.Module): | ||
r"""HeteroStypeWiseEncoder is a simple encoder to encode multi-modal | ||
data from different node types. | ||
Args: | ||
channels (int): The output channels for each node type. | ||
node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa: E501 | ||
A dictionary mapping from node type to column names dictionary | ||
compatible to PyTorch Frame. | ||
node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]): | ||
A dictionary mapping from node type to column statistics dictionary | ||
compatible to PyTorch Frame. | ||
stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): | ||
A dictionary mapping from :obj:`torch_frame.stype` object into a | ||
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its | ||
keyword arguments :obj:`kwargs`. | ||
torch_frame_model_cls: Model class for PyTorch Frame. The class object | ||
takes :class:`TensorFrame` object as input and outputs | ||
:obj:`channels`-dimensional embeddings. Default to | ||
:class:`torch_frame.nn.ResNet`. | ||
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for | ||
:class:`torch_frame_model_cls` class. Default keyword argument is | ||
set specific for :class:`torch_frame.nn.ResNet`. Expect it to | ||
be changed for different :class:`torch_frame_model_cls`. | ||
""" | ||
def __init__( | ||
self, | ||
channels: int, | ||
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, | ||
List[str]]], | ||
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], | ||
stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any], | ||
torch_frame_model_cls=ResNet, | ||
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.encoders = torch.nn.ModuleDict() | ||
|
||
for node_type in node_to_col_names_dict.keys(): | ||
stype_encoder_dict = { | ||
stype: | ||
stype_encoder_cls_kwargs[stype][0]( | ||
**stype_encoder_cls_kwargs[stype][1]) | ||
for stype in node_to_col_names_dict[node_type].keys() | ||
} | ||
|
||
self.encoders[node_type] = torch_frame_model_cls( | ||
**torch_frame_model_kwargs, | ||
out_channels=channels, | ||
col_stats=node_to_col_stats[node_type], | ||
col_names_dict=node_to_col_names_dict[node_type], | ||
stype_encoder_dict=stype_encoder_dict, | ||
) | ||
|
||
def reset_parameters(self) -> None: | ||
for encoder in self.encoders.values(): | ||
encoder.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
tf_dict: Dict[NodeType, torch_frame.TensorFrame], | ||
) -> Dict[NodeType, Tensor]: | ||
x_dict = { | ||
node_type: self.encoders[node_type](tf) | ||
for node_type, tf in tf_dict.items() | ||
} | ||
return x_dict | ||
|
||
|
||
class HeteroTemporalEncoder(torch.nn.Module): | ||
def __init__(self, node_types: List[NodeType], channels: int) -> None: | ||
super().__init__() | ||
|
||
self.encoder_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
PositionalEncoding(channels) | ||
for node_type in node_types | ||
}) | ||
self.lin_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
torch.nn.Linear(channels, channels) | ||
for node_type in node_types | ||
}) | ||
|
||
def reset_parameters(self) -> None: | ||
for encoder in self.encoder_dict.values(): | ||
encoder.reset_parameters() | ||
for lin in self.lin_dict.values(): | ||
lin.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
seed_time: Tensor, | ||
time_dict: Dict[NodeType, Tensor], | ||
batch_dict: Dict[NodeType, Tensor], | ||
) -> Dict[NodeType, Tensor]: | ||
out_dict: Dict[NodeType, Tensor] = {} | ||
|
||
for node_type, time in time_dict.items(): | ||
rel_time = seed_time[batch_dict[node_type]] - time | ||
rel_time = rel_time / SECONDS_IN_A_DAY | ||
|
||
x = self.encoder_dict[node_type](rel_time) | ||
x = self.lin_dict[node_type](x) | ||
out_dict[node_type] = x | ||
|
||
return out_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .graphsage import HeteroGraphSAGE | ||
from .idgnn import IDGNN | ||
from .contextgnn import ContextGNN | ||
from .shallowrhsgnn import ShallowRHSGNN | ||
from .rhsembeddinggnn import RHSEmbeddingGNN | ||
|
||
__all__ = classes = [ | ||
'HeteroGraphSAGE', 'IDGNN', 'ContextGNN', 'ShallowRHSGNN', | ||
'RHSEmbeddingGNN' | ||
] |
Oops, something went wrong.