diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index b24b96c2..b1c60487 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -18,6 +18,65 @@ from .config import load_config_and_datastore from .datastore.base import BaseRegularGridDatastore +def validate_graph(edge_index, num_nodes, name="graph"): + """ + Basic sanity checks to catch common graph issues early. + """ + + if edge_index.shape[0] != 2: + raise ValueError(f"[{name}] edge_index should have shape [2, num_edges]") + + if edge_index.numel() == 0: + raise ValueError(f"[{name}] edge_index is empty") + + if edge_index.min() < 0: + raise ValueError(f"[{name}] found negative node indices") + + if edge_index.max() >= num_nodes: + logger.warning( + f"[{name}] edge_index contains node index >= num_nodes ({num_nodes}) " + f"(may be valid for offset/global indexing)" + ) + + logger.info(f"[{name}] validation passed") + + +def compute_graph_stats(edge_index, num_nodes, name="graph"): + """ + Log a few useful stats so it's easier to understand the graph structure. + """ + + import torch + + num_edges = edge_index.shape[1] + + if num_edges == 0: + logger.warning(f"[{name}] graph has no edges") + return + + in_deg = torch.bincount(edge_index[1], minlength=num_nodes) + out_deg = torch.bincount(edge_index[0], minlength=num_nodes) + deg = in_deg + out_deg + + deg_float = deg.float() + in_float = in_deg.float() + out_float = out_deg.float() + + avg_degree = deg_float.mean().item() + max_degree = deg.max().item() + isolated_nodes = (deg == 0).sum().item() + + avg_in = in_float.mean().item() + avg_out = out_float.mean().item() + + logger.info(f"[{name}] nodes: {num_nodes}") + logger.info(f"[{name}] edges: {num_edges}") + logger.info(f"[{name}] avg degree: {avg_degree:.2f}") + logger.info(f"[{name}] max degree: {max_degree}") + logger.info(f"[{name}] isolated nodes: {isolated_nodes}") + logger.info(f"[{name}] avg in-degree: {avg_in:.2f}") + logger.info(f"[{name}] avg out-degree: {avg_out:.2f}") + def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -361,6 +420,9 @@ def create_graph( ) for level_graph, start_index in zip(G, first_index_level) ] + for level, graph in enumerate(m2m_graphs): + validate_graph(graph.edge_index, graph.num_nodes, f"m2m_level_{level}") + compute_graph_stats(graph.edge_index, graph.num_nodes, f"m2m_level_{level}") mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] @@ -486,6 +548,8 @@ def create_graph( ) pyg_g2m = from_networkx(G_g2m) + validate_graph(pyg_g2m.edge_index, pyg_g2m.num_nodes, "g2m") + compute_graph_stats(pyg_g2m.edge_index, pyg_g2m.num_nodes, "g2m") if create_plot: plot_graph(pyg_g2m, title="Grid-to-mesh") @@ -525,6 +589,8 @@ def create_graph( G_m2g, first_label=0, ordering="sorted" ) pyg_m2g = from_networkx(G_m2g_int) + validate_graph(pyg_m2g.edge_index, pyg_m2g.num_nodes, "m2g") + compute_graph_stats(pyg_m2g.edge_index, pyg_m2g.num_nodes, "m2g") if create_plot: plot_graph(pyg_m2g, title="Mesh-to-grid") @@ -594,8 +660,9 @@ def cli(input_args=None): ) args = parser.parse_args(input_args) - if args.config_path is None: - raise ValueError("Specify your config with --config_path") + assert ( + args.config_path is not None + ), "Specify your config with --config_path" # Load neural-lam configuration and datastore to use _, datastore = load_config_and_datastore(config_path=args.config_path) diff --git a/neural_lam/custom_loggers.py b/neural_lam/custom_loggers.py index 635f515e..533c86fa 100644 --- a/neural_lam/custom_loggers.py +++ b/neural_lam/custom_loggers.py @@ -6,6 +6,8 @@ import mlflow.pytorch import pytorch_lightning as pl from loguru import logger +from typing import List, Optional +from matplotlib.figure import Figure class CustomMLFlowLogger(pl.loggers.MLFlowLogger): @@ -15,7 +17,7 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger): of version `2.0.3` at least. """ - def __init__(self, experiment_name, tracking_uri, run_name): + def __init__(self, experiment_name: str, tracking_uri: str, run_name: str) -> None: super().__init__( experiment_name=experiment_name, tracking_uri=tracking_uri ) @@ -25,7 +27,7 @@ def __init__(self, experiment_name, tracking_uri, run_name): mlflow.log_param("run_id", self.run_id) @property - def save_dir(self): + def save_dir(self) -> str: """ Returns the directory where the MLFlow artifacts are saved. Used to define the path to save output when using the logger. @@ -37,7 +39,7 @@ def save_dir(self): """ return "mlruns" - def log_image(self, key, images, step=None): + def log_image(self,key: str,images: List[Figure],step: Optional[int] = None,) -> None: """ Log a matplotlib figure as an image to MLFlow