diff --git a/gli/io.py b/gli/io.py index fc294831..39435a85 100644 --- a/gli/io.py +++ b/gli/io.py @@ -1,10 +1,12 @@ """Helper functions for creating datasets in GLI format.""" +import datetime import json import os from typing import Dict, List, Optional, Tuple, Union import warnings import numpy as np from scipy.sparse import isspmatrix, spmatrix, coo_matrix +import time from gli.utils import save_data @@ -165,8 +167,7 @@ def save_graph( description, cite, save_dir) # verify the inputs are dict for heterograph if not isinstance(edge, dict): - raise TypeError( - "The input edge must be a dictionary for heterograph.") + raise TypeError("The input edge must be a dictionary for heterograph.") if num_nodes is not None and not isinstance(num_nodes, dict): raise TypeError( "The input num_nodes must be a dictionary for heterograph.") @@ -247,6 +248,11 @@ def _attr_to_metadata_dict(key_to_loc, prefix, a): return metadata +def _get_version(): + """Get the current utc time as the version.""" + return datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d%H%M%S") + + def save_homograph( name: str, edge: np.ndarray, @@ -455,6 +461,10 @@ def save_homograph( metadata["citation"] = citation metadata["is_heterogeneous"] = False + metadata["version"] = _get_version() + print("The graph metadata is saved to", + os.path.join(save_dir, "metadata.json")) + print("Version:", metadata["version"]) if citation == "": warnings.warn("The citation is empty.") @@ -907,6 +917,10 @@ def save_heterograph( graph_dict[attr.name] = _attr_to_metadata_dict(key_to_loc, "Graph", attr) metadata["data"]["Graph"] = graph_dict + metadata["version"] = _get_version() + print("The graph metadata is saved to", + os.path.join(save_dir, "metadata.json")) + print("Version:", metadata["version"]) if citation == "": warnings.warn("The citation is empty.") @@ -969,6 +983,14 @@ def _check_feature(feature): "Each element in `feature` must be a node/edge/graph attribute." +def _get_metadata_version(metadata): + """Check the version of the metadata is valid.""" + assert "version" in metadata, \ + "The metadata does not contain the version information." \ + "Please add the version information to the metadata." + return metadata["version"] + + def _save_task_reg_or_cls(task_type, name, description, @@ -1078,6 +1100,12 @@ def _save_task_reg_or_cls(task_type, val_ratio, test_ratio, num_samples) + # Check if metadata.json exists. + metadata_path = os.path.join(save_dir, "metadata.json") + assert os.path.exists(metadata_path), \ + "metadata.json does not exist. Please create it first." + current_metadata_version = _get_metadata_version(metadata_path) + # Task-dependent checks. if task_type in ("NodeClassification", "NodeRegression"): assert target.startswith("Node/"), \ @@ -1096,7 +1124,8 @@ def _save_task_reg_or_cls(task_type, "description": description, "type": task_type, "feature": feature, - "target": target + "target": target, + "version": current_metadata_version, } if num_classes is not None: task_dict["num_classes"] = num_classes @@ -1141,7 +1170,8 @@ def save_task_node_regression(name, test_ratio=0.1, num_samples=None, task_id=1, - save_dir="."): + save_dir=".", + latest_supported_ver=None): """Save the node regression task information into task json and data files. :param name: The name of the dataset. @@ -1200,6 +1230,8 @@ def save_task_node_regression(name, :param save_dir: The directory to save the task json and data files. Default: ".". :type save_dir: str + :param latest_supported_ver: The latest supported version of the metadata. + :type latest_supported_ver: str :raises ValueError: If `task_type` is not "NodeRegression" or "NodeClassification". @@ -1285,7 +1317,8 @@ def save_task_node_regression(name, test_ratio=test_ratio, num_samples=num_samples, task_id=task_id, - save_dir=save_dir) + save_dir=save_dir, + latest_supported_ver=latest_supported_ver) def save_task_node_classification(name, @@ -1301,7 +1334,8 @@ def save_task_node_classification(name, test_ratio=0.1, num_samples=None, task_id=1, - save_dir="."): + save_dir=".", + latest_supported_ver=None): """Save the node classification task information into task json and data files. :param name: The name of the dataset. @@ -1450,4 +1484,5 @@ def save_task_node_classification(name, test_ratio=test_ratio, num_samples=num_samples, task_id=task_id, - save_dir=save_dir) + save_dir=save_dir, + latest_supported_ver=latest_supported_ver)