Skip to content

Commit

Permalink
GATAS implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
adeandrade committed Jun 10, 2020
1 parent ecfcaf6 commit a852215
Show file tree
Hide file tree
Showing 42 changed files with 4,016 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# IDEs
.idea/
8 changes: 8 additions & 0 deletions MLproject
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: gatas

entry_points:
train-link-predictor:
command: "python3.7 -m link_prediction.train"

train-node-classifier:
command: "python3.7 -m node_classification.train"
95 changes: 95 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# GATAS
Implementation of *Graph Representation Learning Network via Adaptive Sampling*: [http://arxiv.org/abs/2006.04637](http://arxiv.org/abs/2006.04637)

The algorithm represents nodes by reducing their neighbour representations with attention. Multi-step neighbour representations incorporate different path properties. Neighbours are sampled using learnable depth coefficients.


## Overview
This repository is organized as follows:
- `data/` contains the necessary files for the Cora, Pubmed, Citeseer, PPI, Twitter and YouTube datasets.
- `framework/` contains helper libraries for model development, training and evaluation.
- `gatas/` contains the implementation of GATAS.
- `node_classification/` contains a node label classifier using the model.
- `link_prediction/` contains a link prediction model using the model.


## Instructions
First we must create a CSR binary representation of the graph where the values are the edge type indices. For the Cora, Citeseer and PubMed datasets, we have precomputed and placed them in `data/`. For PPI, it can be computed with:

```bash
python3 -m node_classification.datasets.ppi --path data/ppi
```

For the Twitter and YouTube datasets, it can be computed with:
```bash
python3 -m link_prediction.datasets.gatne --path {path to dataset} --num-edge-types {number of edge types}
```

These scripts will also collect and preprocess the node features when available, and create dataset splits with inputs and targets for the tasks. Once we have a CSR graph representation, we can compute the transition probabilities by running:
```bash
python3 -m gatas.transitions --path {path to dataset} --num_steps {number of steps}
```

GATAS has two components: `NeighbourSampler` and `NeighbourAggregator`. `NeighbourSampler` can be initialized with a path so the precomputed transition data can be used:

```python
from gatas.sampler import NeighbourSampler

neighbour_sampler = NeighbourSampler.from_path(num_steps=3, path='data/ppi')
```

`NeighbourAggregator` can receive a matrix of node features and can be initialized as follows:
```python
import numpy as np
from gatas.aggregator import NeighbourAggregator

node_features = np.load('data/ppi/node_embeddings.npy')

neighbour_aggregator = NeighbourAggregator(
input_noise_rate=0.,
dropout_rate=0.,
num_nodes=node_features.shape[0],
num_edge_types=neighbour_sampler.num_edge_types,
num_steps=3,
edge_type_embedding_size=5,
node_embedding_size=None,
layer_size=256,
num_attention_heads=10,
node_features=node_features,
)
```

We can call `neighbour_aggregator` with the output of `neighbour_sampler`. This pattern is used in the node classification and link prediction tasks. You can train those models with:

```bash
python3 -m node_classification.train --data-path {path to dataset}
```

or:

```bash
python3 -m link_prediction.train --data-path {path to dataset}
```

where additional parameters can be passed through the command line. Run with `--help` for a list of them:

```bash
python3 -m node_classification.train --help
```


## Reference
```tex
@misc{andrade2020graph,
title={Graph Representation Learning Network via Adaptive Sampling},
author={Anderson de Andrade and Chen Liu},
year={2020},
eprint={2006.04637},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```


## License
MIT
Empty file added framework/__init__.py
Empty file.
Empty file added framework/common/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions framework/common/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
import sys
from typing import Callable, Union, Optional, List, Mapping, Set, Tuple

import tensorflow as tf


class DataSpecification:
def __init__(self, specs: List[tf.TensorSpec]) -> None:
self.specs = specs

def get_types(self) -> Tuple[tf.DType, ...]:
return tuple(spec.dtype for spec in self.specs)

def get_shapes(self) -> Tuple[tf.TensorShape, ...]:
return tuple(spec.shape for spec in self.specs)

def get_names(self) -> Tuple[str, ...]:
return tuple(spec.name for spec in self.specs)

def create_placeholders(self) -> List[tf.Tensor]:
placeholders = [
tf.placeholder(spec.dtype, spec.shape, spec.name)
for spec in self.specs
]

return placeholders


PrimitiveType = Union[int, float, bool, str]


def get_positional_arguments(argv: Optional[List[str]] = None) -> List[str]:
argv = argv if argv else sys.argv[1:]

arguments = []

for argument in argv:
if argument.startswith('-'):
break

arguments.append(argument)

return arguments


def get_keyword_arguments(argv: Optional[List[str]] = None) -> Set[str]:
argv = argv if argv else sys.argv[1:]

num_positional_arguments = len(get_positional_arguments(argv))

passed_arguments = {
key.lstrip('-').replace('_', '-')
for key in argv[num_positional_arguments::2]
if key.startswith('-')
}

return passed_arguments


def get_script_parameters(function: Callable, ignore_keyword_arguments: bool = True) -> Mapping[str, PrimitiveType]:
"""
Returns the arguments of a function with its values specified by the command line or its default values.
Underscores in the name of the arguments are transformed to dashes.
Can optionally filter out keyword arguments obtained through the command line.
:param function: the function to inspect.
:param ignore_keyword_arguments: whether to filter out keyword command line arguments.
:return: a map from argument names to default values.
"""
positional_arguments, keyword_arguments = get_positional_arguments(), get_keyword_arguments()
signature = inspect.signature(function)

arguments = {}

for index, (name, parameter) in enumerate(signature.parameters.items()):
transformed_name = name.replace('_', '-')

if index < len(positional_arguments):
arguments[transformed_name] = positional_arguments[index]

elif not (ignore_keyword_arguments and transformed_name in keyword_arguments):
if parameter.default != parameter.empty:
arguments[transformed_name] = parameter.default

return arguments
Empty file added framework/dataset/__init__.py
Empty file.
110 changes: 110 additions & 0 deletions framework/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numba
import numpy as np


@numba.njit()
def convert_incremental_lengths_to_indices(accumulated_lengths: np.ndarray) -> np.ndarray:
"""
Convert a vector of accumulated item lengths into a matrix of indices.
For instance, given [0, 2, 5], we obtain [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]].
:param accumulated_lengths: A vector of accumulated lengths of size N + 1.
:return: A matrix of size (total length, 2) representing the assignment of each index to an item in N.
"""
num_indices = accumulated_lengths[-1] - accumulated_lengths[0]

indices = np.empty((num_indices, 2), dtype=np.int32)

index = 0

for outer_index in range(accumulated_lengths.size - 1):
length = accumulated_lengths[outer_index + 1] - accumulated_lengths[outer_index]

for inner_index in range(length):
indices[index, 0] = outer_index
indices[index, 1] = inner_index

index += 1

return indices


@numba.njit()
def convert_lengths_to_indices(lengths: np.ndarray) -> np.ndarray:
"""
Convert a vector of item lengths into a matrix of indices.
For instance, given [2, 3], we obtain [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]].
:param lengths: A vector of lengths of size N.
:return: A matrix of size (total length, 2) representing the assignment of each index to an item in N.
"""
num_indices = np.sum(lengths)

indices = np.empty((num_indices, 2), dtype=np.int32)

index = 0

for outer_index, length in enumerate(lengths):
for inner_index in range(length):
indices[index, 0] = outer_index
indices[index, 1] = inner_index

index += 1

return indices


@numba.njit()
def convert_incremental_lengths_to_segment_indices(accumulated_lengths: np.ndarray) -> np.ndarray:
"""
Convert a vector of accumulated item lengths into a segment of indices.
For instance, given [0, 2, 5], we obtain [0, 0, 1, 1, 1].
:param accumulated_lengths: A vector of accumulated lengths of size N + 1.
:return: A vector of size (total length) representing the assignment of each item index in N.
"""
num_indices = accumulated_lengths[-1] - accumulated_lengths[0]

segment_indices = np.empty(num_indices, dtype=np.int32)

index = 0

for outer_index in range(accumulated_lengths.size - 1):
length = accumulated_lengths[outer_index + 1] - accumulated_lengths[outer_index]

for _ in range(length):
segment_indices[index] = outer_index

index += 1

return segment_indices


@numba.njit()
def convert_incremental_lengths_to_lengths(accumulated_lengths: np.ndarray) -> np.ndarray:
"""
Convert a vector of accumulated item lengths into lengths.
For instance, given [0, 2, 5], we obtain [2, 3].
:param accumulated_lengths: A vector of accumulated lengths of size N + 1.
:return: A vector of size N representing the length of each item.
"""
num_items = accumulated_lengths.size - 1

lengths = np.empty(num_items, dtype=np.int32)

for index in range(num_items):
lengths[index] = accumulated_lengths[index + 1] - accumulated_lengths[index]

return lengths


@numba.njit()
def accumulate(vector: np.ndarray) -> np.ndarray:
"""
Accumulates the values of a vector such that accumulated[i] = accumulated[i - 1] + vector[i]
:param vector: A vector of size N.
:return: A vector of size N + 1.
"""
accumulated = np.zeros(vector.size + 1, dtype=np.int32)

for index, value in enumerate(vector):
accumulated[index + 1] = accumulated[index] + value

return accumulated
Loading

0 comments on commit a852215

Please sign in to comment.