Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add versioning in save methods #442

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions gli/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Helper functions for creating datasets in GLI format."""
import datetime
import json
import os
from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
from scipy.sparse import isspmatrix, spmatrix, coo_matrix
import time

from gli.utils import save_data

Expand Down Expand Up @@ -165,8 +167,7 @@ def save_graph(
description, cite, save_dir)
# verify the inputs are dict for heterograph
if not isinstance(edge, dict):
raise TypeError(
"The input edge must be a dictionary for heterograph.")
raise TypeError("The input edge must be a dictionary for heterograph.")
if num_nodes is not None and not isinstance(num_nodes, dict):
raise TypeError(
"The input num_nodes must be a dictionary for heterograph.")
Expand Down Expand Up @@ -247,6 +248,11 @@ def _attr_to_metadata_dict(key_to_loc, prefix, a):
return metadata


def _get_version():
"""Get the current utc time as the version."""
return datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d%H%M%S")


def save_homograph(
name: str,
edge: np.ndarray,
Expand Down Expand Up @@ -455,6 +461,10 @@ def save_homograph(

metadata["citation"] = citation
metadata["is_heterogeneous"] = False
metadata["version"] = _get_version()
print("The graph metadata is saved to",
os.path.join(save_dir, "metadata.json"))
print("Version:", metadata["version"])

if citation == "":
warnings.warn("The citation is empty.")
Expand Down Expand Up @@ -907,6 +917,10 @@ def save_heterograph(
graph_dict[attr.name] = _attr_to_metadata_dict(key_to_loc, "Graph",
attr)
metadata["data"]["Graph"] = graph_dict
metadata["version"] = _get_version()
print("The graph metadata is saved to",
os.path.join(save_dir, "metadata.json"))
print("Version:", metadata["version"])

if citation == "":
warnings.warn("The citation is empty.")
Expand Down Expand Up @@ -969,6 +983,14 @@ def _check_feature(feature):
"Each element in `feature` must be a node/edge/graph attribute."


def _get_metadata_version(metadata):
"""Check the version of the metadata is valid."""
assert "version" in metadata, \
"The metadata does not contain the version information." \
"Please add the version information to the metadata."
return metadata["version"]


def _save_task_reg_or_cls(task_type,
name,
description,
Expand Down Expand Up @@ -1078,6 +1100,12 @@ def _save_task_reg_or_cls(task_type,
val_ratio, test_ratio,
num_samples)

# Check if metadata.json exists.
metadata_path = os.path.join(save_dir, "metadata.json")
assert os.path.exists(metadata_path), \
"metadata.json does not exist. Please create it first."
current_metadata_version = _get_metadata_version(metadata_path)

# Task-dependent checks.
if task_type in ("NodeClassification", "NodeRegression"):
assert target.startswith("Node/"), \
Expand All @@ -1096,7 +1124,8 @@ def _save_task_reg_or_cls(task_type,
"description": description,
"type": task_type,
"feature": feature,
"target": target
"target": target,
"version": current_metadata_version,
}
if num_classes is not None:
task_dict["num_classes"] = num_classes
Expand Down Expand Up @@ -1141,7 +1170,8 @@ def save_task_node_regression(name,
test_ratio=0.1,
num_samples=None,
task_id=1,
save_dir="."):
save_dir=".",
latest_supported_ver=None):
"""Save the node regression task information into task json and data files.

:param name: The name of the dataset.
Expand Down Expand Up @@ -1200,6 +1230,8 @@ def save_task_node_regression(name,
:param save_dir: The directory to save the task json and data files.
Default: ".".
:type save_dir: str
:param latest_supported_ver: The latest supported version of the metadata.
:type latest_supported_ver: str

:raises ValueError: If `task_type` is not "NodeRegression" or
"NodeClassification".
Expand Down Expand Up @@ -1285,7 +1317,8 @@ def save_task_node_regression(name,
test_ratio=test_ratio,
num_samples=num_samples,
task_id=task_id,
save_dir=save_dir)
save_dir=save_dir,
latest_supported_ver=latest_supported_ver)


def save_task_node_classification(name,
Expand All @@ -1301,7 +1334,8 @@ def save_task_node_classification(name,
test_ratio=0.1,
num_samples=None,
task_id=1,
save_dir="."):
save_dir=".",
latest_supported_ver=None):
"""Save the node classification task information into task json and data files.

:param name: The name of the dataset.
Expand Down Expand Up @@ -1450,4 +1484,5 @@ def save_task_node_classification(name,
test_ratio=test_ratio,
num_samples=num_samples,
task_id=task_id,
save_dir=save_dir)
save_dir=save_dir,
latest_supported_ver=latest_supported_ver)