Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,17 @@ def check_if_task_metadata_valid(
assert (
len(task_metadata_pb.supervision_edge_types) > 0
), "Must provide at least one supervision edge type."
graph_metadata_pb_edge_types = [
GbmlProtosTranslator.edge_type_from_EdgeTypePb(edge_type_pb=edge_type_pb)
for edge_type_pb in graph_metadata_pb.edge_types
]
graph_metadata_node_types = graph_metadata_pb.node_types
for edge_type_pb in task_metadata_pb.supervision_edge_types:
edge_type = GbmlProtosTranslator.edge_type_from_EdgeTypePb(
edge_type_pb=edge_type_pb
)
assert (
edge_type in graph_metadata_pb_edge_types
), f"Invalid supervision edge type: {edge_type}; not found in graphMetadata edge types {graph_metadata_pb_edge_types}."
edge_type.src_node_type in graph_metadata_node_types
), f"Invalid supervision edge type: {edge_type}; which contains a source node type not found in graphMetadata node types: {graph_metadata_node_types}."
assert (
edge_type.dst_node_type in graph_metadata_node_types
), f"Invalid supervision edge type: {edge_type}; which contains a destination node type not found in graphMetadata node types: {graph_metadata_node_types}."
else:
raise ValueError(
f"Invalid 'taskMetadata'; must be one of {[TaskMetadataType.NODE_BASED_TASK, TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK]}.",
Expand Down
71 changes: 71 additions & 0 deletions python/tests/unit/src/validation/task_metadata_is_valid_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. this file should be template_config_checks_test.py?

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest

from gigl.src.validation_check.libs.template_config_checks import (
check_if_task_metadata_valid,
)
from snapchat.research.gbml import gbml_config_pb2, graph_schema_pb2
from tests.test_assets.graph_metadata_constants import (
DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB,
DEFAULT_HOMOGENEOUS_NODE_TYPE_STR,
Comment on lines +8 to +9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated, this shouldn't be done in this pr.

We should not have mutable constants like this [1]. If we do this, then one test could mutate them, and have other downstream tests fail.

Ideally, we have some create_graph_metadata_pb() that always creates a new object.

Can we add this to some task tracker?

)


class TaskMetadataIsValidTest(unittest.TestCase):
"""
Tests for the check_if_task_metadata_valid function.
Tests edge validation behavior for link prediction tasks.
"""

def _create_link_prediction_task_config(
self,
supervision_edge_types: list[graph_schema_pb2.EdgeType],
graph_metadata: graph_schema_pb2.GraphMetadata,
) -> gbml_config_pb2.GbmlConfig:
"""Helper method to create a node-anchor-based link prediction task configuration."""

return gbml_config_pb2.GbmlConfig(
task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata(
node_anchor_based_link_prediction_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadata(
supervision_edge_types=supervision_edge_types
)
),
graph_metadata=graph_metadata,
)

def test_link_prediction_task_edge_with_invalid_node_types_raises_error(self):
"""Test that error is raised when supervision edge has node types not in graph metadata."""
# Create an edge type with node types that don't exist in graph metadata
edge_with_invalid_nodes = graph_schema_pb2.EdgeType(
src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # valid node type
relation="to",
dst_node_type="nonexistent_dst_node_type", # invalid destination node type
)
config = self._create_link_prediction_task_config(
supervision_edge_types=[edge_with_invalid_nodes],
graph_metadata=DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB,
)

with self.assertRaises(AssertionError):
check_if_task_metadata_valid(config)

def test_link_prediction_task_edge_not_in_graph_metadata_but_nodes_valid_passes(
self,
):
"""Test that no error is raised when edge type is not in graph metadata but node types are valid."""
# Create an edge type with valid node types but a relation that doesn't exist in graph metadata
edge_with_new_relation = graph_schema_pb2.EdgeType(
src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # Valid node type
relation="completely_new_relation", # This relation doesn't exist in graph metadata
dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # Valid node type
)
config = self._create_link_prediction_task_config(
supervision_edge_types=[edge_with_new_relation],
graph_metadata=DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB,
)

# This should not raise any errors
check_if_task_metadata_valid(config)


if __name__ == "__main__":
unittest.main()