@@ -460,8 +460,7 @@ def test_extract_edge_observation_happy_path(sample_node_data):
460460
461461 assert isinstance (extracted , NodeModelResult )
462462 assert extracted .name == "network_model"
463- expected_node = network .find (edge = "100l1" , distance = network ._edges ["100l1" ].breakpoints [0 ].distance )
464- assert extracted .node == expected_node
463+ assert extracted .node in nmr .nodes
465464
466465
467466@pytest .mark .skipif (
@@ -504,13 +503,16 @@ def test_extract_edge_observation_breakpoint_node_missing_raises_valueerror(
504503 path_to_file = "./tests/testdata/network.res1d"
505504 network = Network .from_res1d (path_to_file )
506505 nmr = NetworkModelResult (network , item = "Discharge" )
507- edge = network ._edges ["100l1" ]
508- bp = edge .breakpoints [0 ]
509- node_id = network .find (edge = "100l1" , distance = bp .distance )
510- remaining_nodes = [int (n ) for n in nmr .data .node .values if int (n ) != node_id ]
506+ obs_data = sample_node_data .rename (columns = {"WaterLevel" : "Discharge" })
507+ baseline_obs = ms .EdgeObservation (obs_data , edge = "100l1" , item = "Discharge" )
508+ node_id = nmr .extract (baseline_obs ).node
509+ remaining_nodes = []
510+ for node in nmr .data .node .values :
511+ node_int = int (node )
512+ if node_int != node_id :
513+ remaining_nodes .append (node_int )
511514 nmr .data = nmr .data .sel (node = remaining_nodes )
512515
513- obs_data = sample_node_data .rename (columns = {"WaterLevel" : "Discharge" })
514516 obs = ms .EdgeObservation (obs_data , edge = "100l1" , item = "Discharge" )
515517
516518 with pytest .raises (ValueError , match = "matching breakpoint nodes are missing" ):
0 commit comments