Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,65 @@
from .config import load_config_and_datastore
from .datastore.base import BaseRegularGridDatastore

def validate_graph(edge_index, num_nodes, name="graph"):
"""
Basic sanity checks to catch common graph issues early.
"""

if edge_index.shape[0] != 2:
raise ValueError(f"[{name}] edge_index should have shape [2, num_edges]")

if edge_index.numel() == 0:
raise ValueError(f"[{name}] edge_index is empty")

if edge_index.min() < 0:
raise ValueError(f"[{name}] found negative node indices")

if edge_index.max() >= num_nodes:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is incompatible with the hierarchical m2m graphs below, because from_networkx_with_start_index() intentionally keeps globally offset node ids. For level 1+ edge_index.max() can be much larger than num_nodes even though the graph is valid, so this now breaks hierarchical graph generation.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback I've updated the stats to use total degree (in + out), and added separate logging for in-degree and out-degree. Also added a small guard for empty graphs to avoid edge-case issues. Let me know if this looks good!

logger.warning(
f"[{name}] edge_index contains node index >= num_nodes ({num_nodes}) "
f"(may be valid for offset/global indexing)"
)

logger.info(f"[{name}] validation passed")


def compute_graph_stats(edge_index, num_nodes, name="graph"):
"""
Log a few useful stats so it's easier to understand the graph structure.
"""

import torch

num_edges = edge_index.shape[1]

if num_edges == 0:
logger.warning(f"[{name}] graph has no edges")
return

in_deg = torch.bincount(edge_index[1], minlength=num_nodes)
out_deg = torch.bincount(edge_index[0], minlength=num_nodes)
deg = in_deg + out_deg

deg_float = deg.float()
in_float = in_deg.float()
out_float = out_deg.float()

avg_degree = deg_float.mean().item()
max_degree = deg.max().item()
isolated_nodes = (deg == 0).sum().item()

avg_in = in_float.mean().item()
avg_out = out_float.mean().item()

logger.info(f"[{name}] nodes: {num_nodes}")
logger.info(f"[{name}] edges: {num_edges}")
logger.info(f"[{name}] avg degree: {avg_degree:.2f}")
logger.info(f"[{name}] max degree: {max_degree}")
logger.info(f"[{name}] isolated nodes: {isolated_nodes}")
logger.info(f"[{name}] avg in-degree: {avg_in:.2f}")
logger.info(f"[{name}] avg out-degree: {avg_out:.2f}")


def plot_graph(graph, title=None):
fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H
Expand Down Expand Up @@ -361,6 +420,9 @@ def create_graph(
)
for level_graph, start_index in zip(G, first_index_level)
]
for level, graph in enumerate(m2m_graphs):
validate_graph(graph.edge_index, graph.num_nodes, f"m2m_level_{level}")
compute_graph_stats(graph.edge_index, graph.num_nodes, f"m2m_level_{level}")

mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs]

Expand Down Expand Up @@ -486,6 +548,8 @@ def create_graph(
)

pyg_g2m = from_networkx(G_g2m)
validate_graph(pyg_g2m.edge_index, pyg_g2m.num_nodes, "g2m")
compute_graph_stats(pyg_g2m.edge_index, pyg_g2m.num_nodes, "g2m")

if create_plot:
plot_graph(pyg_g2m, title="Grid-to-mesh")
Expand Down Expand Up @@ -525,6 +589,8 @@ def create_graph(
G_m2g, first_label=0, ordering="sorted"
)
pyg_m2g = from_networkx(G_m2g_int)
validate_graph(pyg_m2g.edge_index, pyg_m2g.num_nodes, "m2g")
compute_graph_stats(pyg_m2g.edge_index, pyg_m2g.num_nodes, "m2g")

if create_plot:
plot_graph(pyg_m2g, title="Mesh-to-grid")
Expand Down Expand Up @@ -594,8 +660,9 @@ def cli(input_args=None):
)
args = parser.parse_args(input_args)

if args.config_path is None:
raise ValueError("Specify your config with --config_path")
assert (
args.config_path is not None
), "Specify your config with --config_path"

# Load neural-lam configuration and datastore to use
_, datastore = load_config_and_datastore(config_path=args.config_path)
Expand Down
8 changes: 5 additions & 3 deletions neural_lam/custom_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import mlflow.pytorch
import pytorch_lightning as pl
from loguru import logger
from typing import List, Optional
from matplotlib.figure import Figure


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
Expand All @@ -15,7 +17,7 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
of version `2.0.3` at least.
"""

def __init__(self, experiment_name, tracking_uri, run_name):
def __init__(self, experiment_name: str, tracking_uri: str, run_name: str) -> None:
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)
Expand All @@ -25,7 +27,7 @@ def __init__(self, experiment_name, tracking_uri, run_name):
mlflow.log_param("run_id", self.run_id)

@property
def save_dir(self):
def save_dir(self) -> str:
"""
Returns the directory where the MLFlow artifacts are saved.
Used to define the path to save output when using the logger.
Expand All @@ -37,7 +39,7 @@ def save_dir(self):
"""
return "mlruns"

def log_image(self, key, images, step=None):
def log_image(self,key: str,images: List[Figure],step: Optional[int] = None,) -> None:
"""
Log a matplotlib figure as an image to MLFlow

Expand Down