diff --git a/.gitignore b/.gitignore index b6e4761..b72833f 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# IDEs +.idea/ diff --git a/MLproject b/MLproject new file mode 100644 index 0000000..187e358 --- /dev/null +++ b/MLproject @@ -0,0 +1,8 @@ +name: gatas + +entry_points: + train-link-predictor: + command: "python3.7 -m link_prediction.train" + + train-node-classifier: + command: "python3.7 -m node_classification.train" diff --git a/README.md b/README.md new file mode 100644 index 0000000..108e670 --- /dev/null +++ b/README.md @@ -0,0 +1,95 @@ +# GATAS +Implementation of *Graph Representation Learning Network via Adaptive Sampling*: [http://arxiv.org/abs/2006.04637](http://arxiv.org/abs/2006.04637) + +The algorithm represents nodes by reducing their neighbour representations with attention. Multi-step neighbour representations incorporate different path properties. Neighbours are sampled using learnable depth coefficients. + + +## Overview +This repository is organized as follows: +- `data/` contains the necessary files for the Cora, Pubmed, Citeseer, PPI, Twitter and YouTube datasets. +- `framework/` contains helper libraries for model development, training and evaluation. +- `gatas/` contains the implementation of GATAS. +- `node_classification/` contains a node label classifier using the model. +- `link_prediction/` contains a link prediction model using the model. + + +## Instructions +First we must create a CSR binary representation of the graph where the values are the edge type indices. For the Cora, Citeseer and PubMed datasets, we have precomputed and placed them in `data/`. For PPI, it can be computed with: + +```bash +python3 -m node_classification.datasets.ppi --path data/ppi +``` + +For the Twitter and YouTube datasets, it can be computed with: +```bash +python3 -m link_prediction.datasets.gatne --path {path to dataset} --num-edge-types {number of edge types} +``` + +These scripts will also collect and preprocess the node features when available, and create dataset splits with inputs and targets for the tasks. Once we have a CSR graph representation, we can compute the transition probabilities by running: +```bash +python3 -m gatas.transitions --path {path to dataset} --num_steps {number of steps} +``` + +GATAS has two components: `NeighbourSampler` and `NeighbourAggregator`. `NeighbourSampler` can be initialized with a path so the precomputed transition data can be used: + +```python +from gatas.sampler import NeighbourSampler + +neighbour_sampler = NeighbourSampler.from_path(num_steps=3, path='data/ppi') +``` + +`NeighbourAggregator` can receive a matrix of node features and can be initialized as follows: +```python +import numpy as np +from gatas.aggregator import NeighbourAggregator + +node_features = np.load('data/ppi/node_embeddings.npy') + +neighbour_aggregator = NeighbourAggregator( + input_noise_rate=0., + dropout_rate=0., + num_nodes=node_features.shape[0], + num_edge_types=neighbour_sampler.num_edge_types, + num_steps=3, + edge_type_embedding_size=5, + node_embedding_size=None, + layer_size=256, + num_attention_heads=10, + node_features=node_features, +) +``` + +We can call `neighbour_aggregator` with the output of `neighbour_sampler`. This pattern is used in the node classification and link prediction tasks. You can train those models with: + +```bash +python3 -m node_classification.train --data-path {path to dataset} +``` + +or: + +```bash +python3 -m link_prediction.train --data-path {path to dataset} +``` + +where additional parameters can be passed through the command line. Run with `--help` for a list of them: + +```bash +python3 -m node_classification.train --help +``` + + +## Reference +```tex +@misc{andrade2020graph, + title={Graph Representation Learning Network via Adaptive Sampling}, + author={Anderson de Andrade and Chen Liu}, + year={2020}, + eprint={2006.04637}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + + +## License +MIT diff --git a/framework/__init__.py b/framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/framework/common/__init__.py b/framework/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/framework/common/parameters.py b/framework/common/parameters.py new file mode 100644 index 0000000..30469e4 --- /dev/null +++ b/framework/common/parameters.py @@ -0,0 +1,85 @@ +import inspect +import sys +from typing import Callable, Union, Optional, List, Mapping, Set, Tuple + +import tensorflow as tf + + +class DataSpecification: + def __init__(self, specs: List[tf.TensorSpec]) -> None: + self.specs = specs + + def get_types(self) -> Tuple[tf.DType, ...]: + return tuple(spec.dtype for spec in self.specs) + + def get_shapes(self) -> Tuple[tf.TensorShape, ...]: + return tuple(spec.shape for spec in self.specs) + + def get_names(self) -> Tuple[str, ...]: + return tuple(spec.name for spec in self.specs) + + def create_placeholders(self) -> List[tf.Tensor]: + placeholders = [ + tf.placeholder(spec.dtype, spec.shape, spec.name) + for spec in self.specs + ] + + return placeholders + + +PrimitiveType = Union[int, float, bool, str] + + +def get_positional_arguments(argv: Optional[List[str]] = None) -> List[str]: + argv = argv if argv else sys.argv[1:] + + arguments = [] + + for argument in argv: + if argument.startswith('-'): + break + + arguments.append(argument) + + return arguments + + +def get_keyword_arguments(argv: Optional[List[str]] = None) -> Set[str]: + argv = argv if argv else sys.argv[1:] + + num_positional_arguments = len(get_positional_arguments(argv)) + + passed_arguments = { + key.lstrip('-').replace('_', '-') + for key in argv[num_positional_arguments::2] + if key.startswith('-') + } + + return passed_arguments + + +def get_script_parameters(function: Callable, ignore_keyword_arguments: bool = True) -> Mapping[str, PrimitiveType]: + """ + Returns the arguments of a function with its values specified by the command line or its default values. + Underscores in the name of the arguments are transformed to dashes. + Can optionally filter out keyword arguments obtained through the command line. + :param function: the function to inspect. + :param ignore_keyword_arguments: whether to filter out keyword command line arguments. + :return: a map from argument names to default values. + """ + positional_arguments, keyword_arguments = get_positional_arguments(), get_keyword_arguments() + signature = inspect.signature(function) + + arguments = {} + + for index, (name, parameter) in enumerate(signature.parameters.items()): + transformed_name = name.replace('_', '-') + + if index < len(positional_arguments): + arguments[transformed_name] = positional_arguments[index] + + elif not (ignore_keyword_arguments and transformed_name in keyword_arguments): + if parameter.default != parameter.empty: + arguments[transformed_name] = parameter.default + + return arguments diff --git a/framework/dataset/__init__.py b/framework/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/framework/dataset/dataset.py b/framework/dataset/dataset.py new file mode 100644 index 0000000..e9556eb --- /dev/null +++ b/framework/dataset/dataset.py @@ -0,0 +1,110 @@ +import numba +import numpy as np + + +@numba.njit() +def convert_incremental_lengths_to_indices(accumulated_lengths: np.ndarray) -> np.ndarray: + """ + Convert a vector of accumulated item lengths into a matrix of indices. + For instance, given [0, 2, 5], we obtain [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]. + :param accumulated_lengths: A vector of accumulated lengths of size N + 1. + :return: A matrix of size (total length, 2) representing the assignment of each index to an item in N. + """ + num_indices = accumulated_lengths[-1] - accumulated_lengths[0] + + indices = np.empty((num_indices, 2), dtype=np.int32) + + index = 0 + + for outer_index in range(accumulated_lengths.size - 1): + length = accumulated_lengths[outer_index + 1] - accumulated_lengths[outer_index] + + for inner_index in range(length): + indices[index, 0] = outer_index + indices[index, 1] = inner_index + + index += 1 + + return indices + + +@numba.njit() +def convert_lengths_to_indices(lengths: np.ndarray) -> np.ndarray: + """ + Convert a vector of item lengths into a matrix of indices. + For instance, given [2, 3], we obtain [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]. + :param lengths: A vector of lengths of size N. + :return: A matrix of size (total length, 2) representing the assignment of each index to an item in N. + """ + num_indices = np.sum(lengths) + + indices = np.empty((num_indices, 2), dtype=np.int32) + + index = 0 + + for outer_index, length in enumerate(lengths): + for inner_index in range(length): + indices[index, 0] = outer_index + indices[index, 1] = inner_index + + index += 1 + + return indices + + +@numba.njit() +def convert_incremental_lengths_to_segment_indices(accumulated_lengths: np.ndarray) -> np.ndarray: + """ + Convert a vector of accumulated item lengths into a segment of indices. + For instance, given [0, 2, 5], we obtain [0, 0, 1, 1, 1]. + :param accumulated_lengths: A vector of accumulated lengths of size N + 1. + :return: A vector of size (total length) representing the assignment of each item index in N. + """ + num_indices = accumulated_lengths[-1] - accumulated_lengths[0] + + segment_indices = np.empty(num_indices, dtype=np.int32) + + index = 0 + + for outer_index in range(accumulated_lengths.size - 1): + length = accumulated_lengths[outer_index + 1] - accumulated_lengths[outer_index] + + for _ in range(length): + segment_indices[index] = outer_index + + index += 1 + + return segment_indices + + +@numba.njit() +def convert_incremental_lengths_to_lengths(accumulated_lengths: np.ndarray) -> np.ndarray: + """ + Convert a vector of accumulated item lengths into lengths. + For instance, given [0, 2, 5], we obtain [2, 3]. + :param accumulated_lengths: A vector of accumulated lengths of size N + 1. + :return: A vector of size N representing the length of each item. + """ + num_items = accumulated_lengths.size - 1 + + lengths = np.empty(num_items, dtype=np.int32) + + for index in range(num_items): + lengths[index] = accumulated_lengths[index + 1] - accumulated_lengths[index] + + return lengths + + +@numba.njit() +def accumulate(vector: np.ndarray) -> np.ndarray: + """ + Accumulates the values of a vector such that accumulated[i] = accumulated[i - 1] + vector[i] + :param vector: A vector of size N. + :return: A vector of size N + 1. + """ + accumulated = np.zeros(vector.size + 1, dtype=np.int32) + + for index, value in enumerate(vector): + accumulated[index + 1] = accumulated[index] + value + + return accumulated diff --git a/framework/dataset/io.py b/framework/dataset/io.py new file mode 100644 index 0000000..92017c3 --- /dev/null +++ b/framework/dataset/io.py @@ -0,0 +1,171 @@ +import glob +import io +import os +import re +import tempfile +from contextlib import contextmanager +from typing import Optional, List, Tuple, Union, Generator +from urllib.parse import urlparse + +import boto3 +import numpy as np +import s3fs + + +def get_bucket_and_key(path: str) -> Tuple[str, str]: + url = urlparse(path) + + return url.netloc, url.path.strip('/') + + +def download_s3(bucket: str, key: str, destination: str) -> None: + boto3.client('s3').download_file(Bucket=bucket, Key=key, Filename=destination) + + +def put_s3(bucket: str, key: str, reader: io.BytesIO) -> None: + boto3.client('s3').put_object(Bucket=bucket, Key=key, Body=reader) + + +def get_s3_keys(bucket: str, prefix: str = '') -> List[str]: + return [file['Key'] for file in boto3.client('s3').list_objects_v2(Bucket=bucket, Prefix=prefix)['Contents']] + + +def get_relative_path(shallow_path: str, deep_path: str) -> str: + return re.sub('^' + shallow_path, '', deep_path).lstrip('/') + + +def get_file_in_directory(directory: str, file: str) -> str: + return os.path.join(directory, file.split('/')[-1]) + + +def upload_s3(source: str, bucket: str, key: str) -> None: + boto3.client('s3').upload_file(Filename=source, Bucket=bucket, Key=key) + + +def get_temporary_local_directory_path() -> str: + return tempfile.mkdtemp() + + +def get_temporary_local_path() -> str: + return tempfile.mkstemp()[1] + + +def get_relative_path_in_directory(directory: str, shallow_path: str, deep_path: str) -> str: + return os.path.join(directory, get_relative_path(shallow_path, deep_path)) + + +def get_local_directory_path(path: str, directory: Optional[str] = None) -> str: + if not path.startswith('s3'): + return path + + directory = get_temporary_local_directory_path() if directory is None else directory + + local_path = get_file_in_directory(directory, path) + + bucket, prefix = get_bucket_and_key(path) + + for key in get_s3_keys(bucket, prefix): + download_s3(bucket, key, get_file_in_directory(directory, key)) + + return local_path + + +def get_local_path(path: str) -> str: + if not path.startswith('s3'): + return path + + bucket, key = get_bucket_and_key(path) + + local_path = get_temporary_local_path() + + download_s3(bucket, key, local_path) + + return local_path + + +def load_npy(path: str, mmap_mode: Optional[str] = None) -> np.ndarray: + return np.load(get_local_path(path), mmap_mode=mmap_mode) + + +def load_bin(path: str, dtype: np.dtype.type = np.uint8, mode: str = 'r+', shape: List[int] = None) -> np.ndarray: + return np.memmap(get_local_path(path), dtype, mode, shape=shape) + + +@contextmanager +def open(path: str, mode: str = 'rb', encoding: Optional[str] = None) -> Generator[io.TextIOBase, None, None]: + open_function = s3fs.S3FileSystem.current().open if path.startswith('s3') else io.open + + with open_function(path, mode=mode, encoding=encoding) as file_object: + yield file_object + + +@contextmanager +def with_local_path(path: str) -> Generator[str, None, None]: + local_path = get_temporary_local_path() if path.startswith('s3') else path + + yield local_path + + if path.startswith('s3'): + bucket, key = get_bucket_and_key(path) + upload_s3(local_path, bucket, key) + + +@contextmanager +def with_local_directory_path(path: str) -> Generator[str, None, None]: + local_path = get_temporary_local_directory_path() if path.startswith('s3') else path + + yield local_path + + if path.startswith('s3'): + sync_remote(local_path, path) + + +def read(path: str, encoding: Optional[str] = None) -> Union[bytes, str]: + if not path.startswith('s3'): + if encoding: + return io.open(path, mode='r', encoding=encoding).read() + else: + return io.open(path, mode='rb').read() + + bucket, key = get_bucket_and_key(path) + + content = boto3.client('s3').get_object(Bucket=bucket, Key=key)['Body'].read() + + if encoding: + return content.decode(encoding) + + return content + + +def save_npy(path: str, data: np.ndarray) -> None: + if not path.startswith('s3'): + np.save(path, data) + + else: + bucket, key = get_bucket_and_key(path) + + writer = io.BytesIO() + + np.save(writer, data) + + writer.seek(0) + + put_s3(bucket, key, writer) + + +def sync_remote(local_path: str, remote_path: str) -> None: + paths = (path for path in glob.glob(os.path.join(local_path, '**'), recursive=True) if os.path.isfile(path)) + + for path in paths: + path_in_remote_path = get_relative_path_in_directory(remote_path, local_path, path) + + upload(path, path_in_remote_path) + + +def upload(source: str, destination: str) -> None: + if not destination.startswith('s3'): + raise IOError(f'Destination path {destination} not supported.') + + bucket, key = get_bucket_and_key(destination) + + upload_s3(source, bucket, key) diff --git a/framework/trackers/__init__.py b/framework/trackers/__init__.py new file mode 100644 index 0000000..b845675 --- /dev/null +++ b/framework/trackers/__init__.py @@ -0,0 +1,5 @@ +from typing import Union, Callable + +Numeric = Union[int, float] +PrimitiveType = Union[int, float, str, bool] +Aggregator = Callable[..., Numeric] diff --git a/framework/trackers/aggregator.py b/framework/trackers/aggregator.py new file mode 100644 index 0000000..5167d44 --- /dev/null +++ b/framework/trackers/aggregator.py @@ -0,0 +1,115 @@ +from enum import Enum +from typing import Mapping, Union, MutableMapping, Collection, Tuple, List, Optional, NamedTuple + +import numpy as np + +from framework.trackers import Numeric, Aggregator +from framework.trackers.metrics import Metric + + +class Statistic(Enum): + TRAINING_TARGET = 'Training Target' + VALIDATION_TARGET = 'Validation Target' + TESTING_TARGET = 'Testing Target' + + TRAINING_PREDICTION = 'Training Prediction' + VALIDATION_PREDICTION = 'Validation Prediction' + TESTING_PREDICTION = 'Testing Prediction' + + TRAINING_PROBABILITY = 'Training Probability' + VALIDATION_PROBABILITY = 'Validation Probability' + TESTING_PROBABILITY = 'Testing Probability' + + TRAINING_COST = 'Training Cost' + VALIDATION_COST = 'Validation Cost' + TESTING_COST = 'Testing Cost' + + BETA = 'Beta' + + +MetricKey = Union[str, Metric] +StatisticKey = Union[str, Statistic] +StatisticValue = Union[List[Numeric], Numeric, np.ndarray] + + +class Aggregation(NamedTuple): + metric: MetricKey + statistics: Collection[StatisticKey] + function: Aggregator + + +class MetricAggregator: + def __init__(self, aggregators: Optional[Collection[Aggregation]] = None): + self.aggregators = {} # type: MutableMapping[MetricKey, Tuple[Collection[StatisticKey], Aggregator]] + self.state = {} # type: MutableMapping[StatisticKey, List[Numeric]] + + if aggregators: + self.register_aggregators(aggregators) + + @staticmethod + def to_numeric_list(value: StatisticValue) -> List[Numeric]: + if isinstance(value, np.ndarray): + return value.tolist() + elif isinstance(value, List): + return value + else: + return [value] + + @staticmethod + def to_str(value: MetricKey) -> str: + return value.value if isinstance(value, Enum) else value + + def register_aggregator(self, metric: MetricKey, statistics: Collection[StatisticKey], function: Aggregator) -> None: + self.aggregators[metric] = (statistics, function) + + for statistic in statistics: + if statistic not in self.state: + self.state[statistic] = [] + + def register_aggregators(self, aggregators: Collection[Aggregation]): + for aggregator in aggregators: + self.register_aggregator(aggregator.metric, aggregator.statistics, aggregator.function) + + def add_statistic(self, key: StatisticKey, value: StatisticValue) -> None: + self.state[key].extend(self.to_numeric_list(value)) + + def add_statistics(self, statistics: Mapping[StatisticKey, StatisticValue]) -> None: + for key, value in statistics.items(): + self.add_statistic(key, value) + + def _compute_metric(self, key: MetricKey) -> Optional[Numeric]: + statistics, function = self.aggregators[key] + + if not any(len(self.state[statistic]) != 0 for statistic in statistics): + return None + + return function(*(np.array(self.state[statistic]) for statistic in statistics)) + + def compute_metric(self, key: MetricKey) -> Numeric: + value = self._compute_metric(key) + + assert value is not None + + return value + + def compute_metrics(self, keys: Collection[MetricKey]) -> Mapping[MetricKey, Numeric]: + return { + key: value + for key, value + in ((key, self._compute_metric(key)) for key in keys) + if value is not None + } + + def get_metrics(self) -> Mapping[MetricKey, Numeric]: + return self.compute_metrics(self.aggregators.keys()) + + def clear(self) -> None: + for sequence in self.state.values(): + sequence.clear() + + def flush(self) -> Mapping[MetricKey, Numeric]: + metrics = self.compute_metrics(self.aggregators.keys()) + + self.clear() + + return metrics diff --git a/framework/trackers/metrics.py b/framework/trackers/metrics.py new file mode 100644 index 0000000..de4f489 --- /dev/null +++ b/framework/trackers/metrics.py @@ -0,0 +1,114 @@ +from enum import Enum +from typing import Optional + +import numpy as np +import sklearn.metrics as sk_metrics + +from framework.trackers import Numeric, Aggregator + + +class Metric(Enum): + TRAINING_ACCURACY = 'Training Accuracy' + VALIDATION_ACCURACY = 'Validation Accuracy' + TESTING_ACCURACY = 'Testing Accuracy' + + VALIDATION_WEIGHTED_F1 = 'Validation Weighted F1 Score' + TESTING_WEIGHTED_F1 = 'Testing Weighted F1 Score' + VALIDATION_WEIGHTED_PRECISION = 'Validation Weighted Precision' + TESTING_WEIGHTED_PRECISION = 'Testing Weighted Precision' + VALIDATION_WEIGHTED_RECALL = 'Validation Weighted Recall' + TESTING_WEIGHTED_RECALL = 'Testing Weighted Recall' + + TRAINING_MICRO_F1 = 'Training Micro F1 Score' + VALIDATION_MICRO_F1 = 'Validation Micro F1 Score' + TESTING_MICRO_F1 = 'Testing Micro F1 Score' + VALIDATION_MICRO_PRECISION = 'Validation Micro Precision' + TESTING_MICRO_PRECISION = 'Testing Micro Precision' + VALIDATION_MICRO_RECALL = 'Validation Micro Recall' + TESTING_MICRO_RECALL = 'Testing Micro Recall' + + TRAINING_MACRO_F1 = 'Training Macro F1 Score' + VALIDATION_MACRO_F1 = 'Validation Macro F1 Score' + TESTING_MACRO_F1 = 'Testing Macro F1 Score' + VALIDATION_MACRO_PRECISION = 'Validation Macro Precision' + TESTING_MACRO_PRECISION = 'Testing Macro Precision' + VALIDATION_MACRO_RECALL = 'Validation Macro Recall' + TESTING_MACRO_RECALL = 'Testing Macro Recall' + + TRAINING_MEAN_COST = 'Training Mean Cost' + VALIDATION_MEAN_COST = 'Validation Mean Cost' + TESTING_MEAN_COST = 'Testing Mean Cost' + + VALIDATION_JACCARD_SCORE = 'Validation Jaccard Score' + + MEAN_BETA = 'Mean Beta' + + +class MetricFunctions: + @staticmethod + def accuracy(targets: np.ndarray, predictions: np.ndarray) -> Numeric: + return sk_metrics.accuracy_score(targets, predictions) + + @staticmethod + def precision(targets: np.ndarray, predictions: np.ndarray, average: Optional[str] = 'weighted') -> Numeric: + return sk_metrics.precision_score(targets, predictions, average=average, zero_division=0) + + @staticmethod + def recall(targets: np.ndarray, predictions: np.ndarray, average: Optional[str] = 'weighted') -> Numeric: + return sk_metrics.recall_score(targets, predictions, average=average, zero_division=0) + + @staticmethod + def f1_score(targets: np.ndarray, predictions: np.ndarray, average: Optional[str] = 'weighted') -> Numeric: + return sk_metrics.f1_score(targets, predictions, average=average, zero_division=0) + + @staticmethod + def jaccard_score(targets: np.ndarray, predictions: np.ndarray, average: Optional[str] = 'weighted') -> Numeric: + return sk_metrics.jaccard_score(targets, predictions, average=average) + + @staticmethod + def jaccard_score_multiclass(targets: np.ndarray, predictions: np.ndarray) -> Numeric: + overlap = np.sum(targets * predictions, axis=1) + + union = np.sum(targets + predictions > 0, axis=1) + + score = float(np.mean(np.divide(overlap, union, out=np.zeros_like(union, dtype=np.float32), where=union != 0))) + + return score + + @staticmethod + def mean(array: np.ndarray) -> Numeric: + return float(np.mean(array)) + + @staticmethod + def calculate_thresholds(targets: np.ndarray, probabilities: np.ndarray) -> np.array: + num_classes = targets.shape[1] + + best_thresholds = np.empty(shape=num_classes, dtype=np.float32) + + for index in range(num_classes): + precisions, recalls, thresholds = sk_metrics.precision_recall_curve( + y_true=targets[:, index], + probas_pred=probabilities[:, index], + ) + + f1_scores_denominator = precisions + recalls + f1_scores = np.divide( + 2 * precisions * recalls, + f1_scores_denominator, + out=np.zeros_like(f1_scores_denominator), + where=f1_scores_denominator != 0, + ) + + best_thresholds[index] = thresholds[np.nanargmax(f1_scores) - 1] + + return best_thresholds + + @classmethod + def calculate_with_optimized_predictions(cls, targets: np.ndarray, probabilities: np.ndarray, function: Aggregator) -> Numeric: + thresholds = cls.calculate_thresholds(targets, probabilities) + + predictions = probabilities >= thresholds + + score = function(targets, predictions.astype(np.float32)) + + return score diff --git a/framework/trackers/tracker.py b/framework/trackers/tracker.py new file mode 100644 index 0000000..cb2ff24 --- /dev/null +++ b/framework/trackers/tracker.py @@ -0,0 +1,26 @@ +from typing import Mapping, Optional + +from framework.trackers import Numeric + + +class Tracker: + def set_tags(self, tags: Mapping[str, str]) -> None: + raise NotImplementedError + + def register_parameters(self, parameters: Mapping[str, Numeric]) -> None: + raise NotImplementedError + + def log_metrics(self, metrics: Mapping[str, Numeric], step: Optional[int] = None) -> None: + raise NotImplementedError + + def save_model(self, path: str) -> None: + raise NotImplementedError + + def finish_epoch(self) -> None: + raise NotImplementedError + + def __enter__(self) -> 'Tracker': + raise NotImplementedError + + def __exit__(self, context_type, value, traceback) -> None: + raise NotImplementedError diff --git a/framework/trackers/tracker_mlflow.py b/framework/trackers/tracker_mlflow.py new file mode 100644 index 0000000..1460821 --- /dev/null +++ b/framework/trackers/tracker_mlflow.py @@ -0,0 +1,53 @@ +import glob +from typing import Mapping, Collection, Optional + +import mlflow + +from framework.trackers import Numeric, PrimitiveType +from framework.trackers.aggregator import Aggregation, MetricAggregator +from framework.trackers.tracker import Tracker + + +class MLFlowTracker(MetricAggregator, Tracker): + def __init__( + self, + experiment_name: str, + tracking_uri: Optional[str] = None, + aggregators: Optional[Collection[Aggregation]] = None): + + super().__init__(aggregators) + + self.step = 0 + + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + + mlflow.set_experiment(experiment_name) + + def set_tags(self, tags: Mapping[str, str]) -> None: + mlflow.set_tags(tags) + + def register_parameters(self, parameters: Mapping[str, PrimitiveType]) -> None: + mlflow.log_params({self.to_str(key): value for key, value in parameters.items()}) + + def log_metrics(self, metrics: Mapping[str, Numeric], step: Optional[int] = None) -> None: + mlflow.log_metrics(metrics, step) + + def save_model(self, path: str) -> None: + for file_path in glob.glob(path + '*'): + mlflow.log_artifact(file_path) + + def finish_epoch(self) -> None: + metrics = {self.to_str(key): value for key, value in self.get_metrics().items()} + + self.log_metrics(metrics, self.step) + + self.flush() + + self.step += 1 + + def __enter__(self) -> 'MLFlowTracker': + return self + + def __exit__(self, context_type, value, traceback) -> None: + mlflow.end_run() diff --git a/gatas/__init__.py b/gatas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gatas/aggregator.py b/gatas/aggregator.py new file mode 100644 index 0000000..655005d --- /dev/null +++ b/gatas/aggregator.py @@ -0,0 +1,298 @@ +from typing import Tuple, Optional, Union + +import numba +import numpy as np +import tensorflow as tf + +from gatas import common + + +class NeighbourAggregator(tf.Module): + def __init__( + self, + input_noise_rate: float, + dropout_rate: float, + num_nodes: int, + num_edge_types: int, + num_steps: int, + edge_type_embedding_size: int, + node_embedding_size: Optional[int], + layer_size: int, + num_attention_heads: int, + node_features: Optional[Union[tf.Tensor, np.ndarray]] = None) -> None: + + super().__init__() + + self.num_edge_types = num_edge_types + self.layer_size = layer_size + self.num_attention_heads = num_attention_heads + + if node_features is not None: + self.node_features = tf.convert_to_tensor(node_features, name='node_embeddings') + node_size = layer_size + node_embedding_size if node_embedding_size else layer_size + + else: + self.node_features = None + node_size = node_embedding_size if node_embedding_size else 0 + + initialize = tf.glorot_uniform_initializer() + + self.feature_transformer = tf.keras.Sequential( + layers=[ + tf.keras.layers.LayerNormalization(), + tf.keras.layers.Dropout(rate=input_noise_rate), + tf.keras.layers.Dense(layer_size, use_bias=True, activation=tf.nn.elu), + tf.keras.layers.Dropout(rate=dropout_rate), + ], + name='input_transformation', + ) + + if node_embedding_size: + self.node_embeddings = tf.Variable( + initial_value=initialize(shape=(num_nodes, node_embedding_size), dtype=tf.float32), + name='trainable_node_embeddings', + ) + + else: + self.node_embeddings = None + + self.dropout = tf.keras.layers.Dropout(dropout_rate) + + self.edge_type_embeddings = tf.Variable( + initial_value=initialize(shape=(num_edge_types, edge_type_embedding_size), dtype=tf.float32), + name='edge_type_embeddings', + ) + + self.positional_embeddings = tf.convert_to_tensor(common.compute_positional_embeddings( + max_length=num_steps, + num_features=edge_type_embedding_size, + )) + + self.coefficient_transformer = tf.keras.Sequential([ + tf.keras.layers.Dense(units=1, use_bias=False), + tf.keras.layers.Dropout(dropout_rate), + ]) + + self.value_transformer = tf.keras.Sequential([ + tf.keras.layers.Dense(units=layer_size, use_bias=True, activation=tf.nn.elu), + tf.keras.layers.Dropout(dropout_rate), + ]) + + self.attention_hidden_weights = tf.Variable( + initial_value=initialize(shape=(node_size + layer_size, layer_size, num_attention_heads), dtype=tf.float32), + name='attention_hidden_weights', + ) + + self.attention_hidden_biases = tf.Variable( + initial_value=tf.zeros(shape=(layer_size, num_attention_heads), dtype=tf.float32), + name='attention_hidden_biases', + ) + + self.attention_weights = tf.Variable( + initial_value=initialize(shape=(layer_size, num_attention_heads), dtype=tf.float32), + name='attention_weights', + ) + + self.values_weights = tf.Variable( + initial_value=initialize(shape=(layer_size, layer_size, num_attention_heads), dtype=tf.float32), + name='values_weights', + ) + + self.values_biases = tf.Variable( + initial_value=tf.zeros(shape=(layer_size, num_attention_heads), dtype=tf.float32), + name='values_biases', + ) + + def __call__( + self, + anchor_indices: tf.Tensor, + neighbour_indices: tf.Tensor, + neighbour_assignments: tf.Tensor, + neighbour_weights: tf.Tensor, + neighbour_path_indices: tf.Tensor, + training: bool = False, + concatenate: bool = True) -> tf.Tensor: + + anchor_features, neighbour_features = self.generate_features( + anchor_indices=anchor_indices, + neighbour_indices=neighbour_indices, + training=training, + ) + + if self.node_embeddings is not None: + anchor_features, neighbour_features = self.build_embeddings( + anchor_indices=anchor_indices, + neighbour_indices=neighbour_indices, + anchor_features=anchor_features, + neighbour_features=neighbour_features, + training=training, + ) + + neighbour_embeddings = self.build_embeddings_with_path( + path_indices=neighbour_path_indices, + anchor_features=anchor_features, + neighbour_features=neighbour_features, + neighbour_assignments=neighbour_assignments, + training=training, + ) + + node_representations = self.create_attended_representations( + anchor_embeddings=anchor_features, + neighbour_embeddings=neighbour_embeddings, + neighbour_assignments=neighbour_assignments, + neighbour_weights=neighbour_weights, + training=training, + concatenate=concatenate, + ) + + return node_representations + + def generate_features( + self, + anchor_indices: tf.Tensor, + neighbour_indices: tf.Tensor, + training: bool) -> Tuple[Optional[tf.Tensor], Optional[tf.Tensor]]: + + if self.node_features is None: + return None, None + + anchor_features = self.feature_transformer( + inputs=tf.nn.embedding_lookup(self.node_features, anchor_indices), + training=training, + ) + + neighbour_features = self.feature_transformer( + inputs=tf.nn.embedding_lookup(self.node_features, neighbour_indices), + training=training, + ) + + return anchor_features, neighbour_features + + def build_embeddings( + self, + anchor_indices: tf.Tensor, + neighbour_indices: tf.Tensor, + anchor_features: Optional[tf.Tensor], + neighbour_features: Optional[tf.Tensor], + training: bool) -> Tuple[tf.Tensor, tf.Tensor]: + + anchor_embeddings = self.dropout( + inputs=tf.nn.embedding_lookup(self.node_embeddings, anchor_indices), + training=training, + ) + + neighbour_embeddings = self.dropout( + inputs=tf.nn.embedding_lookup(self.node_embeddings, neighbour_indices), + training=training, + ) + + if anchor_features is not None: + anchor_embeddings = tf.concat((anchor_embeddings, anchor_features), axis=-1) + + if neighbour_features is not None: + neighbour_embeddings = tf.concat((neighbour_embeddings, neighbour_features), axis=-1) + + return anchor_embeddings, neighbour_embeddings + + def get_relation_path_sequences(self, path_indices: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + segment_indices_op, position_indices_op, edge_type_indices_op = tf.numpy_function( + func=lambda x: compute_edge_type_path_sequences(x, self.num_edge_types), + inp=(path_indices,), + Tout=(tf.int32, tf.int32, tf.int32), + ) + + segment_indices_op.set_shape((None,)) + position_indices_op.set_shape((None,)) + edge_type_indices_op.set_shape((None,)) + + return segment_indices_op, position_indices_op, edge_type_indices_op + + def build_embeddings_with_path( + self, + path_indices: tf.Tensor, + anchor_features: tf.Tensor, + neighbour_features: tf.Tensor, + neighbour_assignments: tf.Tensor, + training: bool) -> tf.Tensor: + + segment_indices, position_indices, edge_type_indices = self.get_relation_path_sequences(path_indices) + + path_embeddings = tf.add( + x=tf.gather(self.edge_type_embeddings, edge_type_indices), + y=tf.gather(self.positional_embeddings, position_indices), + ) + + coefficients = tf.gather(tf.gather(anchor_features, neighbour_assignments), segment_indices) + coefficients = tf.concat((coefficients, path_embeddings), axis=-1) + coefficients = self.coefficient_transformer(coefficients, training=training) + coefficients = common.segment_softmax(coefficients, segment_indices) + + values = tf.concat((tf.gather(neighbour_features, segment_indices), path_embeddings), axis=-1) + values = self.value_transformer(values, training=training) + + neighbour_embeddings = tf.math.segment_sum(coefficients * values, segment_indices) + + return neighbour_embeddings + + def create_attended_representations( + self, + anchor_embeddings: tf.Tensor, + neighbour_embeddings: tf.Tensor, + neighbour_assignments: tf.Tensor, + neighbour_weights: tf.Tensor, + training: bool, + concatenate: bool) -> tf.Tensor: + + embeddings = tf.concat((tf.gather(anchor_embeddings, neighbour_assignments), neighbour_embeddings), axis=1) + + attention_coefficients = tf.einsum('bd,dfa->bfa', embeddings, self.attention_hidden_weights) + attention_coefficients = tf.nn.elu(attention_coefficients + self.attention_hidden_biases[tf.newaxis, :, :]) + attention_coefficients = tf.einsum('bda,da->ba', attention_coefficients, self.attention_weights) + attention_coefficients = tf.math.log(neighbour_weights)[:, tf.newaxis] + attention_coefficients + attention_coefficients = common.segment_softmax(attention_coefficients, neighbour_assignments) + attention_coefficients = self.dropout(attention_coefficients[:, tf.newaxis, :], training=training) + + neighbour_embeddings = tf.einsum('bd,dfa->bfa', neighbour_embeddings, self.values_weights) + neighbour_embeddings = tf.nn.elu(neighbour_embeddings + self.values_biases[tf.newaxis, :, :]) + neighbour_embeddings = self.dropout(neighbour_embeddings, training=training) + neighbour_embeddings = attention_coefficients * neighbour_embeddings + + attention_heads = tf.nn.elu(tf.math.segment_sum(neighbour_embeddings, neighbour_assignments)) + + if concatenate: + anchors_attended = tf.reshape(attention_heads, (-1, self.layer_size * self.num_attention_heads)) + else: + anchors_attended = tf.reduce_mean(attention_heads, axis=-1) + + return anchors_attended + + +@numba.njit +def compute_edge_type_path_sequences( + path_indices: np.ndarray, + num_edge_types: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + + segment_indices, position_indices, edge_type_indices = [], [], [] + + for segment_index, edge_type_path_index in enumerate(path_indices): + if edge_type_path_index == 0: + segment_indices.append(segment_index) + position_indices.append(0) + edge_type_indices.append(0) + + else: + position_index = 1 + + while edge_type_path_index >= 0: + segment_indices.append(segment_index) + position_indices.append(position_index) + edge_type_indices.append(edge_type_path_index % num_edge_types) + + edge_type_path_index = (edge_type_path_index / num_edge_types) - 1 + position_index += 1 + + segment_indices = np.array(segment_indices, dtype=np.int32) + position_indices = np.array(position_indices, dtype=np.int32) + edge_type_indices = np.array(edge_type_indices, dtype=np.int32) + + return segment_indices, position_indices, edge_type_indices diff --git a/gatas/common.py b/gatas/common.py new file mode 100644 index 0000000..b1dcacc --- /dev/null +++ b/gatas/common.py @@ -0,0 +1,34 @@ +import numpy as np +import tensorflow as tf + + +def segment_softmax(logits: tf.Tensor, segment_ids: tf.Tensor) -> tf.Tensor: + logits_max = tf.math.segment_max(logits, segment_ids) + + logits_exp = tf.math.exp(logits - tf.gather(logits_max, segment_ids)) + + partitions = tf.gather(tf.math.segment_sum(logits_exp, segment_ids), segment_ids) + + softmax = logits_exp / partitions + + return softmax + + +def segment_normalize(logits: tf.Tensor, segment_ids: tf.Tensor) -> tf.Tensor: + partitions = tf.gather(tf.math.segment_sum(logits, segment_ids), segment_ids) + + probabilities = logits / partitions + + return probabilities + + +def compute_positional_embeddings(max_length: int, num_features: int) -> np.ndarray: + feature_indices, positions = np.arange(num_features, dtype=np.float32), np.arange(max_length + 1, dtype=np.float32) + + angle_rates = 1 / np.power(10000, 2 * (feature_indices // 2) / num_features) + positional_encodings = positions[:, np.newaxis] * angle_rates[np.newaxis, :] + + positional_encodings[:, 0::2] = np.sin(positional_encodings[:, 0::2]) + positional_encodings[:, 1::2] = np.cos(positional_encodings[:, 1::2]) + + return positional_encodings diff --git a/gatas/sampler.py b/gatas/sampler.py new file mode 100644 index 0000000..3cf7233 --- /dev/null +++ b/gatas/sampler.py @@ -0,0 +1,283 @@ +import math +import os +from typing import NamedTuple, Tuple + +import numba +import numpy as np +import tensorflow as tf + +from framework.dataset import io + + +class NeighbourSample(NamedTuple): + indices: tf.Tensor + segments: tf.Tensor + path_indices: tf.Tensor + weights: tf.Tensor + + +class NeighbourSampler(tf.Module): + def __init__( + self, + accumulated_transition_lengths: np.ndarray, + neighbours: np.ndarray, + path_indices: np.ndarray, + probabilities: np.ndarray, + num_edge_types: int, + num_steps: int) -> None: + + super().__init__() + + self.accumulated_transition_lengths = accumulated_transition_lengths + self.neighbours = neighbours + self.path_indices = path_indices + self.probabilities = probabilities + + self.num_edge_types = num_edge_types + self.num_steps = num_steps + + self.path_depths = compute_path_depths(num_steps, num_edge_types) + + self.coefficients = tf.Variable( + initial_value=self.compute_initial_coefficients(num_steps), + shape=(num_steps + 1,), + dtype=tf.float32, + ) + + @classmethod + def get_num_edge_types(cls, path: str, suffix: str) -> int: + edge_types = io.load_bin(os.path.join(path, f'edge_types{suffix}.bin'), dtype=np.int32) + + num_edge_types = np.max(edge_types) + 2 + + return num_edge_types + + @classmethod + def from_path(cls, num_steps: int, path: str, suffix: str = '') -> 'NeighbourSampler': + accumulated_transition_lengths = io.load_npy( + path=os.path.join(path, f'accumulated_transition_lengths{suffix}.npy'), + mmap_mode='r', + ) + neighbours = io.load_npy(os.path.join(path, f'neighbours{suffix}.npy'), mmap_mode='r') + path_indices = io.load_npy(os.path.join(path, f'path_indices{suffix}.npy'), mmap_mode='r') + probabilities = io.load_npy(os.path.join(path, f'probabilities{suffix}.npy'), mmap_mode='r') + + instance = cls( + accumulated_transition_lengths=accumulated_transition_lengths, + neighbours=neighbours, + path_indices=path_indices, + probabilities=probabilities, + num_edge_types=cls.get_num_edge_types(path, suffix), + num_steps=num_steps, + ) + + return instance + + @staticmethod + def compute_initial_coefficients(num_steps: int) -> np.ndarray: + steps = np.concatenate((np.array([0], dtype=np.float32), np.arange(num_steps, dtype=np.float32))) + + decaying_distribution = -steps / np.log(num_steps + 1, dtype=np.float32) + + return decaying_distribution + + def __call__(self, node_indices: tf.Tensor, sample_size: int, noisify: bool = True) -> NeighbourSample: + step_probabilities = tf.nn.softmax(self.coefficients) + + indices, segments, path_indices, steps, probabilities = self.generate_sample( + node_indices=node_indices, + coefficients=step_probabilities, + sample_size=sample_size, + noisify=noisify, + ) + + weights = tf.gather(step_probabilities, steps) * probabilities + + sample = NeighbourSample(indices, segments, path_indices, weights) + + return sample + + def generate_sample( + self, + node_indices: tf.Tensor, + coefficients: tf.Tensor, + sample_size: int, + noisify: bool) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + + indices, segments, path_indices, steps, probabilities= tf.numpy_function( + func=lambda x, y: generate_sample( + node_indices=x, + transition_pointers=self.accumulated_transition_lengths, + neighbours=self.neighbours, + path_indices=self.path_indices, + probabilities=self.probabilities, + coefficients=y, + path_depths=self.path_depths, + num_steps=self.num_steps, + sample_size=sample_size, + noisify=noisify, + ), + inp=(node_indices, coefficients), + Tout=(tf.int32, tf.int32, tf.int32, tf.int32, tf.float32), + ) + + indices.set_shape((None,)) + segments.set_shape((None,)) + path_indices.set_shape((None,)) + steps.set_shape((None,)) + probabilities.set_shape((None,)) + + return indices, segments, path_indices, steps, probabilities + + +@numba.njit +def generate_sample( + node_indices: np.ndarray, + transition_pointers: np.ndarray, + neighbours: np.ndarray, + path_indices: np.ndarray, + probabilities: np.ndarray, + coefficients: np.ndarray, + path_depths: np.ndarray, + num_steps: int, + sample_size: int, + noisify: bool) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + indices = [] + segments = [] + path_indices_subset = [] + steps = [] + probabilities_subset = [] + + for segment_id, node_index in enumerate(node_indices): + start_index, end_index = transition_pointers[node_index], transition_pointers[node_index + 1] + + neighbour_sparse_indices, _ = compute_top_k( + probabilities=probabilities[start_index:end_index], + path_indices=path_indices[start_index:end_index], + path_depths=path_depths, + coefficients=coefficients, + num_steps=num_steps, + k=sample_size, + noisify=noisify, + ) + + for neighbour_sparse_index in neighbour_sparse_indices: + neighbour_index = neighbour_sparse_index + start_index + + indices.append(neighbours[neighbour_index]) + segments.append(segment_id) + path_indices_subset.append(path_indices[neighbour_index]) + steps.append(path_depths[path_indices[neighbour_index]]) + probabilities_subset.append(probabilities[neighbour_index]) + + indices = np.array(indices, dtype=np.int32) + segments = np.array(segments, dtype=np.int32) + path_indices_subset = np.array(path_indices_subset, dtype=np.int32) + steps = np.array(steps, dtype=np.int32) + probabilities_subset = np.array(probabilities_subset, dtype=np.float32) + + return indices, segments, path_indices_subset, steps, probabilities_subset + + +@numba.njit +def get_path_depth(index: int, num_edge_types: int, step: int = 0) -> int: + if (index == 0 and step == 0) or index < 0: + return step + + parent_index = math.floor(index / num_edge_types) - 1 + + return get_path_depth(parent_index, num_edge_types, step + 1) + + +@numba.njit +def compute_path_depths(num_steps: int, num_edge_types: int) -> np.ndarray: + num_paths = 0 + + for step in range(num_steps): + num_paths += num_edge_types ** (step + 1) + + path_depths = np.empty((num_paths,), dtype=np.int32) + + for path_index in range(num_paths): + path_depths[path_index] = get_path_depth(path_index, num_edge_types) + + return path_depths + + +@numba.njit +def compute_top_k( + probabilities: np.ndarray, + path_indices: np.ndarray, + path_depths: np.ndarray, + coefficients: np.ndarray, + num_steps: int, + k: int, + noisify: bool) -> Tuple[np.ndarray, np.ndarray]: + + top_indices = np.full(fill_value=-1, shape=(k,), dtype=np.int32) + top_values = np.full(fill_value=-np.inf, shape=(k,), dtype=np.float32) + + if k == 0: + return top_indices, top_values + + last_index = k - 1 + + for index, probability in enumerate(probabilities): + step = path_depths[path_indices[index]] + + if step > num_steps: + continue + + value = calculate_transition_logits(probability, coefficients[step], noisify=noisify) + + if value <= top_values[0]: + continue + + # heap bubble-down operation + node_index = 0 + + top_indices[node_index], top_values[node_index] = index, value + + while True: + child_index = 2 * node_index + 1 + + swap_index = node_index + + if child_index <= last_index and top_values[node_index] > top_values[child_index]: + swap_index = child_index + + if child_index + 1 <= last_index and top_values[swap_index] > top_values[child_index + 1]: + swap_index = child_index + 1 + + if swap_index == node_index: + break + + temp_index, temp_value = top_indices[swap_index], top_values[swap_index] + top_indices[swap_index] = top_indices[node_index] + top_values[swap_index] = top_values[node_index] + top_indices[node_index], top_values[node_index] = temp_index, temp_value + + node_index = swap_index + + # extract indices from fixed-length heaps + indices = np.where(top_indices >= 0)[0] + top_indices, top_values = top_indices[indices], top_values[indices] + + return top_indices, top_values + + +@numba.njit +def calculate_transition_logits( + probabilities: np.ndarray, + coefficients: np.ndarray, + noisify: bool, + eps: float = 1e-20) -> float: + + logits = np.log(probabilities * coefficients) + + if noisify: + gumbel_sample = -np.log(-np.log(np.random.random(np.shape(probabilities)) + eps) + eps) + logits += gumbel_sample + + return logits diff --git a/gatas/transitions.py b/gatas/transitions.py new file mode 100644 index 0000000..2c413f0 --- /dev/null +++ b/gatas/transitions.py @@ -0,0 +1,155 @@ +import os +from typing import List, Tuple, MutableMapping + +import defopt +import numba +import numpy as np +from numba import types as nb_types + +from framework.dataset import io + + +TransitionType = Tuple[int, int, int] + + +@numba.experimental.jitclass([ + ('stack_primary', nb_types.List(nb_types.Tuple([numba.int64, numba.int64, numba.int64]))), + ('stack_secondary', nb_types.List(nb_types.Tuple([numba.int64, numba.int64, numba.int64]))), +]) +class Queue: + def __init__(self, element: TransitionType) -> None: + self.stack_primary = [element] + self.stack_secondary: List[TransitionType] = [self.stack_primary.pop()] + + def size(self) -> int: + return len(self.stack_primary) + len(self.stack_secondary) + + def enqueue(self, element: TransitionType) -> None: + self.stack_primary.append(element) + + def dequeue(self) -> TransitionType: + if len(self.stack_secondary) == 0: + while len(self.stack_primary) > 0: + self.stack_secondary.append(self.stack_primary.pop()) + + return self.stack_secondary.pop() + + +@numba.njit +def extend_path(path_index: int, edge_type: int, num_edge_types: int) -> int: + if path_index == 0: + return edge_type + 1 + + return (path_index + 1) * (num_edge_types + 1) + edge_type + 1 + + +@numba.njit +def traverse( + start_index: int, + accumulated_num_edges: np.ndarray, + adjacencies: np.ndarray, + edge_types: np.ndarray, + num_steps: int, + num_edge_types: int) -> Tuple[MutableMapping[Tuple[int, int, int], int], List[int]]: + + neighbour_path_counts: MutableMapping[Tuple[int, int, int], int] = {} + step_counts: List[int] = [0] * (num_steps + 1) + + visited: MutableMapping[int, int] = {} + queue = Queue((start_index, 0, 0)) + + while queue.size() > 0: + key = queue.dequeue() + node_index, path_index, step = key + + if node_index in visited and visited[node_index] < step: + continue + + visited[node_index] = step + + if key in neighbour_path_counts: + neighbour_path_counts[key] += 1 + else: + neighbour_path_counts[key] = 1 + + step_counts[step] += 1 + + if step >= num_steps: + continue + + for sparse_index in range(accumulated_num_edges[node_index], accumulated_num_edges[node_index + 1]): + neighbour_id, edge_type = adjacencies[sparse_index], edge_types[sparse_index] + extended_path_id = extend_path(path_index, edge_type, num_edge_types) + + queue.enqueue((neighbour_id, extended_path_id, step + 1)) + + return neighbour_path_counts, step_counts + + +@numba.njit +def create_transition_tensors( + accumulated_num_edges: np.ndarray, + adjacencies: np.ndarray, + edge_types: np.ndarray, + num_steps: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + num_edge_types = np.max(edge_types) + 1 + + accumulated_transition_lengths = [0] + neighbours = [] + path_indices = [] + probabilities = [] + + for start_index in range(accumulated_num_edges.size - 1): + neighbour_path_counts, step_counts = traverse( + start_index=start_index, + accumulated_num_edges=accumulated_num_edges, + adjacencies=adjacencies, + edge_types=edge_types, + num_steps=num_steps, + num_edge_types=num_edge_types, + ) + + step_probability = [1 / num_steps if num_steps != 0 else 0. for num_steps in step_counts] + accumulated_transition_lengths.append(accumulated_transition_lengths[-1] + len(neighbour_path_counts)) + + for (end_index, path_index, step), count in neighbour_path_counts.items(): + neighbours.append(end_index) + path_indices.append(path_index) + probabilities.append(count * step_probability[step]) + + accumulated_transition_lengths = np.array(accumulated_transition_lengths, dtype=np.int32) + neighbours = np.array(neighbours, dtype=np.int32) + path_indices = np.array(path_indices, dtype=np.int32) + probabilities = np.array(probabilities, dtype=np.float32) + + return accumulated_transition_lengths, neighbours, path_indices, probabilities + + +def store_transition_tensor(*, path: str, num_steps: int, suffix: str = '') -> None: + """ + Computes and stores a transition tensor. + + :param path: The path to the CSR sparse representation vectors. + :param num_steps: Number of steps to compute. + :param suffix: Suffix of the input files. + """ + accumulated_num_edges = io.load_bin(os.path.join(path, f'accumulated_num_edges{suffix}.bin'), dtype=np.int32) + adjacencies = io.load_bin(os.path.join(path, f'adjacencies{suffix}.bin'), dtype=np.int32) + edge_types = io.load_bin(os.path.join(path, f'edge_types{suffix}.bin'), dtype=np.int32) + + accumulated_transition_lengths, neighbours, path_indices, probabilities = create_transition_tensors( + accumulated_num_edges=accumulated_num_edges, + adjacencies=adjacencies, + edge_types=edge_types, + num_steps=num_steps, + ) + + np.save(os.path.join(path, f'accumulated_transition_lengths{suffix}.npy'), accumulated_transition_lengths) + np.save(os.path.join(path, f'neighbours{suffix}.npy'), neighbours) + np.save(os.path.join(path, f'path_indices{suffix}.npy'), path_indices) + np.save(os.path.join(path, f'probabilities{suffix}.npy'), probabilities) + + +if __name__ == '__main__': + defopt.run(store_transition_tensor) diff --git a/link_prediction/__init__.py b/link_prediction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/link_prediction/dataset.py b/link_prediction/dataset.py new file mode 100644 index 0000000..43d80b1 --- /dev/null +++ b/link_prediction/dataset.py @@ -0,0 +1,159 @@ +import os +from typing import Tuple, Generator, Optional, Union + +import numba +import numpy as np +from sklearn.model_selection import ShuffleSplit + +from framework.dataset import io + + +BatchType = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] + + +class LinkPredictionDataset: + def __init__( + self, + triplets: np.ndarray, + targets: np.ndarray, + num_classes: int, + random_state: int, + indices: Optional[np.ndarray] = None) -> None: + + self.triplets = triplets + self.targets = targets + self.num_classes = num_classes + self.random_state = random_state + + self.indices = indices if indices is not None else np.arange(targets.size, dtype=np.int32) + self.size = self.indices.size + + @classmethod + def from_path(cls, path: str, num_classes: int, random_state: int) -> 'LinkPredictionDataset': + triplets = io.load_npy(os.path.join(path, 'triplets.npy'), mmap_mode='r') + targets = io.load_npy(os.path.join(path, 'targets.npy'), mmap_mode='r') + + return cls(triplets, targets, num_classes, random_state) + + def get_batches(self, size: int) -> Generator[BatchType, None, None]: + np.random.shuffle(self.indices) + + for start_index in range(0, self.size, size): + end_index = min(start_index + size, self.size) + + batch_indices = self.indices[start_index:end_index] + + yield self.get_batch(self.triplets[batch_indices], self.targets[batch_indices], self.num_classes) + + @staticmethod + @numba.njit + def get_batch( + triplets: np.ndarray, + targets: np.ndarray, + num_classes: int) -> BatchType: + + size = targets.size + + node_ids = np.empty((size * 2,), dtype=np.int32) + pair_indices = np.empty((size * 2, 2), dtype=np.int32) + + edge_targets = np.zeros((size, num_classes), dtype=np.float32) + edge_targets_mask = np.zeros((size, num_classes), dtype=np.float32) + + for index in range(size): + node_ids[index * 2] = triplets[index, 0] + pair_indices[index * 2, :] = [index, 0] + + node_ids[index * 2 + 1] = triplets[index, 1] + pair_indices[index * 2 + 1, :] = [index, 1] + + edge_targets[index, triplets[index, 2]] = targets[index] + edge_targets_mask[index, triplets[index, 2]] = 1. + + return node_ids, pair_indices, edge_targets, edge_targets_mask + + def subset(self, indices: np.ndarray) -> 'LinkPredictionDataset': + dataset = LinkPredictionDataset( + triplets=self.triplets, + targets=self.targets, + num_classes=self.num_classes, + random_state=self.random_state, + indices=self.indices[indices], + ) + + return dataset + + +def get_splitted_dataset( + path: str, + num_classes: int, + random_state: int) -> Tuple[LinkPredictionDataset, LinkPredictionDataset, LinkPredictionDataset]: + + train_dataset = LinkPredictionDataset( + triplets=io.load_npy(os.path.join(path, 'triplets_train.npy')), + targets=io.load_npy(os.path.join(path, 'targets_train.npy')), + num_classes=num_classes, + random_state=random_state, + ) + + validation_dataset = LinkPredictionDataset( + triplets=io.load_npy(os.path.join(path, 'triplets_validation.npy')), + targets=io.load_npy(os.path.join(path, 'targets_validation.npy')), + num_classes=num_classes, + random_state=random_state, + ) + + test_dataset = LinkPredictionDataset( + triplets=io.load_npy(os.path.join(path, 'triplets_test.npy')), + targets=io.load_npy(os.path.join(path, 'targets_test.npy')), + num_classes=num_classes, + random_state=random_state, + ) + + return train_dataset, validation_dataset, test_dataset + + +def get_dataset_splits( + dataset: LinkPredictionDataset, + train_size: Union[int, float], + validation_size: Union[int, float], + test_size: Union[int, float], + number_splits: int, + random_state: int) -> Generator[Tuple[LinkPredictionDataset, LinkPredictionDataset, LinkPredictionDataset], None, None]: + + if isinstance(train_size, float): + train_size = round(dataset.size * train_size) + + if isinstance(validation_size, float): + validation_size = round(dataset.size * validation_size) + + if isinstance(test_size, float): + test_size = round(dataset.size * test_size) + + train_test_splitter = ShuffleSplit( + n_splits=number_splits, + train_size=train_size, + test_size=validation_size + test_size, + random_state=random_state, + ) + + validation_test_splitter = ShuffleSplit( + n_splits=1, + train_size=validation_size, + test_size=test_size, + random_state=random_state, + ) + + for train_indices, validation_test_indices in train_test_splitter.split(np.empty(dataset.size, dtype=np.bool)): + train_dataset = dataset.subset(train_indices) + + validation_indices, test_indices = \ + next(validation_test_splitter.split(np.empty(validation_test_indices.size, dtype=np.bool))) + + validation_indices = validation_test_indices[validation_indices] + test_indices = validation_test_indices[test_indices] + + validation_dataset = dataset.subset(validation_indices) + test_dataset = dataset.subset(test_indices) + + yield train_dataset, validation_dataset, test_dataset diff --git a/link_prediction/datasets/__init__.py b/link_prediction/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/link_prediction/datasets/gatne.py b/link_prediction/datasets/gatne.py new file mode 100644 index 0000000..1de2cc1 --- /dev/null +++ b/link_prediction/datasets/gatne.py @@ -0,0 +1,162 @@ +import io +import os +from typing import Generator, Tuple, Mapping, Callable, Iterator, Set + +import defopt +import numpy as np + +from link_prediction.datasets import twitter + + +def read_triplet_file(path: str) -> Generator[Tuple[int, int, int], None, None]: + with io.open(path, mode='r') as text_file: + for line in text_file: + columns = line.strip().split() + + yield int(columns[1]), int(columns[2]), int(columns[0]) - 1 + + +def read_triplet_file_with_targets(path: str) -> Generator[Tuple[int, int, int, int], None, None]: + with io.open(path, mode='r') as text_file: + for line in text_file: + columns = line.strip().split() + + yield int(columns[1]), int(columns[2]), int(columns[0]) - 1, int(columns[3]) + + +def get_node_map(path: str) -> Mapping[int, int]: + node_ids = set() + + for head, tail, _ in read_triplet_file(os.path.join(path, 'train.txt')): + node_ids.add(head) + node_ids.add(tail) + + for dataset_type in ['validation', 'test']: + for head, tail, _, _ in read_triplet_file_with_targets(os.path.join(path, f'{dataset_type}.txt')): + node_ids.add(head) + node_ids.add(tail) + + return twitter.compute_node_maps(node_ids)[0] + + +def read_mapped_triplets(node_map: Mapping[int, int], path: str) -> Generator[Tuple[int, int, int], None, None]: + for head, tail, edge_type in read_triplet_file(path): + yield node_map[head], node_map[tail], edge_type + + +def read_mapped_triplets_with_targets(node_map: Mapping[int, int], path: str) -> Generator[Tuple[int, int, int, int], None, None]: + for head, tail, edge_type, target in read_triplet_file_with_targets(path): + yield node_map[head], node_map[tail], edge_type, target + + +def get_all_samples(node_map: Mapping[int, int], path: str) -> Set[Tuple[int, int, int]]: + train = set(read_mapped_triplets(node_map, os.path.join(path, 'train.txt'))) + + validation = set(( + (head, tail, edge_type) + for head, tail, edge_type, target + in read_mapped_triplets_with_targets(node_map, os.path.join(path, 'validation.txt')) + )) + + test = set(( + (head, tail, edge_type) + for head, tail, edge_type, target + in read_mapped_triplets_with_targets(node_map, os.path.join(path, 'test.txt')) + )) + + return train | validation | test + + +def generate_negative_dataset( + node_map: Mapping[int, int], + num_edge_types: int, + path: str) -> Generator[Tuple[int, int, int], None, None]: + + all_samples = get_all_samples(node_map, path) + negative_samples = set() + + while True: + triplet = twitter.generate_sample(len(node_map), num_edge_types) + + if triplet not in all_samples and triplet not in negative_samples: + negative_samples.add(triplet) + + yield triplet + + +def generate_triplets( + positive_samples: Set[Tuple[int, int, int]], + node_map: Mapping[int, int], + num_edge_types: int, + path: str) -> Generator[Tuple[int, int, int, int], None, None]: + + negative_samples = generate_negative_dataset(node_map, num_edge_types, path) + + for index, (positive_triplet, negative_triplet) in enumerate(zip(positive_samples, negative_samples)): + yield positive_triplet[0], positive_triplet[1], positive_triplet[2], 1 + yield negative_triplet[0], negative_triplet[1], negative_triplet[2], 0 + + +def convert_triplets(generator: Callable[[], Iterator[Tuple[int, int, int]]]) -> np.ndarray: + size = sum(1 for _ in generator()) + + triplets = np.empty((size, 3), dtype=np.int32) + + indices = np.arange(size) + np.random.shuffle(indices) + + for index, (head, tail, edge_type) in zip(indices, generator()): + triplets[index] = [head, tail, edge_type] + + return triplets + + +def convert_triplets_with_targets(generator: Callable[[], Iterator[Tuple[int, int, int, int]]]) -> Tuple[np.ndarray, np.ndarray]: + size = sum(1 for _ in generator()) + + triplets = np.empty((size, 3), dtype=np.int32) + targets = np.empty((size,), dtype=np.int32) + + indices = np.arange(size) + np.random.shuffle(indices) + + for index, (head, tail, edge_type, target) in zip(indices, generator()): + triplets[index] = [head, tail, edge_type] + targets[index] = target + + return triplets, targets + + +def save_triplets(node_map: Mapping[int, int], num_edge_types: int, path: str) -> None: + triplets, targets = convert_triplets_with_targets(lambda: generate_triplets( + positive_samples=set(read_mapped_triplets(node_map, os.path.join(path, 'train.txt'))), + node_map=node_map, + num_edge_types=num_edge_types, + path=path, + )) + np.save(os.path.join(path, 'triplets_train.npy'), triplets) + np.save(os.path.join(path, 'targets_train.npy'), targets) + + triplets, targets = convert_triplets_with_targets(lambda: read_mapped_triplets_with_targets(node_map, os.path.join(path, 'validation.txt'))) + np.save(os.path.join(path, 'triplets_validation.npy'), triplets) + np.save(os.path.join(path, 'targets_validation.npy'), targets) + + triplets, targets = convert_triplets_with_targets(lambda: read_mapped_triplets_with_targets(node_map, os.path.join(path, 'test.txt'))) + np.save(os.path.join(path, 'triplets_test.npy'), triplets) + np.save(os.path.join(path, 'targets_test.npy'), targets) + + +def get_triplets_train_positive(node_map: Mapping[int, int], path: str) -> np.ndarray: + return convert_triplets(lambda: read_mapped_triplets(node_map, os.path.join(path, 'train.txt'))) + + +def main(*, path: str, num_edge_types: int) -> None: + node_map = get_node_map(path) + save_triplets(node_map, num_edge_types, path) + triplets_train_positive = get_triplets_train_positive(node_map, path) + graph = twitter.get_graph_by_triplets(triplets_train_positive, len(node_map), num_edge_types) + twitter.create_graph(graph, path) + + +if __name__ == '__main__': + defopt.run(main) diff --git a/link_prediction/datasets/twitter.py b/link_prediction/datasets/twitter.py new file mode 100644 index 0000000..6197e29 --- /dev/null +++ b/link_prediction/datasets/twitter.py @@ -0,0 +1,371 @@ +import io +import os +import random +from functools import reduce +from typing import Generator, Tuple, Mapping, Set, Optional, Iterable, List, MutableMapping + +import defopt +import numpy as np + + +FILES = [ + 'higgs-mention_network.edgelist', + 'higgs-reply_network.edgelist', + 'higgs-retweet_network.edgelist', + 'higgs-social_network.edgelist', +] + + +def read_triplet_file(path: str) -> Generator[Tuple[int, int], None, None]: + with io.open(path, mode='r') as text_file: + for line in text_file: + columns = line.split() + + yield int(columns[0]), int(columns[1]) + + +def get_triplets(path: str) -> Generator[Tuple[int, int, int], None, None]: + for edge_type, file in enumerate(FILES): + for head, tail in read_triplet_file(os.path.join(path, file)): + yield head, tail, edge_type + + +def get_mapped_triplets(path: str, node_map: Mapping[int, int]) -> Generator[Tuple[int, int, int], None, None]: + for head, tail, edge_type in get_triplets(path): + if head in node_map and tail in node_map: + yield node_map[head], node_map[tail], edge_type + + +def get_graph_by_node_map(path: str, node_map: Mapping[int, int]) -> List[MutableMapping[int, Set[int]]]: + num_nodes = len(node_map) + + graph: List[MutableMapping[int, Set[int]]] = [{} for _ in range(num_nodes)] + + for head, tail, edge_type in get_mapped_triplets(path, node_map): + graph[head].setdefault(tail, set()).add(edge_type) + + return graph + + +def get_graph_by_triplets(triplets: np.ndarray, num_nodes: int, num_edge_types: int) -> List[MutableMapping[int, Set[int]]]: + graph: List[MutableMapping[int, Set[int]]] = [{} for _ in range(num_nodes)] + + for head, tail, edge_type in triplets: + graph[head].setdefault(tail, set()).add(edge_type) + graph[tail].setdefault(head, set()).add(edge_type + num_edge_types) + + return graph + + +def get_largest_subgraph(graph: List[MutableMapping[int, Set[int]]]) -> Set[int]: + num_nodes = len(graph) + visited: Set[int] = set() + subgraphs = [] + + while len(visited) < num_nodes: + subgraph = set() + + stack = [next( + node_index + for node_index in range(num_nodes) + if node_index not in visited + )] + + while stack: + node = stack.pop() + + if node not in subgraph: + subgraph.add(node) + visited.add(node) + + for neighbour in graph[node].keys(): + stack.append(neighbour) + + subgraphs.append(subgraph) + + sorted_subgraphs = sorted( + [(subgraph, len(subgraph)) for subgraph in subgraphs], + key=lambda graph_tuple: graph_tuple[1], + reverse=True, + ) + + return sorted_subgraphs[0][0] + + +def compute_node_maps(node_ids: Iterable[int]) -> Tuple[Mapping[int, int], Mapping[int, int]]: + id_to_index = {node_id: index for index, node_id in enumerate(node_ids)} + index_to_id = {index: node_id for index, node_id in enumerate(node_ids)} + + return id_to_index, index_to_id + + +def get_biggest_nodes(graph: List[Mapping[int, Set[int]]], subgraph: Set[int], num_nodes: int) -> List[int]: + node_lengths = [ + (node_index, sum(len(edge_types) for edge_types in graph[node_index].values())) + for node_index in subgraph + ] + + node_lengths_sorted = sorted(node_lengths, key=lambda length: length[1], reverse=True) + + node_indices = [node_index for node_index, _ in node_lengths_sorted[:num_nodes]] + + return node_indices + + +def create_node_ids(path: str, num_nodes: Optional[int] = None) -> Mapping[int, int]: + node_ids = set() + + for head, tail, _ in get_triplets(path): + node_ids.add(head) + node_ids.add(tail) + + node_id_to_index, node_index_to_id = compute_node_maps(node_ids) + + if num_nodes is not None: + graph = get_graph_by_node_map(path, node_id_to_index) + + largest_subgraph = get_largest_subgraph(graph) + + # node_indices = get_biggest_nodes(graph, largest_subgraph, num_nodes) + node_indices = random.sample(largest_subgraph, num_nodes) + + node_ids = {node_index_to_id[node_index] for node_index in node_indices} + + np.save(os.path.join(path, 'node_ids.npy'), np.array(node_ids, dtype=np.int32)) + + return compute_node_maps(node_ids)[0] + + +def generate_sample(num_nodes: int, num_edge_types: int) -> Tuple[int, int, int]: + head = random.randint(0, num_nodes - 1) + tail = random.randint(0, num_nodes - 1) + edge_type = random.randint(0, num_edge_types - 1) + + return head, tail, edge_type + + +def generate_negative_dataset(path: str, node_map: Mapping[int, int], num_edge_types: int) -> Generator[Tuple[int, int, int], None, None]: + num_nodes = len(node_map) + + positive_samples, negative_samples = set(get_mapped_triplets(path, node_map)), set() + + while True: + triplet = generate_sample(num_nodes, num_edge_types) + + if triplet not in positive_samples and triplet not in negative_samples: + negative_samples.add(triplet) + + yield triplet + + +def create_triplets(path: str, node_map: Mapping[int, int], num_edge_types: int) -> Tuple[np.ndarray, np.ndarray]: + positive_samples = set(get_mapped_triplets(path, node_map)) + negative_samples = generate_negative_dataset(path, node_map, num_edge_types) + + dataset_size = len(positive_samples) * 2 + + triplets = np.empty((dataset_size, 3), dtype=np.int32) + targets = np.empty((dataset_size,), dtype=np.int32) + + for index, (positive_triplet, negative_triplet) in enumerate(zip(positive_samples, negative_samples)): + triplets[index * 2, :] = positive_triplet + targets[index * 2] = 1 + + triplets[index * 2 + 1, :] = negative_triplet + targets[index * 2 + 1] = 0 + + return triplets, targets + + +def split_triplets( + triplets: np.ndarray, + targets: np.ndarray, + train_size: float, + validation_size: float, + test_size: float, + path: str) -> np.ndarray: + + size = triplets.shape[0] + + indices = np.arange(size) + np.random.shuffle(indices) + + train_size = round(train_size * size) + validation_size = round(validation_size * size) + test_size = round(test_size * size) + + train_indices = indices[:train_size] + np.save(os.path.join(path, 'triplets_train.npy'), triplets[train_indices]) + np.save(os.path.join(path, 'targets_train.npy'), targets[train_indices]) + + validation_indices = indices[train_size:train_size + validation_size] + np.save(os.path.join(path, 'triplets_validation.npy'), triplets[validation_indices]) + np.save(os.path.join(path, 'targets_validation.npy'), targets[validation_indices]) + + test_indices = indices[train_size + validation_size:train_size + validation_size + test_size] + np.save(os.path.join(path, 'triplets_test.npy'), triplets[test_indices]) + np.save(os.path.join(path, 'targets_test.npy'), targets[test_indices]) + + triplets_train_positive = triplets[train_indices][targets[train_indices].astype(np.bool)] + + return triplets_train_positive + + +def split_triplets_variation( + triplets: np.ndarray, + targets: np.ndarray, + base_size: float, + train_size: float, + validation_size: float, + test_size: float, + path: str) -> np.ndarray: + + size = triplets.shape[0] + + indices = np.arange(size) + np.random.shuffle(indices) + + base_size = round(base_size * size) + train_size = round(train_size * size) + validation_size = round(validation_size * size) + test_size = round(test_size * size) + + base_indices = indices[:base_size] + + train_indices = indices[base_size:base_size + train_size] + np.save(os.path.join(path, 'triplets_train.npy'), triplets[train_indices]) + np.save(os.path.join(path, 'targets_train.npy'), targets[train_indices]) + + validation_indices = indices[base_size + train_size:base_size + train_size + validation_size] + np.save(os.path.join(path, 'triplets_validation.npy'), triplets[validation_indices]) + np.save(os.path.join(path, 'targets_validation.npy'), targets[validation_indices]) + + test_indices = indices[base_size + train_size + validation_size:base_size + train_size + validation_size + test_size] + np.save(os.path.join(path, 'triplets_test.npy'), triplets[test_indices]) + np.save(os.path.join(path, 'targets_test.npy'), targets[test_indices]) + + triplets_base_positive = triplets[base_indices][targets[base_indices].astype(np.bool)] + + return triplets_base_positive + + +def create_graph(graph: List[MutableMapping[int, Set[int]]], path: str) -> None: + accumulated_num_edges, adjacencies, edge_types = [0], [], [] + + for tails in graph: + edges = sorted( + (tail, edge_type) + for tail, pair_edge_types in tails.items() + for edge_type in pair_edge_types + ) + + for tail, edge_type in edges: + adjacencies.append(tail) + edge_types.append(edge_type) + + accumulated_num_edges.append(accumulated_num_edges[-1] + len(edges)) + + np.array(accumulated_num_edges, dtype=np.int32).tofile(os.path.join(path, 'accumulated_num_edges.bin')) + np.array(adjacencies, dtype=np.int32).tofile(os.path.join(path, 'adjacencies.bin')) + np.array(edge_types, dtype=np.int32).tofile(os.path.join(path, 'edge_types.bin')) + + +def count_edge_types( + graph: List[Mapping[int, Set[int]]], + node_ids: Iterable[int], + num_edge_types: int) -> List[List[int]]: + + node_counts = [] + + for head in node_ids: + edge_type_count = [0] * num_edge_types + + for edge_types in graph[head].values(): + for edge_type in edge_types: + edge_type_count[edge_type] += 1 + + node_counts.append(edge_type_count) + + return node_counts + + +def select_nodes( + graph: List[Mapping[int, Set[int]]], + node_ids: List[int], + num_nodes: int, + num_edge_types: int) -> List[int]: + + edge_types_counts = count_edge_types(graph, node_ids, num_edge_types) + + num_nodes_per_edge_type = round(num_nodes / num_edge_types) + + node_indices_subset: Set[int] = set() + + for edge_type in range(num_edge_types): + edge_type_counts = sorted( + ( + (node_index, counts[edge_type]) + for node_index, counts + in enumerate(edge_types_counts) + if node_index not in node_indices_subset + ), + key=lambda node_counts: node_counts[1], + reverse=True, + ) + + node_indices_subset.update(node_index for node_index, _ in edge_type_counts[:num_nodes_per_edge_type]) + + node_ids_subset = [node_ids[node_index] for node_index in node_indices_subset] + + return node_ids_subset + + +def compute_overlap_node_ids(path: str, num_edge_types: int, num_nodes: Optional[int] = None) -> Mapping[int, int]: + sets: MutableMapping[int, Set[int]] = {} + + for head, tail, edge_type in get_triplets(path): + if edge_type not in sets: + sets[edge_type] = set() + + sets[edge_type].add(head) + sets[edge_type].add(tail) + + node_ids = reduce(lambda x, y: x & y, sets.values()) + + node_id_to_index, node_index_to_id = compute_node_maps(node_ids) + + if num_nodes is not None: + graph = get_graph_by_node_map(path, node_id_to_index) + + largest_subgraph = get_largest_subgraph(graph) + + node_indices = random.sample(largest_subgraph, num_nodes) + # node_indices = get_biggest_nodes(graph, largest_subgraph, num_nodes) + # node_indices = select_nodes(graph, list(largest_subgraph), num_nodes, num_edge_types) + + node_ids = {node_index_to_id[node_index] for node_index in node_indices} + + np.save(os.path.join(path, 'node_ids.npy'), np.array(node_ids, dtype=np.int32)) + + return compute_node_maps(node_ids)[0] + + +def main( + *, + path: str, + train_size: float = .85, + validation_size: float = .05, + test_size: float = .1, + num_nodes: Optional[int] = None) -> None: + + num_edge_types = len(FILES) + + node_map = compute_overlap_node_ids(path, num_edge_types, num_nodes) + triplets, targets = create_triplets(path, node_map, num_edge_types) + triplets_train_positive = split_triplets(triplets, targets, train_size, validation_size, test_size, path) + graph = get_graph_by_triplets(triplets_train_positive, len(node_map), num_edge_types) + create_graph(graph, path) + + +if __name__ == '__main__': + defopt.run(main) diff --git a/link_prediction/metrics.py b/link_prediction/metrics.py new file mode 100644 index 0000000..c38e90f --- /dev/null +++ b/link_prediction/metrics.py @@ -0,0 +1,52 @@ +import numpy as np +import sklearn.metrics as sk_metrics + +from framework.trackers import Numeric + + +def f1_score_micro(targets: np.ndarray, probabilities: np.ndarray, weights: np.ndarray) -> Numeric: + predictions = np.round(probabilities) + + mask = weights.astype(np.bool) + + score = sk_metrics.f1_score(targets[mask], predictions[mask]) + + return score + + +def f1_score_macro(targets: np.ndarray, probabilities: np.ndarray, weights: np.ndarray) -> Numeric: + predictions = np.round(probabilities) + + scores = [] + + class_data = zip(np.transpose(targets), np.transpose(predictions), np.transpose(weights)) + + for class_targets, class_predictions, class_weights in class_data: + if np.sum(class_weights) == 0: + continue + + score = sk_metrics.f1_score(class_targets, class_predictions, sample_weight=class_weights) + + scores.append(score) + + score = sum(scores) / len(scores) + + return score + + +def roc_auc(targets: np.ndarray, probabilities: np.ndarray, weights: np.ndarray) -> Numeric: + scores = [] + + class_data = zip(np.transpose(targets), np.transpose(probabilities), np.transpose(weights)) + + for class_targets, class_probabilities, class_weights in class_data: + if np.sum(class_weights) == 0: + continue + + score = sk_metrics.roc_auc_score(class_targets, class_probabilities, sample_weight=class_weights) + + scores.append(score) + + score = sum(scores) / len(scores) + + return score diff --git a/link_prediction/model.py b/link_prediction/model.py new file mode 100644 index 0000000..dfd20ef --- /dev/null +++ b/link_prediction/model.py @@ -0,0 +1,133 @@ +from typing import Optional, Union + +import numpy as np +import tensorflow as tf + +from framework.common.parameters import DataSpecification +from gatas.aggregator import NeighbourAggregator +from gatas.sampler import NeighbourSampler + + +class LinkPredictor(tf.Module): + def __init__( + self, + path: str, + num_nodes: int, + layer_size: int, + layer_size_classifier: int, + num_attention_heads: int, + edge_type_embedding_size: int, + node_embedding_size: Optional[int], + num_steps: int, + sample_size: int, + input_noise_rate: float, + dropout_rate: float, + lambda_coefficient: float, + num_classes: int, + group_size: int, + node_features: Optional[Union[tf.Tensor, np.ndarray]] = None): + + super().__init__() + + self.sample_size = sample_size + self.lambda_coefficient = lambda_coefficient + self.group_size = group_size + + self.neighbour_sampler = NeighbourSampler.from_path(num_steps, path) + + self.neighbour_aggregator = NeighbourAggregator( + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + num_nodes=num_nodes, + num_edge_types=self.neighbour_sampler.num_edge_types, + num_steps=num_steps, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + layer_size=layer_size, + num_attention_heads=num_attention_heads, + node_features=node_features, + ) + + self.node_transformer = tf.keras.layers.Dense(units=layer_size_classifier, activation=tf.nn.elu) + + self.classifier = tf.keras.Sequential([ + tf.keras.layers.Dense(units=layer_size_classifier, activation=tf.nn.elu), + tf.keras.layers.Dense(units=layer_size_classifier, activation=tf.nn.elu), + tf.keras.layers.Dense(units=num_classes, use_bias=False), + ]) + + @classmethod + def get_schema(cls, num_classes: int) -> DataSpecification: + schema = DataSpecification([ + tf.TensorSpec(tf.TensorShape((None,)), tf.int32, 'anchor_indices'), + tf.TensorSpec(tf.TensorShape((None, 2)), tf.int32, 'group_indices'), + tf.TensorSpec(tf.TensorShape((None, num_classes)), tf.int32, 'targets'), + tf.TensorSpec(tf.TensorShape((None, num_classes)), tf.float32, 'target_weights'), + ]) + + return schema + + def __call__(self, anchor_indices: tf.Tensor, group_indices: tf.Tensor, training: Union[tf.Tensor, bool]): + neighbour_sample = self.neighbour_sampler(anchor_indices, self.sample_size) + + node_representations = self.neighbour_aggregator( + anchor_indices=anchor_indices, + neighbour_indices=neighbour_sample.indices, + neighbour_assignments=neighbour_sample.segments, + neighbour_weights=neighbour_sample.weights, + neighbour_path_indices=neighbour_sample.path_indices, + training=training, + ) + + node_representations = self.node_transformer(node_representations) + + num_groups = tf.reduce_max(group_indices, axis=0)[0] + 1 + node_representations_size = node_representations.shape[1] + + node_representations_grouped = tf.scatter_nd( + indices=group_indices, + updates=node_representations, + shape=(num_groups, self.group_size, node_representations_size), + ) + + logits = self.classifier(tf.reshape( + tensor=node_representations_grouped, + shape=(num_groups, self.group_size * node_representations_size), + )) + + probabilities = tf.nn.sigmoid(logits) + + return logits, probabilities + + def calculate_loss(self, logits: tf.Tensor, targets: tf.Tensor, target_weights: tf.Tensor) -> tf.Tensor: + loss = tf.losses.sigmoid_cross_entropy( + multi_class_labels=targets, + logits=logits, + weights=target_weights, + reduction=tf.losses.Reduction.NONE, + ) + + l2_norm = tf.add_n([ + tf.nn.l2_loss(variable) + for variable in tf.trainable_variables() + if variable.name not in {'coefficient_indices:0', 'edge_type_embeddings:0'} + ]) + + loss = tf.reduce_mean(tf.reduce_sum(loss, axis=-1)) + self.lambda_coefficient * l2_norm + + return loss + + @classmethod + def get_clipped_gradient_updates( + cls, + loss: tf.Tensor, + optimizer: tf.train.Optimizer, + max_gradient_norm: float = 5.0) -> tf.Tensor: + + gradients, variables = zip(*optimizer.compute_gradients(loss)) + + clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) + + updates = optimizer.apply_gradients(zip(clipped_gradients, variables)) + + return updates diff --git a/link_prediction/train.py b/link_prediction/train.py new file mode 100644 index 0000000..9a3bfda --- /dev/null +++ b/link_prediction/train.py @@ -0,0 +1,424 @@ +import os +import traceback +from typing import Tuple, Callable, NamedTuple, List, Optional + +import defopt +import numpy as np +import tensorflow as tf + +from framework.common import parameters +from framework.trackers.aggregator import Aggregation, Statistic +from framework.trackers.metrics import MetricFunctions, Metric +from framework.trackers.tracker_mlflow import MLFlowTracker +from link_prediction import dataset +from link_prediction import metrics as metric_functions +from link_prediction.model import LinkPredictor + + +TRAINING_TARGET_WEIGHTS = 'Training Target Weights' +VALIDATION_TARGET_WEIGHTS = 'Validation Target Weights' +TESTING_TARGET_WEIGHTS = 'Testing Target Weights' + +TRAINING_ROC_AUC = 'Training ROC AUC' +VALIDATION_ROC_AUC = 'Validation ROC AUC' +TESTING_ROC_AUC = 'Testing ROC AUC' + + +class ModelMetrics(NamedTuple): + test_roc_auc: float + test_macro_f1: float + test_micro_f1: float + + +class NodeClassifierTrainer: + def __init__( + self, + data_path: str, + model_path: str, + num_nodes: int, + num_classes: int, + max_steps: int, + sample_size: int, + layer_size: int, + layer_size_classifier: int, + num_attention_heads: int, + edge_type_embedding_size: int, + node_embedding_size: Optional[int], + input_noise_rate: float, + dropout_rate: float, + lambda_coefficient: float, + learning_rate: float) -> None: + + self.data_path = data_path + + self.checkpoint_path = os.path.join(model_path, 'model.ckpt') + + self.iterator = self.create_iterator(num_classes) + + self.model = LinkPredictor( + path=data_path, + num_nodes=num_nodes, + layer_size=layer_size, + layer_size_classifier=layer_size_classifier, + num_attention_heads=num_attention_heads, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + num_steps=max_steps, + sample_size=sample_size, + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + lambda_coefficient=lambda_coefficient, + num_classes=num_classes, + group_size=2, + ) + + self.anchor_indices, self.group_indices, self.targets, self.target_weights = self.iterator.get_next() + self.training = tf.placeholder_with_default(False, shape=(), name='training') + self.logits, self.probabilities = self.model(self.anchor_indices, self.group_indices, self.training) + self.loss = self.model.calculate_loss(self.logits, self.targets, self.target_weights) + self.updates = self.model.get_clipped_gradient_updates(self.loss, optimizer=tf.contrib.opt.NadamOptimizer(learning_rate)) + + self.tracker = MLFlowTracker('link-prediction', aggregators=[ + Aggregation( + Metric.TRAINING_MEAN_COST, + [Statistic.TRAINING_COST], + MetricFunctions.mean), + Aggregation( + Metric.VALIDATION_MEAN_COST, + [Statistic.VALIDATION_COST], + MetricFunctions.mean), + Aggregation( + Metric.TESTING_MEAN_COST, + [Statistic.TESTING_COST], + MetricFunctions.mean), + Aggregation( + Metric.TRAINING_MICRO_F1, + [Statistic.TRAINING_TARGET, Statistic.TRAINING_PROBABILITY, TRAINING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_micro(x, y, z)), + Aggregation( + Metric.VALIDATION_MICRO_F1, + [Statistic.VALIDATION_TARGET, Statistic.VALIDATION_PROBABILITY, VALIDATION_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_micro(x, y, z)), + Aggregation( + Metric.TESTING_MICRO_F1, + [Statistic.TESTING_TARGET, Statistic.TESTING_PROBABILITY, TESTING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_micro(x, y, z)), + Aggregation( + Metric.TRAINING_MACRO_F1, + [Statistic.TRAINING_TARGET, Statistic.TRAINING_PROBABILITY, TRAINING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_macro(x, y, z)), + Aggregation( + Metric.VALIDATION_MACRO_F1, + [Statistic.VALIDATION_TARGET, Statistic.VALIDATION_PROBABILITY, VALIDATION_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_macro(x, y, z)), + Aggregation( + Metric.TESTING_MACRO_F1, + [Statistic.TESTING_TARGET, Statistic.TESTING_PROBABILITY, TESTING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.f1_score_macro(x, y, z)), + Aggregation( + TRAINING_ROC_AUC, + [Statistic.TRAINING_TARGET, Statistic.TRAINING_PROBABILITY, TRAINING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.roc_auc(x, y, z)), + Aggregation( + VALIDATION_ROC_AUC, + [Statistic.VALIDATION_TARGET, Statistic.VALIDATION_PROBABILITY, VALIDATION_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.roc_auc(x, y, z)), + Aggregation( + TESTING_ROC_AUC, + [Statistic.TESTING_TARGET, Statistic.TESTING_PROBABILITY, TESTING_TARGET_WEIGHTS], + lambda x, y, z: metric_functions.roc_auc(x, y, z)), + ]) + + self.tracker.set_tags({ + 'Tier': 'Development', + 'Problem': 'Link Prediction', + }) + + @staticmethod + def create_iterator(num_classes: int) -> tf.data.Iterator: + iterator = tf.data.Iterator.from_structure( + output_types=LinkPredictor.get_schema(num_classes).get_types(), + output_shapes=LinkPredictor.get_schema(num_classes).get_shapes(), + ) + + return iterator + + @staticmethod + def initialize_iterator( + iterator: tf.data.Iterator, + generator: Callable, + arguments: Tuple, + queue_size: int = -1) -> tf.Operation: + + validation_dataset = tf.data.Dataset \ + .from_generator( + generator=generator, + output_types=iterator.output_types, + output_shapes=iterator.output_shapes, + args=arguments, + ) \ + .prefetch(queue_size) + + iterator = iterator.make_initializer(validation_dataset) + + return iterator + + def train_early_stopping( + self, + session: tf.Session, + train_iterator: tf.Operation, + validation_iterator: tf.Operation, + test_iterator: tf.Operation, + save_path: str, + num_epochs: int, + early_stopping_threshold: int, + previous_best_metric: float = -np.inf, + saver_variables: Optional[List[tf.Tensor]] = None) -> ModelMetrics: + + saver = tf.train.Saver(var_list=saver_variables) + + non_improvement_times = 0 + + for epoch in range(num_epochs): + session.run(train_iterator) + + try: + while True: + cost_, probabilities_, targets_, target_weights_, _ = session.run( + fetches=(self.loss, self.probabilities, self.targets, self.target_weights, self.updates), + feed_dict={self.training: True}, + ) + + self.tracker.add_statistics({ + Statistic.TRAINING_COST: cost_, + Statistic.TRAINING_PROBABILITY: probabilities_, + Statistic.TRAINING_TARGET: targets_, + TRAINING_TARGET_WEIGHTS: target_weights_, + }) + + except tf.errors.OutOfRangeError: + pass + + session.run(validation_iterator) + + try: + while True: + cost_, probabilities_, targets_, target_weights_ = session.run( + fetches=(self.loss, self.probabilities, self.targets, self.target_weights), + ) + + self.tracker.add_statistics({ + Statistic.VALIDATION_COST: cost_, + Statistic.VALIDATION_PROBABILITY: probabilities_, + Statistic.VALIDATION_TARGET: targets_, + VALIDATION_TARGET_WEIGHTS: target_weights_, + }) + + except tf.errors.OutOfRangeError: + pass + + training_loss = self.tracker.compute_metric(Metric.TRAINING_MEAN_COST) + train_metric = self.tracker.compute_metric(TRAINING_ROC_AUC) + validation_metric = self.tracker.compute_metric(VALIDATION_ROC_AUC) + + print(f'Epoch: {epoch}, training loss: {training_loss}, train metric: {train_metric}, validation metric: {validation_metric}') + + try: + if validation_metric > previous_best_metric: + non_improvement_times, previous_best_metric = 0, validation_metric + + saver.save(session, save_path) + + elif non_improvement_times < early_stopping_threshold: + non_improvement_times += 1 + + else: + print('Stopping after no improvement.') + break + + finally: + self.tracker.finish_epoch() + + saver.restore(session, save_path) + + session.run(test_iterator) + + try: + while True: + cost_, probabilities_, targets_, target_weights_ = session.run( + fetches=(self.loss, self.probabilities, self.targets, self.target_weights), + ) + + self.tracker.add_statistics({ + Statistic.TESTING_COST: cost_, + Statistic.TESTING_PROBABILITY: probabilities_, + Statistic.TESTING_TARGET: targets_, + TESTING_TARGET_WEIGHTS: target_weights_, + }) + + except tf.errors.OutOfRangeError: + pass + + metrics = ModelMetrics( + test_roc_auc=self.tracker.compute_metric(TESTING_ROC_AUC), + test_macro_f1=self.tracker.compute_metric(Metric.TESTING_MACRO_F1), + test_micro_f1=self.tracker.compute_metric(Metric.TESTING_MICRO_F1), + ) + + self.tracker.clear() + self.tracker.save_model(save_path) + + return metrics + + def multi_step_train_with_early_stopping( + self, + data_path: str, + num_folds: int, + batch_size: int, + max_num_epochs: int, + maximum_non_improvement_epochs: int, + num_classes: int, + random_state: int) -> None: + + initializers = tf.global_variables_initializer() + fold_metrics = [] + + splits = (dataset.get_splitted_dataset(data_path, num_classes, random_state) for _ in range(num_folds)) + + for fold_index, (train_dataset, validation_dataset, test_dataset) in enumerate(splits): + train_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=train_dataset.get_batches, + arguments=(batch_size,)) + + validation_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=validation_dataset.get_batches, + arguments=(batch_size,)) + + test_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=test_dataset.get_batches, + arguments=(batch_size,)) + + with tf.Session() as session: + session.run(initializers) + + metrics = self.train_early_stopping( + session=session, + train_iterator=train_iterator, + validation_iterator=validation_iterator, + test_iterator=test_iterator, + save_path=self.checkpoint_path, + num_epochs=max_num_epochs, + early_stopping_threshold=maximum_non_improvement_epochs, + ) + + print(f'\nFold: {fold_index + 1}, ROC AUC: {metrics.test_roc_auc}, Macro F1 Score: {metrics.test_macro_f1}') + self.tracker.log_metrics({ + 'Fold ROC AUC': metrics.test_roc_auc, + 'Fold Macro F1 Score': metrics.test_macro_f1, + 'Fold Micro F1 Score': metrics.test_micro_f1, + }, fold_index + 1) + + fold_metrics.append(metrics) + + roc_auc_mean = float(np.mean([metrics.test_roc_auc for metrics in fold_metrics])) + roc_auc_std = float(np.std([metrics.test_roc_auc for metrics in fold_metrics])) + macro_f1_mean = float(np.mean([metrics.test_macro_f1 for metrics in fold_metrics])) + macro_f1_std = float(np.std([metrics.test_macro_f1 for metrics in fold_metrics])) + + print(f'\nROC AUC: {roc_auc_mean} (±{roc_auc_std}), F1 Macro: {macro_f1_mean} (±{macro_f1_std})') + self.tracker.log_metrics({ + 'Test ROC AUC Mean': roc_auc_mean, + 'Test ROC AUC Standard Deviation': roc_auc_std, + 'Test Macro F1 Score Mean': macro_f1_mean, + 'Test Macro F1 Score Standard Deviation': macro_f1_std, + }) + + +def train( + *, + data_path: str, + model_path: str = '.', + num_nodes: int = 2000, + num_classes: int = 5, + max_steps: int = 2, + sample_size: int = 100, + learning_rate: float = .001, + lambda_coefficient: float = 0, + batch_size: int = 100, + input_noise_rate: float = 0., + dropout_rate: float = 0., + layer_size: int = 50, + layer_size_classifier: int = 250, + num_attention_heads: int = 10, + edge_type_embedding_size: int = 50, + node_embedding_size: Optional[int] = 50, + max_num_epochs: int = 1000, + num_folds: int = 10, + maximum_non_improvement_epochs: int = 5, + random_state: int = 110069) -> None: + """ + Trains a link predictor. + + :param data_path: Path to data. + :param model_path: Path to model. + :param num_nodes: Number of nodes in the graph. + :param num_classes: Number of edge types to classify. + :param max_steps: Maximum random walk steps. + :param sample_size: Neighbourhood sample size. + :param learning_rate: Learning rate for the optimizer. + :param lambda_coefficient: L2 loss coefficient. + :param batch_size: Batch size for stochastic gradient descend. + :param input_noise_rate: Node feature drop rate during training. + :param dropout_rate: Dropout probability. + :param layer_size: The size of the output for each layer in the neighbour aggregator. + :param layer_size_classifier: The size of the output for each layer in the classifier. + :param num_attention_heads: The number of attention heads for a GATAS node. + :param edge_type_embedding_size: The size of the trainable edge type embeddings. + :param node_embedding_size: The size of the trainable node embeddings, if any. + :param max_num_epochs: Maximum number of epochs to train for. + :param num_folds: Number of runs. + :param maximum_non_improvement_epochs: Number of epochs for early stopping (patience). + :param random_state: Random seed for dataset random processes. + """ + trainer = NodeClassifierTrainer( + data_path=data_path, + model_path=model_path, + num_nodes=num_nodes, + num_classes=num_classes, + max_steps=max_steps, + sample_size=sample_size, + layer_size=layer_size, + layer_size_classifier=layer_size_classifier, + num_attention_heads=num_attention_heads, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + lambda_coefficient=lambda_coefficient, + learning_rate=learning_rate, + ) + + trainer.tracker.register_parameters(parameters.get_script_parameters(train)) + + with trainer.tracker: + try: + trainer.multi_step_train_with_early_stopping( + data_path=data_path, + num_folds=num_folds, + batch_size=batch_size, + max_num_epochs=max_num_epochs, + maximum_non_improvement_epochs=maximum_non_improvement_epochs, + num_classes=num_classes, + random_state=random_state, + ) + + except Exception as error: + trainer.tracker.set_tags({'Error': traceback.format_exc()}) + raise error + + +if __name__ == '__main__': + defopt.run(train) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..976ba02 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = True diff --git a/node_classification/__init__.py b/node_classification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/node_classification/dataset.py b/node_classification/dataset.py new file mode 100644 index 0000000..29c903c --- /dev/null +++ b/node_classification/dataset.py @@ -0,0 +1,132 @@ +import math +import os +from typing import Tuple, Generator, Optional + +import numba +import numpy as np +from sklearn.model_selection import StratifiedShuffleSplit + +from framework.dataset import io + + +@numba.experimental.jitclass([ + ('num_nodes', numba.int32), + ('node_indices', numba.int32[:]), + ('class_indices', numba.int32[:, :]), + # ('class_indices', numba.int32[:]), + ('random_state', numba.int32), +]) +class NodeClassifierDataset: + def __init__(self, node_indices: np.ndarray, class_indices: np.ndarray, random_state: int) -> None: + self.num_nodes = node_indices.size + + self.node_indices = node_indices + self.class_indices = class_indices + + self.random_state = random_state + + def get_batches(self, size: int) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: + indices = np.arange(self.num_nodes) + + np.random.seed(self.random_state) + + np.random.shuffle(indices) + + for start_index in range(0, self.num_nodes, size): + end_index = min(start_index + size, self.num_nodes) + + batch_indices = indices[start_index:end_index] + + yield self.node_indices[batch_indices], self.class_indices[batch_indices] + + def get_subset(self, indices: np.ndarray) -> 'NodeClassifierDataset': + dataset = NodeClassifierDataset(self.node_indices[indices], self.class_indices[indices], self.random_state) + + return dataset + + +def get_dataset_splits( + path: str, + train_ratio: float, + validation_ratio: Optional[float] = None, + test_ratio: Optional[float] = None, + random_state: int = 110069) -> Tuple[NodeClassifierDataset, NodeClassifierDataset, NodeClassifierDataset]: + + class_ids = io.load_npy(os.path.join(path, 'class_ids.npy'), mmap_mode='r') + num_nodes = np.size(class_ids) + + indices = np.arange(num_nodes, dtype=np.int32) + np.random.RandomState(random_state).shuffle(indices) + + train_split = math.ceil(num_nodes * train_ratio) + train_indices = indices[:train_split] + train_dataset = NodeClassifierDataset(train_indices, class_ids[train_indices], random_state) + + validation_split = train_split + math.ceil(num_nodes * validation_ratio) if validation_ratio else num_nodes + validation_indices = indices[train_split:validation_split] + validation_dataset = NodeClassifierDataset(validation_indices, class_ids[validation_indices], random_state) + + if test_ratio is None: + return train_dataset, validation_dataset, validation_dataset + + test_split = validation_split + math.ceil(num_nodes * test_ratio) + test_indices = indices[validation_split:test_split] + test_dataset = NodeClassifierDataset(test_indices, class_ids[test_indices], random_state) + + return train_dataset, validation_dataset, test_dataset + + +def get_stratified_dataset_splits( + path: str, + train_size: int, + validation_size: int, + test_size: Optional[int] = None, + number_splits: int = 10, + random_state: int = 110069) -> Generator[Tuple[NodeClassifierDataset, NodeClassifierDataset, NodeClassifierDataset], None, None]: + + train_test_splitter = StratifiedShuffleSplit( + n_splits=number_splits, + train_size=train_size, + test_size=validation_size + (test_size if test_size else 0), + random_state=random_state) + + validation_test_splitter = StratifiedShuffleSplit( + n_splits=1, + train_size=validation_size, + test_size=(test_size if test_size else 0), + random_state=random_state) + + class_ids = io.load_npy(os.path.join(path, 'class_ids.npy'), mmap_mode='r') + + for train_indices, validation_test_indices in train_test_splitter.split(np.zeros_like(class_ids), class_ids): + train_dataset = NodeClassifierDataset(train_indices.astype(np.int32), class_ids[train_indices], random_state) + + if test_size is None: + validation_dataset = NodeClassifierDataset(validation_test_indices.astype(np.int32), class_ids[validation_test_indices], random_state) + + yield train_dataset, validation_dataset, validation_dataset + + validation_test_class_ids = class_ids[validation_test_indices] + validation_indices, test_indices = next(validation_test_splitter.split(np.zeros_like(validation_test_class_ids), validation_test_class_ids)) + validation_indices = validation_test_indices[validation_indices] + test_indices = validation_test_indices[test_indices] + + validation_dataset = NodeClassifierDataset(validation_indices.astype(np.int32), class_ids[validation_indices], random_state) + test_dataset = NodeClassifierDataset(test_indices.astype(np.int32), class_ids[test_indices], random_state) + + yield train_dataset, validation_dataset, test_dataset + + +def get_splitted_dataset(path: str, random_state: int = 110069) -> Tuple[NodeClassifierDataset, NodeClassifierDataset, NodeClassifierDataset]: + class_ids = io.load_npy(os.path.join(path, 'class_ids.npy'), mmap_mode='r') + + train_indices = io.load_npy(os.path.join(path, 'train_indices.npy')) + train_dataset = NodeClassifierDataset(train_indices, class_ids[train_indices], random_state) + + validation_indices = io.load_npy(os.path.join(path, 'validation_indices.npy')) + validation_dataset = NodeClassifierDataset(validation_indices, class_ids[validation_indices], random_state) + + test_indices = io.load_npy(os.path.join(path, 'test_indices.npy')) + test_dataset = NodeClassifierDataset(test_indices, class_ids[test_indices], random_state) + + return train_dataset, validation_dataset, test_dataset diff --git a/node_classification/datasets/__init__.py b/node_classification/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/node_classification/datasets/cora.py b/node_classification/datasets/cora.py new file mode 100644 index 0000000..246511c --- /dev/null +++ b/node_classification/datasets/cora.py @@ -0,0 +1,62 @@ +import os +import sys +from typing import Generator, Tuple + +import numpy as np +import pandas as pd +import scipy.sparse as sp + + +def create_triplets(heads: np.ndarray, tails: np.ndarray) -> Generator[Tuple[int, int, int], None, None]: + for head, tail in zip(heads, tails): + if head != tail: + yield head, tail, 0 + yield tail, head, 1 + + else: + yield tail, head, 2 + + +def normalize_features(features: sp.spmatrix) -> sp.spmatrix: + row_sum = np.array(features.sum(1)) + + row_inverse = np.power(row_sum, -1).flatten() + row_inverse[np.isinf(row_inverse)] = 0 + + normalized_features = sp.diags(row_inverse).dot(features) + + return normalized_features + + +def main(path: str): + cites = pd.read_csv(os.path.join(path, 'cites.csv'), header=None, names=('cited', 'citing')) + content = pd.read_csv(os.path.join(path, 'content.csv'), header=None, names=('id', 'word')) + paper = pd.read_csv(os.path.join(path, 'paper.csv'), header=None, names=('id', 'label')) + + paper_ids = set(cites['cited'].values) | set(cites['citing'].values) | set(content['id'].values) | set(paper['id'].values) + paper_id_to_index = {paper_id: index for index, paper_id in enumerate(paper_ids)} + label_to_index = {value: index for index, value in enumerate(set(paper['label']))} + + cites['cited'] = cites['cited'].map(lambda paper_id: paper_id_to_index[paper_id]) + cites['citing'] = cites['citing'].map(lambda paper_id: paper_id_to_index[paper_id]) + content['id'] = content['id'].map(lambda paper_id: paper_id_to_index[paper_id]) + paper['id'] = paper['id'].map(lambda paper_id: paper_id_to_index[paper_id]) + + heads, tails, edge_types = zip(*create_triplets(cites['cited'].values, cites['citing'].values)) + + adjacency_matrix = sp.coo_matrix((edge_types, (heads, tails)), dtype=np.int32).tocsr() + adjacency_matrix.indptr.tofile(os.path.join(path, 'accumulated_num_edges.bin')) + adjacency_matrix.indices.tofile(os.path.join(path, 'adjacencies.bin')) + adjacency_matrix.data.tofile(os.path.join(path, 'edge_types.bin')) + + content['word'] = content['word'].map(lambda string: int(string.replace('word', '')) - 1) + features = sp.coo_matrix((np.ones((len(content),), np.float32), (content['id'], content['word']))).todense() + np.save(os.path.join(path, 'node_embeddings.npy'), features) + + paper['label'] = paper['label'].map(lambda string: label_to_index[string]) + targets = paper.sort_values('id')['label'].values.astype(np.int32) + np.save(os.path.join(path, 'class_ids.npy'), targets) + + +if __name__ == '__main__': + main(*sys.argv[:1]) diff --git a/node_classification/datasets/ppi.py b/node_classification/datasets/ppi.py new file mode 100644 index 0000000..33369c6 --- /dev/null +++ b/node_classification/datasets/ppi.py @@ -0,0 +1,67 @@ +import json +import os +from typing import Iterable, Tuple, List, Dict + +import defopt +import numpy as np +import scipy.sparse as sp + + +def get_triplets(links: List[Dict[str, int]]) -> Tuple[Iterable[int], Iterable[int], Iterable[int]]: + heads, tails, edge_types = zip(*( + (head, tail, edge_type) + for link in links + for head, tail, edge_type in ((link['source'], link['target'], 0), (link['target'], link['source'], 0)) + )) + + return heads, tails, edge_types + + +def main(*, path: str): + graph = json.load(open(os.path.join(path, 'ppi-G.json'))) + + num_nodes = len(graph['nodes']) + + train_indices, validation_indices, test_indices = [], [], [] + + for node in graph['nodes']: + index = node['id'] + + if node['val']: + validation_indices.append(index) + + elif node['test']: + test_indices.append(index) + + else: + train_indices.append(index) + + np.save(os.path.join(path, 'train_indices.npy'), np.array(train_indices, dtype=np.int32)) + np.save(os.path.join(path, 'validation_indices.npy'), np.array(validation_indices, dtype=np.int32)) + np.save(os.path.join(path, 'test_indices.npy'), np.array(test_indices, dtype=np.int32)) + + heads, tails, edge_types = get_triplets(graph['links']) + + adjacency_matrix = sp.coo_matrix((edge_types, (heads, tails)), dtype=np.int32, shape=(num_nodes, num_nodes)).tocsr() + + adjacency_matrix.indptr.tofile(os.path.join(path, 'accumulated_num_edges.bin')) + adjacency_matrix.indices.tofile(os.path.join(path, 'adjacencies.bin')) + adjacency_matrix.data.tofile(os.path.join(path, 'edge_types.bin')) + + features = np.load(os.path.join(path, 'ppi-feats.npy')).astype(np.float32) + + np.save(os.path.join(path, 'node_embeddings.npy'), features) + + class_map = json.load(open(os.path.join(path, 'ppi-class_map.json'))) + num_labels = len(next(iter(class_map.values()))) + + targets = np.zeros((num_nodes, num_labels), dtype=np.int32) + + for key, labels in sorted([(int(key), value) for key, value in class_map.items()], key=lambda values: values[0]): + targets[key, :] = labels + + np.save(os.path.join(path, 'class_ids.npy'), targets) + + +if __name__ == '__main__': + defopt.run(main) diff --git a/node_classification/model.py b/node_classification/model.py new file mode 100644 index 0000000..4444bf4 --- /dev/null +++ b/node_classification/model.py @@ -0,0 +1,111 @@ +from typing import Tuple, Union, Optional + +import numpy as np +import tensorflow as tf + +from framework.common.parameters import DataSpecification +from gatas.aggregator import NeighbourAggregator +from gatas.sampler import NeighbourSampler + + +class NodeClassifier(tf.Module): + def __init__( + self, + path: str, + num_nodes: int, + num_classes: int, + layer_size: int, + layer_size_classifier: int, + num_attention_heads: int, + edge_type_embedding_size: int, + node_embedding_size: Optional[int], + num_steps: int, + sample_size: int, + input_noise_rate: float, + dropout_rate: float, + lambda_coefficient: float, + node_features: Optional[Union[tf.Tensor, np.ndarray]] = None): + + super().__init__() + + self.sample_size = sample_size + self.lambda_coefficient = lambda_coefficient + + self.neighbour_sampler = NeighbourSampler.from_path(num_steps, path) + + self.neighbour_aggregator = NeighbourAggregator( + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + num_nodes=num_nodes, + num_edge_types=self.neighbour_sampler.num_edge_types, + num_steps=num_steps, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + layer_size=layer_size, + num_attention_heads=num_attention_heads, + node_features=node_features, + ) + + self.classifier = tf.keras.Sequential([ + tf.keras.layers.Dense(units=layer_size_classifier, activation=tf.nn.elu), + tf.keras.layers.Dense(units=layer_size_classifier, activation=tf.nn.elu), + tf.keras.layers.Dense(units=num_classes, use_bias=False), + ]) + + @staticmethod + def get_schema(num_classes: int) -> DataSpecification: + schema = DataSpecification([ + tf.TensorSpec(tf.TensorShape((None,)), tf.int32, 'anchor_indices'), + tf.TensorSpec(tf.TensorShape((None, num_classes)), tf.int32, 'targets'), + # tf.TensorSpec(tf.TensorShape((None,)), tf.int32, 'targets'), + ]) + + return schema + + def __call__(self, anchor_indices: tf.Tensor, training: Union[tf.Tensor, bool]) -> Tuple[tf.Tensor, tf.Tensor]: + neighbour_sample = self.neighbour_sampler(anchor_indices, self.sample_size) + + node_representations = self.neighbour_aggregator( + anchor_indices=anchor_indices, + neighbour_indices=neighbour_sample.indices, + neighbour_assignments=neighbour_sample.segments, + neighbour_weights=neighbour_sample.weights, + neighbour_path_indices=neighbour_sample.path_indices, + training=training, + ) + + logits = self.classifier(node_representations) + + predictions = tf.cast(tf.round(tf.nn.sigmoid(logits)), tf.int32) + # predictions = tf.argmax(tf.nn.softmax(logits), axis=1, output_type=tf.int32) + + return logits, predictions + + def calculate_loss(self, logits: tf.Tensor, targets: tf.Tensor) -> tf.Tensor: + # loss = tf.losses.sparse_softmax_cross_entropy(labels=targets, logits=logits) + loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=targets, logits=logits) + + l2_norm = tf.add_n([ + tf.nn.l2_loss(variable) + for variable in tf.trainable_variables() + if variable.name not in {'coefficient_indices:0', 'edge_type_embeddings:0'} + ]) + + loss = tf.reduce_mean(loss) + self.lambda_coefficient * l2_norm + + return loss + + @classmethod + def get_clipped_gradient_updates( + cls, + loss: tf.Tensor, + optimizer: tf.train.Optimizer, + max_gradient_norm: float = 5.0) -> tf.Tensor: + + gradients, variables = zip(*optimizer.compute_gradients(loss)) + + clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) + + updates = optimizer.apply_gradients(zip(clipped_gradients, variables)) + + return updates diff --git a/node_classification/train.py b/node_classification/train.py new file mode 100644 index 0000000..f610d59 --- /dev/null +++ b/node_classification/train.py @@ -0,0 +1,381 @@ +import os +import traceback +from typing import Tuple, Callable, NamedTuple, List, Optional + +import defopt +import numpy as np +import tensorflow as tf + +from framework.common import parameters +from framework.dataset import io +from framework.trackers.aggregator import Aggregation, Statistic +from framework.trackers.metrics import MetricFunctions, Metric +from framework.trackers.tracker_mlflow import MLFlowTracker +from node_classification import dataset +from node_classification.model import NodeClassifier + + +class ModelMetrics(NamedTuple): + test_accuracy: float + test_micro_f1: float + + +class NodeClassifierTrainer: + def __init__( + self, + data_path: str, + model_path: str, + max_steps: int, + sample_size: int, + layer_size: int, + layer_size_classifier: int, + num_attention_heads: int, + edge_type_embedding_size: int, + node_embedding_size: Optional[int], + input_noise_rate: float, + dropout_rate: float, + lambda_coefficient: float, + learning_rate: float) -> None: + + self.data_path = data_path + + self.checkpoint_path = os.path.join(model_path, 'model.ckpt') + + # num_classes = np.max(io.load_npy(os.path.join(data_path, 'class_ids.npy'))) + 1 + num_classes = io.load_npy(os.path.join(data_path, 'class_ids.npy')).shape[1] + node_embeddings = io.load_npy(os.path.join(data_path, 'node_embeddings.npy'), mmap_mode='r') + + self.iterator = self.create_iterator(num_classes) + self.training = tf.placeholder_with_default(False, shape=(), name='training') + + self.model = NodeClassifier( + path=data_path, + num_nodes=node_embeddings.shape[0], + num_classes=num_classes, + layer_size=layer_size, + layer_size_classifier=layer_size_classifier, + num_attention_heads=num_attention_heads, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + num_steps=max_steps, + sample_size=sample_size, + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + lambda_coefficient=lambda_coefficient, + node_features=node_embeddings, + ) + + self.anchor_indices, self.targets = self.iterator.get_next() + self.training = tf.placeholder_with_default(False, shape=(), name='training') + self.logits, self.predictions = self.model(self.anchor_indices, self.training) + self.loss = self.model.calculate_loss(self.logits, self.targets) + self.updates = self.model.get_clipped_gradient_updates(self.loss, optimizer=tf.contrib.opt.NadamOptimizer(learning_rate)) + + self.tracker = MLFlowTracker('node-classifier', aggregators=[ + Aggregation( + Metric.TRAINING_MEAN_COST, + [Statistic.TRAINING_COST], + MetricFunctions.mean), + Aggregation( + Metric.VALIDATION_MEAN_COST, + [Statistic.VALIDATION_COST], + MetricFunctions.mean), + Aggregation( + Metric.TESTING_MEAN_COST, + [Statistic.TESTING_COST], + MetricFunctions.mean), + Aggregation( + Metric.TRAINING_MICRO_F1, + [Statistic.TRAINING_TARGET, Statistic.TRAINING_PREDICTION], + lambda x, y: MetricFunctions.f1_score(x, y, average='micro')), + Aggregation( + Metric.VALIDATION_MICRO_F1, + [Statistic.VALIDATION_TARGET, Statistic.VALIDATION_PREDICTION], + lambda x, y: MetricFunctions.f1_score(x, y, average='micro')), + Aggregation( + Metric.TESTING_MICRO_F1, + [Statistic.TESTING_TARGET, Statistic.TESTING_PREDICTION], + lambda x, y: MetricFunctions.f1_score(x, y, average='micro')), + Aggregation( + Metric.TRAINING_ACCURACY, + [Statistic.TRAINING_TARGET, Statistic.TRAINING_PREDICTION], + lambda x, y: MetricFunctions.accuracy(x, y)), + Aggregation( + Metric.VALIDATION_ACCURACY, + [Statistic.VALIDATION_TARGET, Statistic.VALIDATION_PREDICTION], + lambda x, y: MetricFunctions.accuracy(x, y)), + Aggregation( + Metric.TESTING_ACCURACY, + [Statistic.TESTING_TARGET, Statistic.TESTING_PREDICTION], + lambda x, y: MetricFunctions.accuracy(x, y)), + ]) + + self.tracker.set_tags({ + 'Tier': 'Development', + 'Problem': 'Node Classification', + }) + + @staticmethod + def create_iterator(num_classes: int) -> tf.data.Iterator: + iterator = tf.data.Iterator.from_structure( + output_types=NodeClassifier.get_schema(num_classes).get_types(), + output_shapes=NodeClassifier.get_schema(num_classes).get_shapes(), + ) + + return iterator + + @staticmethod + def initialize_iterator( + iterator: tf.data.Iterator, + generator: Callable, + arguments: Tuple, + queue_size: int = -1) -> tf.Operation: + + validation_dataset = tf.data.Dataset \ + .from_generator( + generator=generator, + output_types=iterator.output_types, + output_shapes=iterator.output_shapes, + args=arguments) \ + .prefetch(queue_size) + + iterator = iterator.make_initializer(validation_dataset) + + return iterator + + def train_early_stopping( + self, + session: tf.Session, + train_iterator: tf.Operation, + validation_iterator: tf.Operation, + test_iterator: tf.Operation, + save_path: str, + num_epochs: int, + early_stopping_threshold: int, + previous_best_metric: float = -np.inf, + saver_variables: Optional[List[tf.Tensor]] = None) -> ModelMetrics: + + saver = tf.train.Saver(var_list=saver_variables) + + non_improvement_times = 0 + + for epoch in range(num_epochs): + session.run(train_iterator) + + try: + while True: + cost_, predictions_, targets_, _ = session.run( + fetches=(self.loss, self.predictions, self.targets, self.updates), + feed_dict={self.training: True}, + ) + + self.tracker.add_statistics({ + Statistic.TRAINING_COST: cost_, + Statistic.TRAINING_PREDICTION: predictions_, + Statistic.TRAINING_TARGET: targets_, + }) + + except tf.errors.OutOfRangeError: + pass + + session.run(validation_iterator) + + try: + while True: + cost_, predictions_, targets_ = session.run(fetches=(self.loss, self.predictions, self.targets)) + + self.tracker.add_statistics({ + Statistic.VALIDATION_COST: cost_, + Statistic.VALIDATION_PREDICTION: predictions_, + Statistic.VALIDATION_TARGET: targets_, + }) + + except tf.errors.OutOfRangeError: + pass + + training_loss = self.tracker.compute_metric(Metric.TRAINING_MEAN_COST) + train_metric = self.tracker.compute_metric(Metric.TRAINING_MICRO_F1) + validation_metric = self.tracker.compute_metric(Metric.VALIDATION_MICRO_F1) + + print(f'Epoch: {epoch}, training loss: {training_loss}, train metric: {train_metric}, validation metric: {validation_metric}') + + try: + if validation_metric > previous_best_metric: + non_improvement_times, previous_best_metric = 0, validation_metric + + saver.save(session, save_path) + + elif non_improvement_times < early_stopping_threshold: + non_improvement_times += 1 + + else: + print('Stopping after no improvement.') + break + + finally: + self.tracker.finish_epoch() + + saver.restore(session, save_path) + + session.run(test_iterator) + + try: + while True: + cost_, predictions_, targets_ = session.run(fetches=(self.loss, self.predictions, self.targets)) + + self.tracker.add_statistics({ + Statistic.TESTING_COST: cost_, + Statistic.TESTING_PREDICTION: predictions_, + Statistic.TESTING_TARGET: targets_, + }) + + except tf.errors.OutOfRangeError: + pass + + evaluation_metrics = ModelMetrics( + test_accuracy=self.tracker.compute_metric(Metric.TESTING_ACCURACY), + test_micro_f1=self.tracker.compute_metric(Metric.TESTING_MICRO_F1), + ) + + self.tracker.clear() + self.tracker.save_model(save_path) + + return evaluation_metrics + + def multi_step_train_with_early_stopping( + self, + num_folds: int, + batch_size: int, + max_num_epochs: int, + maximum_non_improvement_epochs: int) -> None: + + initializers = tf.global_variables_initializer() + fold_metrics = [] + + train_dataset, validation_dataset, test_dataset = dataset.get_splitted_dataset(self.data_path) + + train_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=train_dataset.get_batches, + arguments=(batch_size,)) + + validation_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=validation_dataset.get_batches, + arguments=(batch_size,)) + + test_iterator = self.initialize_iterator( + iterator=self.iterator, + generator=test_dataset.get_batches, + arguments=(batch_size,)) + + for fold_index in range(num_folds): + with tf.Session() as session: + session.run(initializers) + + metrics = self.train_early_stopping( + session=session, + train_iterator=train_iterator, + validation_iterator=validation_iterator, + test_iterator=test_iterator, + save_path=self.checkpoint_path, + num_epochs=max_num_epochs, + early_stopping_threshold=maximum_non_improvement_epochs, + ) + + print(f'\nFold: {fold_index + 1}, Test Accuracy: {metrics.test_accuracy}, Micro F1 Score: {metrics.test_micro_f1}') + self.tracker.log_metrics({ + 'Fold Accuracy': metrics.test_accuracy, + 'Fold Micro F1 Score': metrics.test_micro_f1, + }, fold_index + 1) + + fold_metrics.append(metrics) + + accuracy_mean = float(np.mean([metrics.test_accuracy for metrics in fold_metrics])) + accuracy_std = float(np.std([metrics.test_accuracy for metrics in fold_metrics])) + micro_f1_mean = float(np.mean([metrics.test_micro_f1 for metrics in fold_metrics])) + micro_f1_std = float(np.std([metrics.test_micro_f1 for metrics in fold_metrics])) + + print(f'\nAccuracy: {accuracy_mean} (±{accuracy_std}), F1 Micro: {micro_f1_mean} (±{micro_f1_std})') + self.tracker.log_metrics({ + 'Test Accuracy Mean': accuracy_mean, + 'Test Accuracy Standard Deviation': accuracy_std, + 'Test Micro F1 Score Mean': micro_f1_mean, + 'Test Micro F1 Score Standard Deviation': micro_f1_std, + }) + + +def train( + *, + data_path: str, + model_path: str = '.', + max_steps: int = 3, + sample_size: int = 500, + learning_rate: float = 0.001, + lambda_coefficient: float = 0, + batch_size: int = 100, + input_noise_rate: float = 0.0, + dropout_rate: float = 0.0, + layer_size: int = 50, + layer_size_classifier: int = 256, + num_attention_heads: int = 10, + edge_type_embedding_size: int = 5, + node_embedding_size: Optional[int] = None, + max_num_epochs: int = 1000, + num_folds: int = 10, + maximum_non_improvement_epochs: int = 10) -> None: + """ + Trains a node classifier. + + :param data_path: Path to data. + :param model_path: Path to model. + :param max_steps: Maximum random walk steps. + :param sample_size: Neighbourhood sample size. + :param learning_rate: Learning rate for the optimizer. + :param lambda_coefficient: L2 loss coefficient. + :param batch_size: Batch size for stochastic gradient descend. + :param input_noise_rate: Node feature drop rate during training. + :param dropout_rate: Dropout probability. + :param layer_size: The size of the output for each layer in the neighbour aggregator. + :param layer_size_classifier: The size of the output for each layer in the classifier. + :param num_attention_heads: The number of attention heads for a GATAS node. + :param edge_type_embedding_size: The size of the trainable edge type embeddings. + :param node_embedding_size: The size of the trainable node embeddings, if any. + :param max_num_epochs: Maximum number of epochs to train for. + :param num_folds: Number of runs. + :param maximum_non_improvement_epochs: Number of epochs for early stopping (patience). + """ + trainer = NodeClassifierTrainer( + data_path=data_path, + model_path=model_path, + max_steps=max_steps, + sample_size=sample_size, + layer_size=layer_size, + layer_size_classifier=layer_size_classifier, + num_attention_heads=num_attention_heads, + edge_type_embedding_size=edge_type_embedding_size, + node_embedding_size=node_embedding_size, + input_noise_rate=input_noise_rate, + dropout_rate=dropout_rate, + lambda_coefficient=lambda_coefficient, + learning_rate=learning_rate, + ) + + trainer.tracker.register_parameters(parameters.get_script_parameters(train)) + + with trainer.tracker: + try: + trainer.multi_step_train_with_early_stopping( + num_folds=num_folds, + batch_size=batch_size, + max_num_epochs=max_num_epochs, + maximum_non_improvement_epochs=maximum_non_improvement_epochs, + ) + + except Exception as error: + trainer.tracker.set_tags({'Error': traceback.format_exc()}) + raise error + + +if __name__ == '__main__': + defopt.run(train) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7915342 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + + +setup( + name='gatas', + version='1.0.0', + packages=find_packages(exclude=('tests', 'tests.*')), + python_requires='~=3.7', + install_requires=[ + 'mlflow~=1.8', + 'defopt~=6.0', + 'numba~=0.49', + 'numpy~=1.18.0', + 's3fs~=0.4.0', + 'scikit-learn~=0.22.0', + 'tensorflow~=1.15', + ]) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/gatas/__init__.py b/tests/gatas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/gatas/test_aggregator.py b/tests/gatas/test_aggregator.py new file mode 100644 index 0000000..5ea2e23 --- /dev/null +++ b/tests/gatas/test_aggregator.py @@ -0,0 +1,52 @@ +import tensorflow as tf + +from gatas.aggregator import NeighbourAggregator + + +class NeighbourAggregatorTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + tf.enable_eager_execution() + + def test_node_representations(self): + anchor_indices = [1, 2, 3] + + neighbour_indices = [0, 4, 5] + neighbour_path_indices = [0, 1, 4] + neighbour_assignments = [0, 0, 2] + neighbour_weights = [1., 1., 1.] + + node_features = [ + [0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [8., 11., 10., 14., 20.], + [0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.], + ] + + layer_size = 10 + num_attention_heads = 4 + + model = NeighbourAggregator( + input_noise_rate=0.0, + dropout_rate=0.0, + num_nodes=6, + num_edge_types=2, + num_steps=3, + edge_type_embedding_size=6, + node_embedding_size=6, + layer_size=layer_size, + num_attention_heads=num_attention_heads, + node_features=tf.convert_to_tensor(node_features), + ) + + neighbour_probabilities = model( + anchor_indices=tf.convert_to_tensor(anchor_indices), + neighbour_indices=tf.convert_to_tensor(neighbour_indices), + neighbour_assignments=tf.convert_to_tensor(neighbour_assignments), + neighbour_weights=tf.convert_to_tensor(neighbour_weights), + neighbour_path_indices=tf.convert_to_tensor(neighbour_path_indices), + ) + + self.assertAllEqual(neighbour_probabilities.shape, [3, layer_size * num_attention_heads]) diff --git a/tests/gatas/test_sampler.py b/tests/gatas/test_sampler.py new file mode 100644 index 0000000..f265a2d --- /dev/null +++ b/tests/gatas/test_sampler.py @@ -0,0 +1,271 @@ +import numpy as np +import tensorflow as tf + +from gatas import sampler +from gatas.sampler import NeighbourSampler + + +class NeighbourhoodSamplerTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + tf.enable_eager_execution() + + def test_compute_path_depths(self): + path_depths = sampler.compute_path_depths(3, 3) + expected_path_depths = np.array([ + 0, + 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + ], dtype=np.int32) + + self.assertAllEqual(path_depths, expected_path_depths) + + def test_compute_top_k(self): + probabilities = np.array( + [0.16288152, 0.04801334, 0.7079845, 0.01123138, 0.06209923, 0.00779003], + dtype=np.float32, + ) + + path_indices = np.array([0, 0, 0, 0, 0, 0], dtype=np.int32) + path_depths = np.array([0], dtype=np.int32) + coefficients = np.array([1.], dtype=np.int32) + + max_indices, max_values = sampler.compute_top_k( + probabilities=probabilities, + path_indices=path_indices, + path_depths=path_depths, + coefficients=coefficients, + num_steps=100, + k=2, + noisify=False, + ) + + self.assertAllEqual(np.sort(max_indices), np.sort(np.argsort(probabilities)[-2:])) + self.assertAllClose(np.sort(max_values), np.sort(np.sort(np.log(probabilities))[-2:])) + + def test_compute_top_large(self): + probabilities = np.array( + [0.16288152, 0.04801334, 0.7079845, 0.01123138, 0.06209923, 0.00779003], + dtype=np.float32, + ) + + path_indices = np.array([0, 0, 0, 0, 0, 0], dtype=np.int32) + path_depths = np.array([0], dtype=np.int32) + coefficients = np.array([1.], dtype=np.int32) + + max_indices, max_values = sampler.compute_top_k( + probabilities=probabilities, + path_indices=path_indices, + path_depths=path_depths, + coefficients=coefficients, + num_steps=100, + k=100, + noisify=False, + ) + + self.assertAllEqual(np.sort(max_indices), np.arange(probabilities.size)) + self.assertAllClose(np.sort(max_values), np.sort(np.log(probabilities))) + + def test_compute_top_k_small(self): + probabilities = np.array( + [0.9654813, 0.01837953, 0.00135655, 0.01478269], + dtype=np.float32, + ) + + path_indices = np.array([0, 0, 0, 0, 0, 0], dtype=np.int32) + path_depths = np.array([0], dtype=np.int32) + coefficients = np.array([1.], dtype=np.int32) + + max_indices, max_values = sampler.compute_top_k( + probabilities=probabilities, + path_indices=path_indices, + path_depths=path_depths, + coefficients=coefficients, + num_steps=100, + k=1, + noisify=False, + ) + + self.assertAllEqual(max_indices, [np.argmax(probabilities)]) + self.assertAllClose(max_values, [np.max(np.log(probabilities))]) + + def test_compute_top_k_zero(self): + probabilities = np.array( + [ + 0.16288152, 0.04801334, 0.7079845, 0.01123138, 0.06209923, 0.00779003, + ], + dtype=np.float32) + + path_indices = np.array([0, 0, 0, 0, 0, 0], dtype=np.int32) + path_depths = np.array([0], dtype=np.int32) + coefficients = np.array([1.], dtype=np.int32) + + max_indices, max_values = sampler.compute_top_k( + probabilities=probabilities, + path_indices=path_indices, + path_depths=path_depths, + coefficients=coefficients, + num_steps=10, + k=0, + noisify=False, + ) + + self.assertAllEqual(max_indices, []) + self.assertAllEqual(max_values, []) + + def test_calculate_transition_logits(self): + probabilities = np.array( + [ + 0.3333, 0.1110, 0.0370, + 0.3333, 0.1110, 0.0370, + 0.3333, 0.1944, 0.0856, + 0.0833, 0.0486, + 0.0833, 0.0486, + 0.0833, 0.0486, + 0.25, 0.0625, 0.0156, + 0.25, 0.0625, 0.0156, + 0.25, 0.0625, 0.0156, + 0.25, 0.0625, 0.0156, + ], + dtype=np.float32, + ) + + path_indices = np.array( + [ + 0, 1, 3, + 0, 2, 4, + 0, 1, 5, + 2, 6, + 1, 7, + 2, 8, + 0, 1, 9, + 0, 2, 10, + 0, 1, 11, + 0, 2, 3, + ], + dtype=np.int32, + ) + + path_depths = sampler.compute_path_depths(2, 3) + + coefficients = 1 - np.arange(3, dtype=np.float32) / 3 + coefficients = np.exp(coefficients) + coefficients /= np.sum(coefficients) + + logits = sampler.calculate_transition_logits( + probabilities=probabilities, + coefficients=coefficients[path_depths[path_indices]], + noisify=False, + ) + + expected_probabilities = [ + 0.1495, 0.0357, 0.0085, + 0.1495, 0.0357, 0.0085, + 0.1495, 0.0625, 0.0197, + 0.0268, 0.0112, + 0.0268, 0.0112, + 0.0268, 0.0112, + 0.1121, 0.0201, 0.0036, + 0.1121, 0.0201, 0.0036, + 0.1121, 0.0201, 0.0036, + 0.1121, 0.0201, 0.0036, + ] + + self.assertAllClose(np.exp(logits), expected_probabilities, rtol=1e-4, atol=1e-4) + + def test_generate_sample(self): + accumulated_transition_lengths = np.array([0, 3, 3, 7, 7, 7, 7], np.int32) + + neighbours = np.array([ + 1, 1, 2, + 2, 3, 4, 5, + ], dtype=np.int32) + + path_indices = np.array([ + 1, 2, 2, + 3, 4, 5, 6, + ], dtype=np.int32) + + probabilities = np.array([ + 0.1, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.7, + ], dtype=np.float32) + + coefficients = np.array([0.1, 0.2, 0.3], dtype=np.float32) + + path_depths = sampler.compute_path_depths(2, 3) + + indices, segments, path_indices_subset, steps, probabilities_subset = sampler.generate_sample( + node_indices=np.array([0, 1, 2], np.int32), + transition_pointers=accumulated_transition_lengths, + neighbours=neighbours, + path_indices=path_indices, + probabilities=probabilities, + coefficients=coefficients, + path_depths=path_depths, + num_steps=2, + sample_size=100, + noisify=False, + ) + + expected_order = np.array([0, 2, 1, 3, 5, 4, 6]) + expected_indices = neighbours[expected_order] + expected_path_indices = path_indices[expected_order] + expected_segments = [0, 0, 0, 2, 2, 2, 2] + expected_steps = path_depths[expected_path_indices] + expected_probabilities = probabilities[expected_order] + + self.assertAllEqual(indices, expected_indices) + self.assertAllEqual(path_indices_subset, expected_path_indices) + self.assertAllEqual(segments, expected_segments) + self.assertAllEqual(steps, expected_steps) + self.assertAllClose(probabilities_subset, expected_probabilities) + + def test_neighbour_sampler(self): + accumulated_transition_lengths = np.array([0, 3, 3, 7, 7, 7, 7], np.int32) + + neighbours = np.array([ + 1, 1, 2, + 2, 3, 4, 5, + ], dtype=np.int32) + + path_indices = np.array([ + 1, 2, 2, + 3, 4, 5, 6, + ], dtype=np.int32) + + probabilities = np.array([ + 0.1, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.7, + ], dtype=np.float32) + + neighbour_sampler = NeighbourSampler( + accumulated_transition_lengths=accumulated_transition_lengths, + neighbours=neighbours, + path_indices=path_indices, + probabilities=probabilities, + num_edge_types=3, + num_steps=2, + ) + + sample = neighbour_sampler( + node_indices=tf.convert_to_tensor([0, 1, 2], dtype=tf.int32), + sample_size=100, + noisify=False, + ) + + coefficients = neighbour_sampler.coefficients.numpy() + coefficients = np.exp(coefficients) + coefficients /= np.sum(coefficients) + + expected_order = np.array([0, 2, 1, 3, 5, 4, 6]) + expected_indices = neighbours[expected_order] + expected_path_indices = path_indices[expected_order] + expected_segments = [0, 0, 0, 2, 2, 2, 2] + expected_weights = (probabilities * coefficients[neighbour_sampler.path_depths[path_indices]])[expected_order] + + self.assertAllEqual(sample.indices, expected_indices) + self.assertAllEqual(sample.path_indices, expected_path_indices) + self.assertAllEqual(sample.segments, expected_segments) + self.assertAllClose(sample.weights, expected_weights) diff --git a/tests/gatas/test_transitions.py b/tests/gatas/test_transitions.py new file mode 100644 index 0000000..536348f --- /dev/null +++ b/tests/gatas/test_transitions.py @@ -0,0 +1,63 @@ +import numpy as np +import tensorflow as tf + +from gatas import transitions + + +class TransitionTensorsTest(tf.test.TestCase): + def test_transition_tensors(self): + accumulated_num_edges = np.array([0, 3, 3, 7, 7, 7, 7], dtype=np.int32) + adjacencies = np.array([0, 1, 2, 2, 3, 4, 5], dtype=np.int32) + edge_types = np.array([0, 1, 0, 1, 0, 1, 0], dtype=np.int32) + + accumulated_transition_lengths, neighbours, path_indices, probabilities = transitions.create_transition_tensors( + accumulated_num_edges=accumulated_num_edges, + adjacencies=adjacencies, + edge_types=edge_types, + num_steps=2, + ) + + expected_accumulated_transition_lengths = [ + 0, 6, 7, 11, 12, 13, 14 + ] + + expected_neighbours = [ + 0, + 1, 2, + 3, 4, 5, + 1, + 2, + 3, 4, 5, + 3, + 4, + 5, + ] + + expected_path_indices = [ + 0, + 2, 1, + 7, 8, 7, + 0, + 0, + 1, 2, 1, + 0, + 0, + 0, + ] + + expected_probabilities = [ + 1.0, + 0.5, 0.5, + 0.3333, 0.3333, 0.3333, + 1.0, + 1.0, + 0.3333, 0.3333, 0.3333, + 1.0, + 1.0, + 1.0 + ] + + self.assertAllEqual(expected_accumulated_transition_lengths, accumulated_transition_lengths) + self.assertAllEqual(expected_neighbours, neighbours) + self.assertAllEqual(expected_path_indices, path_indices) + self.assertAllClose(expected_probabilities, probabilities, rtol=1e-4, atol=1e-4) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..2bff02d --- /dev/null +++ b/tox.ini @@ -0,0 +1,2 @@ +[pycodestyle] +max-line-length = 150