Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/pathpyG/utils/dbgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import torch

from torch_geometric.utils import coalesce
from torch_geometric.data import Data

from pathpyG.algorithms.lift_order import aggregate_edge_index
from pathpyG.core.graph import Graph
from pathpyG.core.index_map import IndexMap
from pathpyG.core.temporal_graph import TemporalGraph

import pathpyG.core.multi_order_model as mm


def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") -> torch.Tensor:
Expand All @@ -22,3 +30,61 @@ def generate_bipartite_edge_index(g: Graph, g2: Graph, mapping: str = "last") ->
)

return bipartide_edge_index


def generate_second_order_model(g: TemporalGraph, delta: float | int = 1, weight: str = "edge_weight") -> mm.MultiOrderModel:
"""
Generate a multi-order model with first- and second-order layer from a temporal graph.
This method is optimized for the memory footprint of large graphs and it may be slower than creating small models with the default approach.
"""
data = g.data.sort_by_time()
edge_index1, timestamps1 = data.edge_index, data.time

node_sequence1 = torch.arange(data.num_nodes, device=edge_index1.device).unsqueeze(1)
if weight in data:
edge_weight = data[weight]
else:
edge_weight = torch.ones(edge_index1.size(1), device=edge_index1.device)

layer1 = aggregate_edge_index(
edge_index=edge_index1, node_sequence=node_sequence1, edge_weight=edge_weight
)
layer1.mapping = g.mapping

node_sequence2 = torch.cat([node_sequence1[edge_index1[0]], node_sequence1[edge_index1[1]][:, -1:]], dim=1)
node_sequence2, edge1_to_node2 = torch.unique(node_sequence2, dim=0, return_inverse=True)

edge_index2 = []
edge_weight2 = []
for timestamp in timestamps1.unique():
src_nodes2, src_nodes2_counts = edge1_to_node2[timestamps1 == timestamp].unique(return_counts=True)
dst_nodes2, dst_nodes2_counts = edge1_to_node2[(timestamps1 > timestamp) & (timestamps1 <= timestamp + delta)].unique(return_counts=True)
for src_node2, src_node2_count in zip(src_nodes2, src_nodes2_counts):
dst_node2 = dst_nodes2[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]]
dst_node2_count = dst_nodes2_counts[node_sequence2[dst_nodes2, 0] == node_sequence2[src_node2, -1]]

edge_index2.append(torch.stack([src_node2.expand(dst_node2.size(0)), dst_node2], dim=0))
edge_weight2.append(src_node2_count.expand(dst_node2.size(0)) * dst_node2_count)

edge_index2 = torch.cat(edge_index2, dim=1)
edge_weight2 = torch.cat(edge_weight2, dim=0)

edge_index2, edge_weight2 = coalesce(edge_index2, edge_attr=edge_weight2, num_nodes=node_sequence2.size(0), reduce="sum")

data2 = Data(
edge_index=edge_index2,
num_nodes=node_sequence2.size(0),
node_sequence=node_sequence2,
edge_weight=edge_weight2,
inverse_idx=edge1_to_node2,
)
layer2 = Graph(data2)
layer2.mapping = IndexMap(
[tuple(layer1.mapping.to_ids(v.cpu())) for v in node_sequence2]
)


m = mm.MultiOrderModel()
m.layers[1] = layer1
m.layers[2] = layer2
return m
29 changes: 29 additions & 0 deletions tests/utils/test_generate_second_order_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch

from pathpyG.core.multi_order_model import MultiOrderModel
from pathpyG.core.temporal_graph import TemporalGraph
from pathpyG.utils.dbgnn import generate_second_order_model

def test_generate_second_order_model():
tedges = [('a', 'b', 1), ('c', 'b', 1), ('c', 'a', 1), ('c', 'a', 1), ('f', 'c', 1),
('b', 'c', 5), ('a', 'd', 5), ('c', 'd', 9), ('a', 'd', 9), ('c', 'e', 9),
('c', 'f', 11), ('f', 'a', 13), ('a', 'g', 18), ('b', 'f', 21),
('a', 'g', 26), ('c', 'f', 27), ('h', 'f', 27), ('g', 'h', 28),
('a', 'c', 30), ('a', 'b', 31), ('c', 'h', 32), ('f', 'h', 33),
('b', 'i', 42), ('i', 'b', 42), ('c', 'i', 47), ('h', 'i', 50)]

g = TemporalGraph.from_edge_list(tedges)
reference = MultiOrderModel.from_temporal_graph(g, max_order=2, delta=10).to_dbgnn_data()

g = TemporalGraph.from_edge_list(tedges)
result = generate_second_order_model(g, delta=10).to_dbgnn_data()

assert result.num_nodes == reference.num_nodes
assert result.num_ho_nodes == reference.num_ho_nodes
assert torch.equal(result.x, reference.x)
assert torch.equal(result.edge_index, reference.edge_index)
assert torch.equal(result.edge_weights, reference.edge_weights)
assert torch.equal(result.x_h, reference.x_h)
assert torch.equal(result.edge_index_higher_order, reference.edge_index_higher_order)
assert torch.equal(result.edge_weights_higher_order, reference.edge_weights_higher_order)
assert torch.equal(result.bipartite_edge_index, reference.bipartite_edge_index)
Loading