diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cf3ab0..5223180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `__version__` attribute to the package init [\#56](https://github.com/mllam/weather-model-graphs/pull/56) @AdMub +### Changed + +- Make `test_create_decode_mask` more specific and future-proof by checking only m2g (decoding) edges instead of total edges. [\#104](https://github.com/mllam/weather-model-graphs/pull/104) @zweihuehner + ## [v0.3.0](https://github.com/mllam/weather-model-graphs/releases/tag/v0.3.0) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index b45d486..9b72759 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -197,8 +197,8 @@ def test_create_lat_lon(kind): @pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"]) def test_create_decode_mask(kind): """ - Tests that the decode mask for m2g works, resulting in less edges than - no filtering. + Tests that the decode mask for m2g works, resulting in fewer m2g edges than + without a decode mask. """ xy = test_utils.create_fake_irregular_coords(100) fn_name = f"create_{kind}_graph" @@ -214,8 +214,14 @@ def test_create_decode_mask(kind): coords=xy, mesh_node_distance=mesh_node_distance, decode_mask=decode_mask ) - # Check that some filtering has been performed - assert len(filtered_graph.edges) < len(unfiltered_graph.edges) + # Check that m2g edges (mesh-to-grid / decoding edges) are filtered + unfiltered_m2g_edges = [ + e for e in unfiltered_graph.edges(data=True) if e[2].get("component") == "m2g" + ] + filtered_m2g_edges = [ + e for e in filtered_graph.edges(data=True) if e[2].get("component") == "m2g" + ] + assert len(filtered_m2g_edges) < len(unfiltered_m2g_edges) @pytest.mark.parametrize("kind", ["graphcast", "oskarsson_hierarchical"])