diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index 5ecfc60..b955111 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -178,7 +178,16 @@ where let start_pos = port_locations[locator]; // Find a port to connect to - fn snap_to_ports, Value>( + fn snap_to_ports< + NodeData, + UserState, + DataType: DataTypeTrait, + ValueType, + Key: slotmap::Key + Into, + Value, + >( + graph: &Graph, + port_type: &DataType, ports: &SlotMap, port_locations: &PortLocations, cursor_pos: Pos2, @@ -186,23 +195,45 @@ where ports .iter() .find_map(|(port_id, _)| { - port_locations.get(&port_id.into()).and_then(|port_pos| { - if port_pos.distance(cursor_pos) < DISTANCE_TO_CONNECT { - Some(*port_pos) - } else { - None - } - }) + let compatible_ports = graph + .any_param_type(port_id.into()) + .map(|other| other == port_type) + .unwrap_or(false); + + if compatible_ports { + port_locations.get(&port_id.into()).and_then(|port_pos| { + if port_pos.distance(cursor_pos) < DISTANCE_TO_CONNECT { + Some(*port_pos) + } else { + None + } + }) + } else { + None + } }) .unwrap_or(cursor_pos) } + let (src_pos, dst_pos) = match locator { AnyParameterId::Output(_) => ( start_pos, - snap_to_ports(&self.graph.inputs, &port_locations, cursor_pos), + snap_to_ports( + &self.graph, + port_type, + &self.graph.inputs, + &port_locations, + cursor_pos, + ), ), AnyParameterId::Input(_) => ( - snap_to_ports(&self.graph.outputs, &port_locations, cursor_pos), + snap_to_ports( + &self.graph, + port_type, + &self.graph.outputs, + &port_locations, + cursor_pos, + ), start_pos, ), }; @@ -261,7 +292,7 @@ where self.node_order.retain(|id| *id != *node_id); } NodeResponse::DisconnectEvent { input, output } => { - let other_node = self.graph.get_input(*input).node(); + let other_node = self.graph.get_output(*output).node; self.graph.remove_connection(*input); self.connection_in_progress = Some((other_node, AnyParameterId::Output(*output)));