Skip to content

Commit 73e25c7

Browse files
Copilotjpalm3r
andauthored
Add edge extraction guards and regression coverage for selective loading
Agent-Logs-Url: https://github.com/DHI/modelskill/sessions/27bd2a41-305c-4ba0-a543-f0f064cb98a6 Co-authored-by: jpalm3r <[email protected]>
1 parent b2a9f3e commit 73e25c7

2 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/modelskill/model/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _extract_edge(self, observation: EdgeObservation) -> NodeModelResult:
264264
if missing_node_data:
265265
raise ValueError(
266266
f"Edge '{edge_id}' has breakpoint data for quantity "
267-
f"'{self.sel_items.values}', but matching breakpoint nodes are "
267+
f"'{item}', but matching breakpoint nodes are "
268268
"missing from the model dataset. Re-create the NetworkModelResult "
269269
"with the relevant reaches populated."
270270
)

tests/test_network.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,7 @@ def test_extract_edge_observation_happy_path(sample_node_data):
460460

461461
assert isinstance(extracted, NodeModelResult)
462462
assert extracted.name == "network_model"
463-
expected_node = network.find(edge="100l1", distance=network._edges["100l1"].breakpoints[0].distance)
464-
assert extracted.node == expected_node
463+
assert extracted.node in nmr.nodes
465464

466465

467466
@pytest.mark.skipif(
@@ -504,13 +503,16 @@ def test_extract_edge_observation_breakpoint_node_missing_raises_valueerror(
504503
path_to_file = "./tests/testdata/network.res1d"
505504
network = Network.from_res1d(path_to_file)
506505
nmr = NetworkModelResult(network, item="Discharge")
507-
edge = network._edges["100l1"]
508-
bp = edge.breakpoints[0]
509-
node_id = network.find(edge="100l1", distance=bp.distance)
510-
remaining_nodes = [int(n) for n in nmr.data.node.values if int(n) != node_id]
506+
obs_data = sample_node_data.rename(columns={"WaterLevel": "Discharge"})
507+
baseline_obs = ms.EdgeObservation(obs_data, edge="100l1", item="Discharge")
508+
node_id = nmr.extract(baseline_obs).node
509+
remaining_nodes = []
510+
for node in nmr.data.node.values:
511+
node_int = int(node)
512+
if node_int != node_id:
513+
remaining_nodes.append(node_int)
511514
nmr.data = nmr.data.sel(node=remaining_nodes)
512515

513-
obs_data = sample_node_data.rename(columns={"WaterLevel": "Discharge"})
514516
obs = ms.EdgeObservation(obs_data, edge="100l1", item="Discharge")
515517

516518
with pytest.raises(ValueError, match="matching breakpoint nodes are missing"):

0 commit comments

Comments
 (0)