diff --git a/egui_node_graph/src/column_node.rs b/egui_node_graph/src/column_node.rs new file mode 100644 index 0000000..465c3ba --- /dev/null +++ b/egui_node_graph/src/column_node.rs @@ -0,0 +1,353 @@ +use super::*; +use crate::utils::ColorUtils; +use egui::*; +use epaint::RectShape; + +pub type SimpleColumnNode = + ColumnNode, VerticalOutputPort>; + +/// A node inside the [`Graph`]. Nodes have input and output parameters, stored +/// as ids. They also contain a custom `NodeData` struct with whatever data the +/// user wants to store per-node. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct ColumnNode { + pub position: Pos2, + pub label: String, + pub content: Content, + /// The input ports of the graph + pub inputs: SlotMap, + /// The output ports of the graph + pub outputs: SlotMap, + + /// The size hint is used to automatically scale the widget to a desirable + /// size while still allowing right-side ports to be justified to the right + /// size of the node widget. If the desired size of a widget inside of the + /// node's frame changes then the node size should be fixed after one bad + /// rendering cycle. + pub size_hint: f32, +} + +impl ColumnNode { + pub fn new(position: Pos2, label: String, content: Content) -> Self { + Self { + position, + label, + content, + inputs: SlotMap::with_key(), + outputs: SlotMap::with_key(), + size_hint: 80.0, + } + } + + pub fn with_input(mut self, input: InputPort) -> Self { + self.inputs.insert(input); + self + } + + pub fn with_output(mut self, output: OutputPort) -> Self { + self.outputs.insert(output); + self + } + + pub fn with_size_hint(mut self, size_hint: f32) -> Self { + self.size_hint = size_hint; + self + } +} + +impl NodeTrait for ColumnNode +where + Content: NodeContentTrait, + InputPort: PortTrait, + OutputPort: PortTrait, +{ + type DataType = InputPort::DataType; + type Content = Content; + + fn show( + &mut self, + parent_ui: &mut egui::Ui, + app: &::AppState, + node_id: NodeId, + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> Vec> { + let mut ui = parent_ui.child_ui_with_id_source( + Rect::from_min_size(self.position + state.pan, [self.size_hint, 0.0].into()), + Layout::default(), + node_id, + ); + + let margin = egui::vec2(15.0, 5.0); + let mut responses = Vec::>::new(); + + let background_color = style.recommend_node_background_color(&ui, node_id); + let text_color = style.recommend_node_text_color(&ui, node_id); + + ui.visuals_mut().widgets.noninteractive.fg_stroke = Stroke::new(2.0, text_color); + + // Forward declare shapes to paint below contents + let outline_shape = ui.painter().add(Shape::Noop); + let background_shape = ui.painter().add(Shape::Noop); + + let outer_rect_bounds = ui.available_rect_before_wrap(); + let inner_rect = { + let mut inner_rect = outer_rect_bounds.shrink2(margin); + + // Try to use the size hint, unless our outer limits are smaller + inner_rect.max.x = inner_rect.max.x.min(self.size_hint + inner_rect.min.x); + + // Make sure we don't shrink to the negative + inner_rect.max.x = inner_rect.max.x.max(inner_rect.min.x); + inner_rect.max.y = inner_rect.max.y.max(inner_rect.min.y); + + inner_rect + }; + + let mut title_height = 0.0; + let mut child_ui = ui.child_ui(inner_rect, *ui.layout()); + child_ui.vertical(|ui| { + let title_rect = ui.horizontal(|ui| { + ui.add(Label::new( + RichText::new(&self.label) + .text_style(TextStyle::Button) + .color(style.recommend_node_text_color(ui, node_id)), + )); + ui.add_space(8.0); // The size of the little cross icon + }).response.rect; + self.size_hint = title_rect.width(); + ui.add_space(margin.y); + title_height = ui.min_size().y; + + for (input_id, port) in &mut self.inputs { + ui.horizontal(|ui| { + let (rect, port_responses): (egui::Rect, Vec>) = port.show( + ui, (node_id, PortId::Input(input_id)), state, style + ); + responses.extend(port_responses.into_iter().map(NodeResponse::Port)); + self.size_hint = self.size_hint.max(rect.width()); + }); + } + + for (output_id, port) in &mut self.outputs { + ui.with_layout(egui::Layout::right_to_left(egui::Align::Center), |ui| { + let (rect, port_responses): (egui::Rect, Vec>) = port.show( + ui, (node_id, PortId::Output(output_id)), state, style + ); + responses.extend(port_responses.into_iter().map(NodeResponse::Port)); + self.size_hint = self.size_hint.max(rect.width()); + }); + } + + let (rect, resp) = self.content.content_ui(ui, app, node_id); + self.size_hint = self.size_hint.max(rect.width() + 3.0*margin.x); + responses.extend(resp.into_iter().map(NodeResponse::Content)); + }); + + let (shape, outline, outer_rect) = { + let rounding_radius = 4.0; + let rounding = Rounding::same(rounding_radius); + + let outer_rect = child_ui.min_rect().expand2(margin); + let titlebar_height = title_height + margin.y; + let titlebar_rect = Rect::from_min_size( + outer_rect.min, vec2(outer_rect.width(), titlebar_height) + ); + let titlebar = Shape::Rect(RectShape{ + rect: titlebar_rect, + rounding, + fill: self.content.titlebar_color( + &ui, app, node_id, + ).unwrap_or_else(|| style.recommend_node_background_color( + &ui, node_id).lighten(0.8) + ), + stroke: Stroke::none(), + }); + + let body_rect = Rect::from_min_size( + outer_rect.min + vec2(0.0, titlebar_height - rounding_radius), + vec2(outer_rect.width(), outer_rect.height() - titlebar_height), + ); + let body = Shape::Rect(RectShape{ + rect: body_rect, + rounding: Rounding::none(), + fill: background_color, + stroke: Stroke::none(), + }); + + let bottom_body_rect = Rect::from_min_size( + body_rect.min + vec2(0.0, body_rect.height() - titlebar_height * 0.5), + vec2(outer_rect.width(), title_height), + ); + let bottom_body = Shape::Rect(RectShape { + rect: bottom_body_rect, + rounding, + fill: background_color, + stroke: Stroke::none(), + }); + + let outline = if state.selected_nodes.contains(&node_id) { + Shape::Rect(RectShape { + rect: titlebar_rect + .union(body_rect) + .union(bottom_body_rect) + .expand(1.0), + rounding, + fill: Color32::WHITE.lighten(0.8), + stroke: Stroke::none(), + }) + } else { + Shape::Noop + }; + + (Shape::Vec(vec![titlebar, body, bottom_body]), outline, outer_rect) + }; + + ui.painter().set(background_shape, shape); + ui.painter().set(outline_shape, outline); + + // Make close button + if { + let margin = 8.0; + let size = 10.0; + let x_size = 8.0; + let stroke_width = 2.0; + let offset = margin + size / 2.0; + + let position = pos2(outer_rect.right() - offset, outer_rect.top() + offset); + let rect = Rect::from_center_size(position, vec2(size, size)); + let x_rect = Rect::from_center_size(position, vec2(x_size, x_size)); + let resp = ui.allocate_rect(rect, Sense::click()); + + let (stroke, fill) = if resp.dragged() { + style.recommend_close_button_clicked_colors(&ui, node_id) + } else if resp.hovered() { + style.recommend_close_button_hover_colors(&ui, node_id) + } else { + style.recommend_close_button_passive_colors(&ui, node_id) + }; + + ui.painter().rect(rect, 0.5, fill, (0_f32, fill)); + + let stroke = Stroke { + width: stroke_width, + color: stroke, + }; + ui.painter().line_segment([x_rect.left_top(), x_rect.right_bottom()], stroke); + ui.painter().line_segment([x_rect.right_top(), x_rect.left_bottom()], stroke); + + resp + }.clicked() { + responses.push(NodeResponse::DeleteNodeUi(node_id)); + } + + let window_response = ui.interact( + outer_rect, + Id::new((node_id, "window")), + Sense::click_and_drag(), + ); + + // Movement + self.position += window_response.drag_delta(); + if window_response.drag_delta().length_sq() > 0.0 { + responses.push(NodeResponse::RaiseNode(node_id)); + } + + // Node selection + if responses.is_empty() && window_response.clicked_by(PointerButton::Primary) { + responses.push(NodeResponse::SelectNode(node_id)); + responses.push(NodeResponse::RaiseNode(node_id)); + } + + responses + } + + fn port_data_type(&self, port_id: PortId) -> Option { + match port_id { + PortId::Input(port_id) => self.inputs.get(port_id).map(|p| p.data_type()), + PortId::Output(port_id) => self.outputs.get(port_id).map(|p| p.data_type()), + } + } + + fn available_hook(&self, port_id: PortId) -> Option { + match port_id { + PortId::Input(port_id) => self.inputs.get(port_id).map(|p| p.available_hook()).flatten(), + PortId::Output(port_id) => self.outputs.get(port_id).map(|p| p.available_hook()).flatten(), + } + } + + fn drop_connection( + &mut self, + (port, hook): (PortId, HookId) + ) -> Result { + match port { + PortId::Input(input_port) => { + match self.inputs.get_mut(input_port) { + Some(input_port) => { + input_port.drop_connection(hook).map_err( + |err| NodeDropConnectionError::PortError { port, err } + ) + } + None => { + Err(NodeDropConnectionError::BadPort(port)) + } + } + } + PortId::Output(output_port) => { + match self.outputs.get_mut(output_port) { + Some(output_port) => { + output_port.drop_connection(hook).map_err( + |err| NodeDropConnectionError::PortError { port, err } + ) + } + None => { + Err(NodeDropConnectionError::BadPort(port)) + } + } + } + } + } + + fn drop_all_connections(&mut self) -> Vec<(PortId, HookId, ConnectionId)> { + let mut dropped = Vec::new(); + for (id, port) in &mut self.inputs { + dropped.extend(port.drop_all_connections().into_iter().map( + |(hook, connection)| (PortId::Input(id), hook, connection) + )); + } + + for (id, port) in &mut self.outputs { + dropped.extend(port.drop_all_connections().into_iter().map( + |(hook, connection)| (PortId::Output(id), hook, connection) + )); + } + + dropped + } + + fn connect(&mut self, (port, hook): (PortId, HookId), to: graph::ConnectionToken) -> Result<(), NodeAddConnectionError> { + match port { + PortId::Input(input_port) => { + match self.inputs.get_mut(input_port) { + Some(input_port) => { + input_port.connect(hook, to).map_err( + |err| NodeAddConnectionError::PortError { port, err } + ) + } + None => Err(NodeAddConnectionError::BadPort(port)) + } + } + PortId::Output(output_port) => { + match self.outputs.get_mut(output_port) { + Some(output_port) => { + output_port.connect(hook, to).map_err( + |err| NodeAddConnectionError::PortError { port, err } + ) + } + None => Err(NodeAddConnectionError::BadPort(port)) + } + } + } + } +} diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index dbc61a8..d00082e 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -1,27 +1,48 @@ -use std::collections::HashSet; - -use crate::color_hex_utils::*; -use crate::utils::ColorUtils; +use std::collections::{HashSet, HashMap}; use super::*; -use egui::epaint::{CubicBezierShape, RectShape}; +use egui::epaint::CubicBezierShape; use egui::*; -pub type PortLocations = std::collections::HashMap; -pub type NodeRects = std::collections::HashMap; - -const DISTANCE_TO_CONNECT: f32 = 10.0; +/// For each hook, this specifies the (location, tangent vector) for any curve +/// rendering a connection out of it. +pub type HookGeometry = HashMap; -/// Nodes communicate certain events to the parent graph when drawn. There is -/// one special `User` variant which can be used by users as the return value -/// when executing some custom actions in the UI of the node. +/// Ports communicate connection and disconnection events to the parent graph +/// when drawn. #[derive(Clone, Debug)] -pub enum NodeResponse { - ConnectEventStarted(NodeId, AnyParameterId), +pub enum PortResponse { + /// The user is creating a new connection from the hook of ConnectionId + ConnectEventStarted(ConnectionId), + /// The user is moving the connection that used to be attached to ConnectionId + MoveEvent(ConnectionId), + /// A connection has been accepted by a port ConnectEventEnded { output: OutputId, input: InputId, }, + /// The value of a port has changed + Value(::Response) +} + +impl PortResponse { + pub fn connect_event_ended(a: ConnectionId, b: ConnectionId) -> Option { + if let (Some(input), Some(output)) = (a.as_input(), b.as_output()) { + Some(PortResponse::ConnectEventEnded{output, input}) + } else if let (Some(output), Some(input)) = (a.as_output(), b.as_input()) { + Some(PortResponse::ConnectEventEnded{output, input}) + } else { + None + } + } +} + +/// Nodes communicate certain events to the parent graph when drawn. There is +/// one special `User` variant which can be used by users as the return value +/// when executing some custom actions in the UI of the node. +#[derive(Clone, Debug)] +pub enum NodeResponse { + Port(PortResponse), CreatedNode(NodeId), SelectNode(NodeId), /// As a user of this library, prefer listening for `DeleteNodeFull` which @@ -32,84 +53,65 @@ pub enum NodeResponse /// contents are passed along with the event. DeleteNodeFull { node_id: NodeId, - node: Node, - }, - DisconnectEvent { - output: OutputId, - input: InputId, + node: Node, }, + /// Emitted for each disconnection that has occurred in the graph + DisconnectEvent{ input: InputId, output: OutputId }, /// Emitted when a node is interacted with, and should be raised RaiseNode(NodeId), - MoveNode { - node: NodeId, - drag_delta: Vec2, - }, - User(UserResponse), + Content(ContentResponseOf), } -/// The return value of [`draw_graph_editor`]. This value can be used to make -/// user code react to specific events that happened when drawing the graph. -#[derive(Clone, Debug)] -pub struct GraphResponse { - /// Events that occurred during this frame of rendering the graph. Check the - /// [`UserResponse`] type for a description of each event. - pub node_responses: Vec>, - /// Is the mouse currently hovering the graph editor? Note that the node - /// finder is considered part of the graph editor, even when it floats - /// outside the graph editor rect. - pub cursor_in_editor: bool, - /// Is the mouse currently hovering the node finder? - pub cursor_in_finder: bool, +impl NodeResponse { + fn disconnect(a: ConnectionId, b: ConnectionId) -> Result { + let (input, output) = match a.port() { + PortId::Input(_) => { + match b.as_output() { + Some(output) => (a.assume_input(), output), + None => return Err(()), + } + } + PortId::Output(_) => { + match b.as_input() { + Some(input) => (input, a.assume_output()), + None => return Err(()), + } + } + }; + + Ok(Self::DisconnectEvent { input, output }) + } } -impl Default - for GraphResponse -{ - fn default() -> Self { - Self { - node_responses: Default::default(), - cursor_in_editor: false, - cursor_in_finder: false, - } + +/// Automatically convert a Port Response into a NodeResponse +impl From> for NodeResponse { + fn from(value: PortResponse) -> Self { + Self::Port(value) } } -pub struct GraphNodeWidget<'a, NodeData, DataType, ValueType> { - pub position: &'a mut Pos2, - pub graph: &'a mut Graph, - pub port_locations: &'a mut PortLocations, - pub node_rects: &'a mut NodeRects, - pub node_id: NodeId, - pub ongoing_drag: Option<(NodeId, AnyParameterId)>, - pub selected: bool, - pub pan: egui::Vec2, + +pub struct EditorUiState<'a, DataType> { + pub pan: Vec2, + pub hook_geometry: &'a mut HookGeometry, + pub ongoing_drag: Option<(ConnectionId, DataType)>, + pub selected_nodes: &'a Vec, +} + +/// The return value of [`draw_graph_editor`]. This value can be used to make +/// user code react to specific events that happened when drawing the graph. +#[derive(Clone, Debug)] +pub struct GraphResponse { + pub node_responses: Vec>, } -impl - GraphEditorState -where - NodeData: NodeDataTrait< - Response = UserResponse, - UserState = UserState, - DataType = DataType, - ValueType = ValueType, - >, - UserResponse: UserResponseTrait, - ValueType: - WidgetValueTrait, - NodeTemplate: NodeTemplateTrait< - NodeData = NodeData, - DataType = DataType, - ValueType = ValueType, - UserState = UserState, - >, - DataType: DataTypeTrait, -{ +impl GraphEditorState { #[must_use] pub fn draw_graph_editor( &mut self, ui: &mut Ui, - all_kinds: impl NodeTemplateIter, - user_state: &mut UserState, - ) -> GraphResponse { + all_kinds: impl NodeTemplateIter, + app_state: &mut AppStateOf, + ) -> GraphResponse { // This causes the graph editor to use as much free space as it can. // (so for windows it will use up to the resizeably set limit // and for a Panel it will fill it completely) @@ -120,44 +122,44 @@ where let mut cursor_in_editor = editor_rect.contains(cursor_pos); let mut cursor_in_finder = false; - // Gets filled with the node metrics as they are drawn - let mut port_locations = PortLocations::new(); - let mut node_rects = NodeRects::new(); + // Gets filled with the port locations as nodes are drawn + let mut hook_geometry = HookGeometry::new(); // The responses returned from node drawing have side effects that are best // executed at the end of this function. - let mut delayed_responses: Vec> = vec![]; + let mut delayed_responses: Vec> = vec![]; - // Used to detect when the background was clicked + // Used to detect when the background was clicked, to dismiss certain selfs let mut click_on_background = false; - // Used to detect drag events in the background - let mut drag_started_on_background = false; - let mut drag_released_on_background = false; - debug_assert_eq!( self.node_order.iter().copied().collect::>(), - self.graph.iter_nodes().collect::>(), + self.graph.iter_nodes().map(|(id, _)| id).collect::>(), "The node_order field of the GraphEditorself was left in an \ - inconsistent self. It has either more or less values than the graph." + inconsistent state. It has either more or less values than the graph." ); + let ongoing_drag = self.connection_in_progress + .and_then(|connection| self.graph.node(connection.node()).map(|n| (connection, n))) + .and_then(|(connection, n)| n.port_data_type(connection.port()).map(|d| (connection, d))); + + let mut state = EditorUiState { + pan: self.pan_zoom.pan + editor_rect.min.to_vec2(), + hook_geometry: &mut hook_geometry, + ongoing_drag: ongoing_drag.clone(), + selected_nodes: &self.selected_nodes, + }; + /* Draw nodes */ for node_id in self.node_order.iter().copied() { - let responses = GraphNodeWidget { - position: self.node_positions.get_mut(node_id).unwrap(), - graph: &mut self.graph, - port_locations: &mut port_locations, - node_rects: &mut node_rects, - node_id, - ongoing_drag: self.connection_in_progress, - selected: self - .selected_nodes - .iter() - .any(|selected| *selected == node_id), - pan: self.pan_zoom.pan + editor_rect.min.to_vec2(), - } - .show(ui, user_state); + let responses = match self.graph.node_mut(node_id, |node| { + node.show(ui, app_state, node_id, &mut state, &self.context) + }) { + Ok(responses) => responses, + // TODO(MXG): Should we alert the user to this error? It really + // shouldn't happen... + Err(()) => continue, + }; // Actions executed later delayed_responses.extend(responses); @@ -166,39 +168,28 @@ where let r = ui.allocate_rect(ui.min_rect(), Sense::click().union(Sense::drag())); if r.clicked() { click_on_background = true; - } else if r.drag_started() { - drag_started_on_background = true; - } else if r.drag_released() { - drag_released_on_background = true; } /* Draw the node finder, if open */ let mut should_close_node_finder = false; if let Some(ref mut node_finder) = self.node_finder { let mut node_finder_area = Area::new("node_finder").order(Order::Foreground); - if let Some(pos) = node_finder.position { - node_finder_area = node_finder_area.current_pos(pos); - } + node_finder_area = node_finder_area.current_pos(node_finder.position); node_finder_area.show(ui.ctx(), |ui| { - if let Some(node_kind) = node_finder.show(ui, all_kinds, user_state) { + if let Some(node_kind) = node_finder.show(ui, all_kinds) { let new_node = self.graph.add_node( - node_kind.node_graph_label(user_state), - node_kind.user_data(user_state), - |graph, node_id| node_kind.build_node(graph, user_state, node_id), - ); - self.node_positions.insert( - new_node, - cursor_pos - self.pan_zoom.pan - editor_rect.min.to_vec2(), + node_kind.build_node(node_finder.position - state.pan, app_state), ); + self.node_order.push(new_node); should_close_node_finder = true; delayed_responses.push(NodeResponse::CreatedNode(new_node)); } - let finder_rect = ui.min_rect(); - // If the cursor is not in the main editor, check if the cursor is in the finder + let finder_rect = ui.max_rect(); + // If the cursor is not in the main editor, check if the cursor *is* in the finder // if the cursor is in the finder, then we can consider that also in the editor. - if finder_rect.contains(cursor_pos) { + if !cursor_in_editor && finder_rect.contains(cursor_pos) { cursor_in_editor = true; cursor_in_finder = true; } @@ -209,131 +200,98 @@ where } /* Draw connections */ - if let Some((_, ref locator)) = self.connection_in_progress { - let port_type = self.graph.any_param_type(*locator).unwrap(); - let connection_color = port_type.data_type_color(user_state); - let start_pos = port_locations[locator]; - - // Find a port to connect to - fn snap_to_ports< - NodeData, - UserState, - DataType: DataTypeTrait, - ValueType, - Key: slotmap::Key + Into, - Value, - >( - graph: &Graph, - port_type: &DataType, - ports: &SlotMap, - port_locations: &PortLocations, - cursor_pos: Pos2, - ) -> Pos2 { - ports - .iter() - .find_map(|(port_id, _)| { - let compatible_ports = graph - .any_param_type(port_id.into()) - .map(|other| other == port_type) - .unwrap_or(false); - - if compatible_ports { - port_locations.get(&port_id.into()).and_then(|port_pos| { - if port_pos.distance(cursor_pos) < DISTANCE_TO_CONNECT { - Some(*port_pos) - } else { - None - } - }) - } else { - None - } - }) - .unwrap_or(cursor_pos) - } - - let (src_pos, dst_pos) = match locator { - AnyParameterId::Output(_) => ( - start_pos, - snap_to_ports( - &self.graph, - port_type, - &self.graph.inputs, - &port_locations, - cursor_pos, - ), - ), - AnyParameterId::Input(_) => ( - snap_to_ports( - &self.graph, - port_type, - &self.graph.outputs, - &port_locations, - cursor_pos, - ), - start_pos, - ), + if let Some((connection, data_type)) = &ongoing_drag { + let connection_color = self.context.recommend_data_type_color(&data_type); + let hook_geom = hook_geometry[connection]; + let cursor_geom = (cursor_pos, (hook_geom.0 - cursor_pos).normalized()); + let (start_geom, end_geom) = match connection.port() { + PortId::Input(_) => (cursor_geom, hook_geom), + PortId::Output(_) => (hook_geom, cursor_geom), }; - draw_connection(ui.painter(), src_pos, dst_pos, connection_color); + draw_connection(ui.painter(), start_geom, end_geom, connection_color); } for (input, output) in self.graph.iter_connections() { - let port_type = self + let data_type = self .graph - .any_param_type(AnyParameterId::Output(output)) - .unwrap(); - let connection_color = port_type.data_type_color(user_state); - let src_pos = port_locations[&AnyParameterId::Output(output)]; - let dst_pos = port_locations[&AnyParameterId::Input(input)]; - draw_connection(ui.painter(), src_pos, dst_pos, connection_color); + .node(input.node()).expect("node missing for a connection") + .port_data_type(input.port().into()).expect("port missing for a connection"); + let connection_color = self.context.recommend_data_type_color(&data_type); + let start_geom = hook_geometry[&output.into()]; + let end_geom = hook_geometry[&input.into()]; + draw_connection(ui.painter(), start_geom, end_geom, connection_color); } /* Handle responses from drawing nodes */ // Some responses generate additional responses when processed. These // are stored here to report them back to the user. - let mut extra_responses: Vec> = Vec::new(); + let mut extra_responses = Vec::new(); for response in delayed_responses.iter() { match response { - NodeResponse::ConnectEventStarted(node_id, port) => { - self.connection_in_progress = Some((*node_id, *port)); - } - NodeResponse::ConnectEventEnded { input, output } => { - self.graph.add_connection(*output, *input) + NodeResponse::Port(port_response) => { + match port_response { + PortResponse::ConnectEventStarted(connection) => { + self.connection_in_progress = Some(*connection); + } + PortResponse::MoveEvent(connection) => { + if let Ok(complement) = self.graph.drop_connection(*connection) { + extra_responses.push( + NodeResponse::disconnect(*connection, complement) + .expect("invalid input/output pair for connection") + ); + if let Some(available_hook) = self.graph.node(complement.node()) + .and_then(|n| n.available_hook(complement.port())) + { + self.connection_in_progress = Some( + ConnectionId(complement.node(), complement.port(), available_hook) + ); + } + } + } + PortResponse::ConnectEventEnded { output, input } => { + // TODO(MXG): Report errors for this? + self.graph.add_connection(*output, *input).ok(); + } + PortResponse::Value(_) => { + // User-defined response type + } + } } NodeResponse::CreatedNode(_) => { - //Convenience NodeResponse for users + // Convenience NodeResponse for users } NodeResponse::SelectNode(node_id) => { - self.selected_nodes = Vec::from([*node_id]); + if !ui.input().modifiers.shift { + self.selected_nodes.clear(); + } + if !self.selected_nodes.contains(node_id) { + self.selected_nodes.push(*node_id); + } + } + NodeResponse::DisconnectEvent { input, .. } => { + // TOOD(MXG): Report errors for this? + self.graph.drop_connection(input.clone().into()).ok(); } NodeResponse::DeleteNodeUi(node_id) => { - let (node, disc_events) = self.graph.remove_node(*node_id); - // Pass the disconnection responses first so user code can perform cleanup - // before node removal response. - extra_responses.extend( - disc_events - .into_iter() - .map(|(input, output)| NodeResponse::DisconnectEvent { input, output }), - ); - // Pass the full node as a response so library users can - // listen for it and get their user data. - extra_responses.push(NodeResponse::DeleteNodeFull { - node_id: *node_id, - node, - }); - self.node_positions.remove(*node_id); + if let Some((node, disc_events)) = self.graph.remove_node(*node_id) { + // Pass the full node as a response so library users can + // listen for it and get their user data. + extra_responses.push(NodeResponse::DeleteNodeFull { + node_id: *node_id, + node, + }); + extra_responses.extend( + disc_events + .into_iter() + .map(|(input, output)| NodeResponse::DisconnectEvent { input, output }), + ); + } // Make sure to not leave references to old nodes hanging self.selected_nodes.retain(|id| *id != *node_id); self.node_order.retain(|id| *id != *node_id); } - NodeResponse::DisconnectEvent { input, output } => { - let other_node = self.graph.get_output(*output).node; - self.graph.remove_connection(*input); - self.connection_in_progress = - Some((other_node, AnyParameterId::Output(*output))); - } NodeResponse::RaiseNode(node_id) => { let old_pos = self .node_order @@ -343,18 +301,7 @@ where self.node_order.remove(old_pos); self.node_order.push(*node_id); } - NodeResponse::MoveNode { node, drag_delta } => { - self.node_positions[*node] += *drag_delta; - // Handle multi-node selection movement - if self.selected_nodes.contains(node) && self.selected_nodes.len() > 1 { - for n in self.selected_nodes.iter().copied() { - if n != *node { - self.node_positions[n] += *drag_delta; - } - } - } - } - NodeResponse::User(_) => { + NodeResponse::Content(_) => { // These are handled by the user code. } NodeResponse::DeleteNodeFull { .. } => { @@ -363,30 +310,6 @@ where } } - // Handle box selection - if let Some(box_start) = self.ongoing_box_selection { - let selection_rect = Rect::from_two_pos(cursor_pos, box_start); - let bg_color = Color32::from_rgba_unmultiplied(200, 200, 200, 20); - let stroke_color = Color32::from_rgba_unmultiplied(200, 200, 200, 180); - ui.painter().rect( - selection_rect, - 2.0, - bg_color, - Stroke::new(3.0, stroke_color), - ); - - self.selected_nodes = node_rects - .into_iter() - .filter_map(|(node_id, rect)| { - if selection_rect.intersects(rect) { - Some(node_id) - } else { - None - } - }) - .collect(); - } - // Push any responses that were generated during response handling. // These are only informative for the end-user and need no special // treatment here. @@ -401,7 +324,7 @@ where self.connection_in_progress = None; } - if mouse.secondary_released() && cursor_in_editor && !cursor_in_finder { + if mouse.secondary_down() && cursor_in_editor && !cursor_in_finder { self.node_finder = Some(NodeFinder::new_at(cursor_pos)); } if ui.ctx().input().key_pressed(Key::Escape) { @@ -412,37 +335,46 @@ where self.pan_zoom.pan += ui.ctx().input().pointer.delta(); } - // Deselect and deactivate finder if the editor backround is clicked, - // *or* if the the mouse clicks off the ui - if click_on_background || (mouse.any_click() && !cursor_in_editor) { - self.selected_nodes = Vec::new(); - self.node_finder = None; + if click_on_background && !ui.input().modifiers.shift { + // Clear the selected nodes if the background is clicked when shift + // is not selected. + self.selected_nodes.clear(); } - if drag_started_on_background && mouse.primary_down() { - self.ongoing_box_selection = Some(cursor_pos); - } - if mouse.primary_released() || drag_released_on_background { - self.ongoing_box_selection = None; + if click_on_background || (mouse.any_click() && !cursor_in_editor) { + // Deactivate finder if the editor backround is clicked, + // *or* if the the mouse clicks off the ui + self.node_finder = None; } GraphResponse { node_responses: delayed_responses, - cursor_in_editor, - cursor_in_finder, } } } -fn draw_connection(painter: &Painter, src_pos: Pos2, dst_pos: Pos2, color: Color32) { +fn calculate_control( + a_pos: Pos2, + b_pos: Pos2, + tangent: Vec2, +) -> Pos2 { + let delta = ((a_pos - b_pos).dot(tangent).abs()/2.0).max(30.0); + a_pos + delta*tangent +} + +fn draw_connection( + painter: &Painter, + (start_pos, start_tangent): (Pos2, Vec2), + (end_pos, end_tangent): (Pos2, Vec2), + color: Color32 +) { let connection_stroke = egui::Stroke { width: 5.0, color }; - let control_scale = ((dst_pos.x - src_pos.x) / 2.0).max(30.0); - let src_control = src_pos + Vec2::X * control_scale; - let dst_control = dst_pos - Vec2::X * control_scale; + let start_control = calculate_control(start_pos, end_pos, start_tangent); + let end_control = calculate_control(end_pos, start_pos, end_tangent); let bezier = CubicBezierShape::from_points_stroke( - [src_pos, src_control, dst_control, dst_pos], + [start_pos, start_control, end_control, end_pos], false, Color32::TRANSPARENT, connection_stroke, @@ -450,419 +382,3 @@ fn draw_connection(painter: &Painter, src_pos: Pos2, dst_pos: Pos2, color: Color painter.add(bezier); } - -impl<'a, NodeData, DataType, ValueType, UserResponse, UserState> - GraphNodeWidget<'a, NodeData, DataType, ValueType> -where - NodeData: NodeDataTrait< - Response = UserResponse, - UserState = UserState, - DataType = DataType, - ValueType = ValueType, - >, - UserResponse: UserResponseTrait, - ValueType: - WidgetValueTrait, - DataType: DataTypeTrait, -{ - pub const MAX_NODE_SIZE: [f32; 2] = [200.0, 200.0]; - - pub fn show( - self, - ui: &mut Ui, - user_state: &mut UserState, - ) -> Vec> { - let mut child_ui = ui.child_ui_with_id_source( - Rect::from_min_size(*self.position + self.pan, Self::MAX_NODE_SIZE.into()), - Layout::default(), - self.node_id, - ); - - Self::show_graph_node(self, &mut child_ui, user_state) - } - - /// Draws this node. Also fills in the list of port locations with all of its ports. - /// Returns responses indicating multiple events. - fn show_graph_node( - self, - ui: &mut Ui, - user_state: &mut UserState, - ) -> Vec> { - let margin = egui::vec2(15.0, 5.0); - let mut responses = Vec::>::new(); - - let background_color; - let text_color; - if ui.visuals().dark_mode { - background_color = color_from_hex("#3f3f3f").unwrap(); - text_color = color_from_hex("#fefefe").unwrap(); - } else { - background_color = color_from_hex("#ffffff").unwrap(); - text_color = color_from_hex("#505050").unwrap(); - } - - ui.visuals_mut().widgets.noninteractive.fg_stroke = Stroke::new(2.0, text_color); - - // Preallocate shapes to paint below contents - let outline_shape = ui.painter().add(Shape::Noop); - let background_shape = ui.painter().add(Shape::Noop); - - let outer_rect_bounds = ui.available_rect_before_wrap(); - let mut inner_rect = outer_rect_bounds.shrink2(margin); - - // Make sure we don't shrink to the negative: - inner_rect.max.x = inner_rect.max.x.max(inner_rect.min.x); - inner_rect.max.y = inner_rect.max.y.max(inner_rect.min.y); - - let mut child_ui = ui.child_ui(inner_rect, *ui.layout()); - let mut title_height = 0.0; - - let mut input_port_heights = vec![]; - let mut output_port_heights = vec![]; - - child_ui.vertical(|ui| { - ui.horizontal(|ui| { - ui.add(Label::new( - RichText::new(&self.graph[self.node_id].label) - .text_style(TextStyle::Button) - .color(text_color), - )); - ui.add_space(8.0); // The size of the little cross icon - }); - ui.add_space(margin.y); - title_height = ui.min_size().y; - - // First pass: Draw the inner fields. Compute port heights - let inputs = self.graph[self.node_id].inputs.clone(); - for (param_name, param_id) in inputs { - if self.graph[param_id].shown_inline { - let height_before = ui.min_rect().bottom(); - if self.graph.connection(param_id).is_some() { - ui.label(param_name); - } else { - // NOTE: We want to pass the `user_data` to - // `value_widget`, but we can't since that would require - // borrowing the graph twice. Here, we make the - // assumption that the value is cheaply replaced, and - // use `std::mem::take` to temporarily replace it with a - // dummy value. This requires `ValueType` to implement - // Default, but results in a totally safe alternative. - let mut value = std::mem::take(&mut self.graph[param_id].value); - let node_responses = value.value_widget( - ¶m_name, - self.node_id, - ui, - user_state, - &self.graph[self.node_id].user_data, - ); - self.graph[param_id].value = value; - responses.extend(node_responses.into_iter().map(NodeResponse::User)); - } - let height_after = ui.min_rect().bottom(); - input_port_heights.push((height_before + height_after) / 2.0); - } - } - - let outputs = self.graph[self.node_id].outputs.clone(); - for (param_name, _param) in outputs { - let height_before = ui.min_rect().bottom(); - ui.label(¶m_name); - let height_after = ui.min_rect().bottom(); - output_port_heights.push((height_before + height_after) / 2.0); - } - - responses.extend( - self.graph[self.node_id] - .user_data - .bottom_ui(ui, self.node_id, self.graph, user_state) - .into_iter(), - ); - }); - - // Second pass, iterate again to draw the ports. This happens outside - // the child_ui because we want ports to overflow the node background. - - let outer_rect = child_ui.min_rect().expand2(margin); - let port_left = outer_rect.left(); - let port_right = outer_rect.right(); - - #[allow(clippy::too_many_arguments)] - fn draw_port( - ui: &mut Ui, - graph: &Graph, - node_id: NodeId, - user_state: &mut UserState, - port_pos: Pos2, - responses: &mut Vec>, - param_id: AnyParameterId, - port_locations: &mut PortLocations, - ongoing_drag: Option<(NodeId, AnyParameterId)>, - is_connected_input: bool, - ) where - DataType: DataTypeTrait, - UserResponse: UserResponseTrait, - NodeData: NodeDataTrait, - { - let port_type = graph.any_param_type(param_id).unwrap(); - - let port_rect = Rect::from_center_size(port_pos, egui::vec2(10.0, 10.0)); - - let sense = if ongoing_drag.is_some() { - Sense::hover() - } else { - Sense::click_and_drag() - }; - - let resp = ui.allocate_rect(port_rect, sense); - - // Check if the distance between the port and the mouse is the distance to connect - let close_enough = if let Some(pointer_pos) = ui.ctx().pointer_hover_pos() { - port_rect.center().distance(pointer_pos) < DISTANCE_TO_CONNECT - } else { - false - }; - - let port_color = if close_enough { - Color32::WHITE - } else { - port_type.data_type_color(user_state) - }; - ui.painter() - .circle(port_rect.center(), 5.0, port_color, Stroke::none()); - - if resp.drag_started() { - if is_connected_input { - let input = param_id.assume_input(); - let corresp_output = graph - .connection(input) - .expect("Connection data should be valid"); - responses.push(NodeResponse::DisconnectEvent { - input: param_id.assume_input(), - output: corresp_output, - }); - } else { - responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); - } - } - - if let Some((origin_node, origin_param)) = ongoing_drag { - if origin_node != node_id { - // Don't allow self-loops - if graph.any_param_type(origin_param).unwrap() == port_type - && close_enough - && ui.input().pointer.any_released() - { - match (param_id, origin_param) { - (AnyParameterId::Input(input), AnyParameterId::Output(output)) - | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => { - responses.push(NodeResponse::ConnectEventEnded { input, output }); - } - _ => { /* Ignore in-in or out-out connections */ } - } - } - } - } - - port_locations.insert(param_id, port_rect.center()); - } - - // Input ports - for ((_, param), port_height) in self.graph[self.node_id] - .inputs - .iter() - .zip(input_port_heights.into_iter()) - { - let should_draw = match self.graph[*param].kind() { - InputParamKind::ConnectionOnly => true, - InputParamKind::ConstantOnly => false, - InputParamKind::ConnectionOrConstant => true, - }; - - if should_draw { - let pos_left = pos2(port_left, port_height); - draw_port( - ui, - self.graph, - self.node_id, - user_state, - pos_left, - &mut responses, - AnyParameterId::Input(*param), - self.port_locations, - self.ongoing_drag, - self.graph.connection(*param).is_some(), - ); - } - } - - // Output ports - for ((_, param), port_height) in self.graph[self.node_id] - .outputs - .iter() - .zip(output_port_heights.into_iter()) - { - let pos_right = pos2(port_right, port_height); - draw_port( - ui, - self.graph, - self.node_id, - user_state, - pos_right, - &mut responses, - AnyParameterId::Output(*param), - self.port_locations, - self.ongoing_drag, - false, - ); - } - - // Draw the background shape. - // NOTE: This code is a bit more involved than it needs to be because egui - // does not support drawing rectangles with asymmetrical round corners. - - let (shape, outline) = { - let rounding_radius = 4.0; - let rounding = Rounding::same(rounding_radius); - - let titlebar_height = title_height + margin.y; - let titlebar_rect = - Rect::from_min_size(outer_rect.min, vec2(outer_rect.width(), titlebar_height)); - let titlebar = Shape::Rect(RectShape { - rect: titlebar_rect, - rounding, - fill: self.graph[self.node_id] - .user_data - .titlebar_color(ui, self.node_id, self.graph, user_state) - .unwrap_or_else(|| background_color.lighten(0.8)), - stroke: Stroke::none(), - }); - - let body_rect = Rect::from_min_size( - outer_rect.min + vec2(0.0, titlebar_height - rounding_radius), - vec2(outer_rect.width(), outer_rect.height() - titlebar_height), - ); - let body = Shape::Rect(RectShape { - rect: body_rect, - rounding: Rounding::none(), - fill: background_color, - stroke: Stroke::none(), - }); - - let bottom_body_rect = Rect::from_min_size( - body_rect.min + vec2(0.0, body_rect.height() - titlebar_height * 0.5), - vec2(outer_rect.width(), titlebar_height), - ); - let bottom_body = Shape::Rect(RectShape { - rect: bottom_body_rect, - rounding, - fill: background_color, - stroke: Stroke::none(), - }); - - let node_rect = titlebar_rect.union(body_rect).union(bottom_body_rect); - let outline = if self.selected { - Shape::Rect(RectShape { - rect: node_rect.expand(1.0), - rounding, - fill: Color32::WHITE.lighten(0.8), - stroke: Stroke::none(), - }) - } else { - Shape::Noop - }; - - // Take note of the node rect, so the editor can use it later to compute intersections. - self.node_rects.insert(self.node_id, node_rect); - - (Shape::Vec(vec![titlebar, body, bottom_body]), outline) - }; - - ui.painter().set(background_shape, shape); - ui.painter().set(outline_shape, outline); - - // --- Interaction --- - - // Titlebar buttons - let can_delete = self.graph.nodes[self.node_id].user_data.can_delete( - self.node_id, - self.graph, - user_state, - ); - - if can_delete && Self::close_button(ui, outer_rect).clicked() { - responses.push(NodeResponse::DeleteNodeUi(self.node_id)); - }; - - let window_response = ui.interact( - outer_rect, - Id::new((self.node_id, "window")), - Sense::click_and_drag(), - ); - - // Movement - let drag_delta = window_response.drag_delta(); - if drag_delta.length_sq() > 0.0 { - responses.push(NodeResponse::MoveNode { - node: self.node_id, - drag_delta, - }); - responses.push(NodeResponse::RaiseNode(self.node_id)); - } - - // Node selection - // - // HACK: Only set the select response when no other response is active. - // This prevents some issues. - if responses.is_empty() && window_response.clicked_by(PointerButton::Primary) { - responses.push(NodeResponse::SelectNode(self.node_id)); - responses.push(NodeResponse::RaiseNode(self.node_id)); - } - - responses - } - - fn close_button(ui: &mut Ui, node_rect: Rect) -> Response { - // Measurements - let margin = 8.0; - let size = 10.0; - let stroke_width = 2.0; - let offs = margin + size / 2.0; - - let position = pos2(node_rect.right() - offs, node_rect.top() + offs); - let rect = Rect::from_center_size(position, vec2(size, size)); - let resp = ui.allocate_rect(rect, Sense::click()); - - let dark_mode = ui.visuals().dark_mode; - let color = if resp.clicked() { - if dark_mode { - color_from_hex("#ffffff").unwrap() - } else { - color_from_hex("#000000").unwrap() - } - } else if resp.hovered() { - if dark_mode { - color_from_hex("#dddddd").unwrap() - } else { - color_from_hex("#222222").unwrap() - } - } else { - #[allow(clippy::collapsible_else_if)] - if dark_mode { - color_from_hex("#aaaaaa").unwrap() - } else { - color_from_hex("#555555").unwrap() - } - }; - let stroke = Stroke { - width: stroke_width, - color, - }; - - ui.painter() - .line_segment([rect.left_top(), rect.right_bottom()], stroke); - ui.painter() - .line_segment([rect.right_top(), rect.left_bottom()], stroke); - - resp - } -} diff --git a/egui_node_graph/src/error.rs b/egui_node_graph/src/error.rs deleted file mode 100644 index 8033727..0000000 --- a/egui_node_graph/src/error.rs +++ /dev/null @@ -1,10 +0,0 @@ -use super::*; - -#[derive(Debug, thiserror::Error)] -pub enum EguiGraphError { - #[error("Node {0:?} has no parameter named {1}")] - NoParameterNamed(NodeId, String), - - #[error("Parameter {0:?} was not found in the graph.")] - InvalidParameterId(AnyParameterId), -} diff --git a/egui_node_graph/src/graph.rs b/egui_node_graph/src/graph.rs index 32301d7..d476e81 100644 --- a/egui_node_graph/src/graph.rs +++ b/egui_node_graph/src/graph.rs @@ -1,91 +1,365 @@ use super::*; +use std::{ + cell::RefCell, + sync::{Mutex, Arc}, + collections::HashMap, +}; +use thiserror::Error as ThisError; #[cfg(feature = "persistence")] use serde::{Deserialize, Serialize}; -/// A node inside the [`Graph`]. Nodes have input and output parameters, stored -/// as ids. They also contain a custom `NodeData` struct with whatever data the -/// user wants to store per-node. -#[derive(Debug, Clone)] -#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub struct Node { - pub id: NodeId, - pub label: String, - pub inputs: Vec<(String, InputId)>, - pub outputs: Vec<(String, OutputId)>, - pub user_data: NodeData, +#[cfg(feature = "persistence")] +fn shown_inline_default() -> bool { + true } -/// The three kinds of input params. These describe how the graph must behave -/// with respect to inline widgets and connections for this parameter. -#[derive(Debug, Clone, Copy)] +/// The graph, containing nodes, input parameters and output parameters. Because +/// graphs are full of self-referential structures, this type uses the `slotmap` +/// crate to represent all the inner references in the data. +#[derive(Debug)] #[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub enum InputParamKind { - /// No constant value can be set. Only incoming connections can produce it - ConnectionOnly, - /// Only a constant value can be set. No incoming connections accepted. - ConstantOnly, - /// Both incoming connections and constants are accepted. Connections take - /// precedence over the constant values. - ConnectionOrConstant, +pub struct Graph { + /// The Nodes of the graph + nodes: SlotMap, + /// Connects the input of a port to the output of its predecessor that + /// produces it + connections: HashMap, + /// Keep track of what connections have dropped while something in the graph + /// was mutable + #[cfg_attr(feature = "persistence", serde(skip))] + dropped_connections: DroppedConnections, + /// This field is used as a buffer inside of process_dropped_connections(). + /// We maintain it as a field so that we don't need to repeatedly reallocate + /// the memory for this buffer. + #[cfg_attr(feature = "persistence", serde(skip))] + drop_buffer: Vec, } -#[cfg(feature = "persistence")] -fn shown_inline_default() -> bool { - true +impl Graph { + pub fn new() -> Self { + Self{ + nodes: SlotMap::default(), + connections: Default::default(), + dropped_connections: Default::default(), + drop_buffer: Vec::new(), + } + } + + pub fn add_node( + &mut self, + node: Node, + ) -> NodeId { + self.nodes.insert(node) + } + + pub fn node(&self, key: NodeId) -> Option<&Node> { + self.nodes.get(key) + } + + /// Operate on a mutable node. If any connections are dropped while mutating + /// the node, the graph will automatically update itself. You may optionally + /// have your function return a value. + /// + /// If the `node_id` is invalid then nothing will happen. To provide a fallback + /// behavior when the node is missing use [`node_mut_or`] + pub fn node_mut T, T>(&mut self, node_id: NodeId, f: F) -> Result { + if let Some(node) = self.nodes.get_mut(node_id) { + let output = f(node); + + // If the node dropped any connections while it was being mutated, + // we need to handle that to maintain consistency in the graph. + self.process_dropped_connections(); + Ok(output) + } else { + Err(()) + } + } + + /// Operate on a mutable node or else perform some fallback behavior. If any + /// connections are dropped while mutating the node, the graph will automatically + /// update itself. You may optionally have your function return a value. + pub fn node_mut_or T, E: FnOnce() -> U, T, U>( + &mut self, + key: NodeId, + found_f: F, + else_f: E + ) -> Result { + if let Some(node) = self.nodes.get_mut(key) { + let output = found_f(node); + self.process_dropped_connections(); + Ok(output) + } else { + Err(else_f()) + } + } + + /// Removes a node from the graph with the given `node_id`. This also removes + /// any incoming or outgoing connections from that node + /// + /// This function returns the list of connections that has been removed + /// after deleting this node as input-output pairs. Note that one of the two + /// ids in the pair (the one on `node_id`'s end) will be invalid after + /// calling this function. + pub fn remove_node(&mut self, node_id: NodeId) -> Option<(Node, Vec<(InputId, OutputId)>)> { + if let Some(mut removed_node) = self.nodes.remove(node_id) { + let dropped = removed_node.drop_all_connections().into_iter() + .map(|(port_id, hook_id, connection)| { + match port_id { + PortId::Input(input_port_id) => { + let input_id = InputId(node_id, input_port_id, hook_id); + let output_id = connection.assume_output(); + (input_id, output_id) + }, + PortId::Output(output_port_id) => { + let output_id = OutputId(node_id, output_port_id, hook_id); + let input_id = connection.assume_input(); + (input_id, output_id) + } + } + }).collect(); + + self.process_dropped_connections(); + return Some((removed_node, dropped)); + } + + return None; + } + + /// Adds a connection to the graph from an output to an input. + // TODO(@mxgrey): Should we test that the data types are compatible? + pub fn add_connection(&mut self, output_id: OutputId, input_id: InputId) -> Result<(), GraphAddConnectionError> { + self.connections.insert(input_id.into(), output_id.into()); + self.connections.insert(output_id.into(), input_id.into()); + + let result = { + // Note: We create the tokens inside this nested scope so that if + // either of the nodes is non-existent, the token will drop before + // we call [`process_dropped_connections`] + + // Also Note: The connection tokens intentionally contain the id of the + // complementary ConnectionId + let token_for_output_hook = ConnectionToken::new( + input_id.into(), + self.dropped_connections.clone(), + ); + let token_for_input_hook = ConnectionToken::new( + output_id.into(), + self.dropped_connections.clone(), + ); + + // 1. Tell the output node about the connection + // 2. If the output node connected successfully, tell the input node + // If either node fails to connect, then its ConnectionToken will + // drop once we exit this scope. When the token drops, it will let + // the dropped_connections field know, and then self.process_dropped_connections() + // will clean up any lingering traces of the connection. + match self.node_mut_or( + output_id.node(), + |output_node| { + output_node.connect( + output_id.into(), + token_for_output_hook, + ).map_err(|err| GraphAddConnectionError::OutputNodeError{node: output_id.0, err}) + }, + || () + ) { + Ok(result) => result, + Err(()) => Err(GraphAddConnectionError::BadOutputNode(output_id.0)), + }.and_then(|_| { + match self.node_mut_or( + input_id.node(), + |input_node| { + input_node.connect( + input_id.into(), + token_for_input_hook, + ).map_err(|err| GraphAddConnectionError::InputNodeError{node: input_id.0, err}) + }, + || () + ) { + Ok(result) => result, + Err(()) => Err(GraphAddConnectionError::BadInputNode(input_id.0)) + } + }) + }; + + self.process_dropped_connections(); + return result; + } + + pub fn drop_connection(&mut self, id: ConnectionId) -> Result { + match self.node_mut_or( + id.node(), + |node| { + node.drop_connection(id.into()) + .map_err(|err| GraphDropConnectionError::NodeError{node: id.node(), err}) + }, + || () + ) { + Ok(result) => result, + Err(()) => Err(GraphDropConnectionError::BadNodeId(id.node())) + } + } + + pub fn iter_nodes(&self) -> impl Iterator + '_ { + self.nodes.iter() + } + + pub fn iter_connections(&self) -> impl Iterator + '_ { + self.connections.iter().filter_map( + |(o, i)| o.as_output().map(|o| (o, i.assume_input())) + ) + } + + pub fn connection(&self, id: &ConnectionId) -> Option { + self.connections.get(id).copied() + } + + /// This will be called automatically after each mutable graph function so + /// users generally should not have to call this. However, if a Node + /// implementation defies the recommended practice of only allowing + /// connections to drop while mutable, then this function can be called to + /// correct the graph. + pub fn process_dropped_connections(&mut self) { + // If we keep self.dropped_connections locked and iterate over it + // directly to disconnect the complementary hooks, its mutex would + // deadlock when the dropped ConnectionTokens of the complementary hook + // try to lock it. + // + // So instead we temporarily lock the mutex of dropped_connections and + // transfer its information into drop_buffer. Then iterate over drop_buffer, + // telling the complementary hook to drop their connection. As those connections + // drop, the ConnectionToken will lock self.dropped_connections and add their + // value into it. + // + // self.drop_buffer is kept as a field so we don't need to dynamically + // reallocate its memory every time we need to transfer the data into it. + self.drop_buffer.extend( + self.dropped_connections + .lock().expect("the dropped_connections mutex is poisoned") + .borrow_mut().drain(..) + ); + + let mut complements = Vec::new(); + for connection in self.drop_buffer.drain(..) { + if let Some(complement) = self.connections.remove(&connection) { + complements.push((complement, connection)); + } + } + + for (complement, original) in complements { + self.node_mut( + original.node(), + |node| { + node.drop_connection(original.into()) + } + ).ok(); + if let Some(connection) = self.connections.remove(&complement) { + assert!(complement == connection); + } + } + + // Clear both buffers. + // drop_buffer can be emptied because we've already processed all its contents + self.drop_buffer.clear(); + // dropped_connections should be cleared even though it was drained + // earlier because the complementary tokens will have filled it with + // irrelevant dropped connection data + self.dropped_connections + .lock().expect("the dropped_connections mutex is poisoned") + .borrow_mut().clear(); + } } -/// An input parameter. Input parameters are inside a node, and represent data -/// that this node receives. Unlike their [`OutputParam`] counterparts, input -/// parameters also display an inline widget which allows setting its "value". -/// The `DataType` generic parameter is used to restrict the range of input -/// connections for this parameter, and the `ValueType` is use to represent the -/// data for the inline widget (i.e. constant) value. -#[derive(Debug, Clone)] -#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub struct InputParam { - pub id: InputId, - /// The data type of this node. Used to determine incoming connections. This - /// should always match the type of the InputParamValue, but the property is - /// not actually enforced. - pub typ: DataType, - /// The constant value stored in this parameter. - pub value: ValueType, - /// The input kind. See [`InputParamKind`] - pub kind: InputParamKind, - /// Back-reference to the node containing this parameter. - pub node: NodeId, - /// When true, the node is shown inline inside the node graph. - #[cfg_attr(feature = "persistence", serde(default = "shown_inline_default"))] - pub shown_inline: bool, +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum PortAddConnectionError { + #[error("there is no hook [{0:?}] for this port")] + BadHook(HookId), + #[error("hook [{0:?}] is already occupied")] + HookOccupied(HookId), } -/// An output parameter. Output parameters are inside a node, and represent the -/// data that the node produces. Output parameters can be linked to the input -/// parameters of other nodes. Unlike an [`InputParam`], output parameters -/// cannot have a constant inline value. -#[derive(Debug, Clone)] -#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub struct OutputParam { - pub id: OutputId, - /// Back-reference to the node containing this parameter. - pub node: NodeId, - pub typ: DataType, +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum NodeAddConnectionError { + #[error("there is no port [{0:?}] for this node")] + BadPort(PortId), + #[error("port [{port:?}] had a connection error: {err}")] + PortError{port: PortId, err: PortAddConnectionError}, } -/// The graph, containing nodes, input parameters and output parameters. Because -/// graphs are full of self-referential structures, this type uses the `slotmap` -/// crate to represent all the inner references in the data. -#[derive(Debug, Clone)] -#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub struct Graph { - /// The [`Node`]s of the graph - pub nodes: SlotMap>, - /// The [`InputParam`]s of the graph - pub inputs: SlotMap>, - /// The [`OutputParam`]s of the graph - pub outputs: SlotMap>, - // Connects the input of a node, to the output of its predecessor that - // produces it - pub connections: SecondaryMap, +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum GraphAddConnectionError { + #[error("attempting to add a connection to a NodeId that doesn't exist: {0:?}")] + BadOutputNode(NodeId), + #[error("attempting to add a connection to a NodeId that doesn't exist: {0:?}")] + BadInputNode(NodeId), + #[error("error from output node {node:?} while attempting to add a connection: {err:?}")] + OutputNodeError{node: NodeId, err: NodeAddConnectionError}, + #[error("error from input node {node:?} while attempting to add a connection: {err:?}")] + InputNodeError{node: NodeId, err: NodeAddConnectionError}, +} + +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum PortDropConnectionError { + #[error("there is no hook [{0:?}] for this port")] + BadHook(HookId), + #[error("hook [{0:?}] does not have any connection")] + NoConnection(HookId), +} + +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum NodeDropConnectionError { + #[error("there is no port [{0:?}] for this node")] + BadPort(PortId), + #[error("port [{port:?}] had a drop connection error: {err}")] + PortError{port: PortId, err: PortDropConnectionError}, +} + +#[derive(ThisError, Debug, Clone, Copy, PartialEq, Eq)] +pub enum GraphDropConnectionError { + #[error("node [{0:?}] does not exist in the graph")] + BadNodeId(NodeId), + #[error("node [{node:?}] experienced an error: {err}")] + NodeError{node: NodeId, err: NodeDropConnectionError}, +} + +impl Default for Graph { + fn default() -> Self { + Self::new() + } +} + +// TODO(@mxgrey): Can this be safely replaced with Rc>>? +// Should we use something like #[cfg(feature = "single_threaded")] to let the +// user choose the more efficient alternative? +pub type DroppedConnections = Arc>>>; + +#[derive(Debug)] +pub struct ConnectionToken { + connected_to: ConnectionId, + drop_list: DroppedConnections, +} + +impl ConnectionToken { + /// Only the Graph class is allowed to create connection tokens + fn new( + connected_to: ConnectionId, + drop_list: DroppedConnections, + ) -> Self { + Self{connected_to, drop_list} + } + + pub fn connected_to(&self) -> ConnectionId { + self.connected_to + } +} + +impl Drop for ConnectionToken { + fn drop(&mut self) { + self.drop_list.lock() + .and_then(|list_cell| { + list_cell.borrow_mut().push(self.connected_to); + Ok(()) + }).ok(); + } } diff --git a/egui_node_graph/src/graph_impls.rs b/egui_node_graph/src/graph_impls.rs index 44b5a5d..1bfe892 100644 --- a/egui_node_graph/src/graph_impls.rs +++ b/egui_node_graph/src/graph_impls.rs @@ -1,215 +1,2 @@ use super::*; - -impl Graph { - pub fn new() -> Self { - Self { - nodes: SlotMap::default(), - inputs: SlotMap::default(), - outputs: SlotMap::default(), - connections: SecondaryMap::default(), - } - } - - pub fn add_node( - &mut self, - label: String, - user_data: NodeData, - f: impl FnOnce(&mut Graph, NodeId), - ) -> NodeId { - let node_id = self.nodes.insert_with_key(|node_id| { - Node { - id: node_id, - label, - // These get filled in later by the user function - inputs: Vec::default(), - outputs: Vec::default(), - user_data, - } - }); - - f(self, node_id); - - node_id - } - - pub fn add_input_param( - &mut self, - node_id: NodeId, - name: String, - typ: DataType, - value: ValueType, - kind: InputParamKind, - shown_inline: bool, - ) -> InputId { - let input_id = self.inputs.insert_with_key(|input_id| InputParam { - id: input_id, - typ, - value, - kind, - node: node_id, - shown_inline, - }); - self.nodes[node_id].inputs.push((name, input_id)); - input_id - } - - pub fn remove_input_param(&mut self, param: InputId) { - let node = self[param].node; - self[node].inputs.retain(|(_, id)| *id != param); - self.inputs.remove(param); - self.connections.retain(|i, _| i != param); - } - - pub fn remove_output_param(&mut self, param: OutputId) { - let node = self[param].node; - self[node].outputs.retain(|(_, id)| *id != param); - self.outputs.remove(param); - self.connections.retain(|_, o| *o != param); - } - - pub fn add_output_param(&mut self, node_id: NodeId, name: String, typ: DataType) -> OutputId { - let output_id = self.outputs.insert_with_key(|output_id| OutputParam { - id: output_id, - node: node_id, - typ, - }); - self.nodes[node_id].outputs.push((name, output_id)); - output_id - } - - /// Removes a node from the graph with given `node_id`. This also removes - /// any incoming or outgoing connections from that node - /// - /// This function returns the list of connections that has been removed - /// after deleting this node as input-output pairs. Note that one of the two - /// ids in the pair (the one on `node_id`'s end) will be invalid after - /// calling this function. - pub fn remove_node(&mut self, node_id: NodeId) -> (Node, Vec<(InputId, OutputId)>) { - let mut disconnect_events = vec![]; - - self.connections.retain(|i, o| { - if self.outputs[*o].node == node_id || self.inputs[i].node == node_id { - disconnect_events.push((i, *o)); - false - } else { - true - } - }); - - // NOTE: Collect is needed because we can't borrow the input ids while - // we remove them inside the loop. - for input in self[node_id].input_ids().collect::>() { - self.inputs.remove(input); - } - for output in self[node_id].output_ids().collect::>() { - self.outputs.remove(output); - } - let removed_node = self.nodes.remove(node_id).expect("Node should exist"); - - (removed_node, disconnect_events) - } - - pub fn remove_connection(&mut self, input_id: InputId) -> Option { - self.connections.remove(input_id) - } - - pub fn iter_nodes(&self) -> impl Iterator + '_ { - self.nodes.iter().map(|(id, _)| id) - } - - pub fn add_connection(&mut self, output: OutputId, input: InputId) { - self.connections.insert(input, output); - } - - pub fn iter_connections(&self) -> impl Iterator + '_ { - self.connections.iter().map(|(o, i)| (o, *i)) - } - - pub fn connection(&self, input: InputId) -> Option { - self.connections.get(input).copied() - } - - pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> { - match param { - AnyParameterId::Input(input) => self.inputs.get(input).map(|x| &x.typ), - AnyParameterId::Output(output) => self.outputs.get(output).map(|x| &x.typ), - } - .ok_or(EguiGraphError::InvalidParameterId(param)) - } - - pub fn try_get_input(&self, input: InputId) -> Option<&InputParam> { - self.inputs.get(input) - } - - pub fn get_input(&self, input: InputId) -> &InputParam { - &self.inputs[input] - } - - pub fn try_get_output(&self, output: OutputId) -> Option<&OutputParam> { - self.outputs.get(output) - } - - pub fn get_output(&self, output: OutputId) -> &OutputParam { - &self.outputs[output] - } -} - -impl Default for Graph { - fn default() -> Self { - Self::new() - } -} - -impl Node { - pub fn inputs<'a, DataType, DataValue>( - &'a self, - graph: &'a Graph, - ) -> impl Iterator> + 'a { - self.input_ids().map(|id| graph.get_input(id)) - } - - pub fn outputs<'a, DataType, DataValue>( - &'a self, - graph: &'a Graph, - ) -> impl Iterator> + 'a { - self.output_ids().map(|id| graph.get_output(id)) - } - - pub fn input_ids(&self) -> impl Iterator + '_ { - self.inputs.iter().map(|(_name, id)| *id) - } - - pub fn output_ids(&self) -> impl Iterator + '_ { - self.outputs.iter().map(|(_name, id)| *id) - } - - pub fn get_input(&self, name: &str) -> Result { - self.inputs - .iter() - .find(|(param_name, _id)| param_name == name) - .map(|x| x.1) - .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into())) - } - - pub fn get_output(&self, name: &str) -> Result { - self.outputs - .iter() - .find(|(param_name, _id)| param_name == name) - .map(|x| x.1) - .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into())) - } -} - -impl InputParam { - pub fn value(&self) -> &ValueType { - &self.value - } - - pub fn kind(&self) -> InputParamKind { - self.kind - } - - pub fn node(&self) -> NodeId { - self.node - } -} +use slotmap::SlotMap; diff --git a/egui_node_graph/src/id_type.rs b/egui_node_graph/src/id_type.rs index 5c272e1..e6e2eca 100644 --- a/egui_node_graph/src/id_type.rs +++ b/egui_node_graph/src/id_type.rs @@ -1,37 +1,186 @@ slotmap::new_key_type! { pub struct NodeId; } -slotmap::new_key_type! { pub struct InputId; } -slotmap::new_key_type! { pub struct OutputId; } +slotmap::new_key_type! { pub struct HookId; } +slotmap::new_key_type! { pub struct InputPortId; } +slotmap::new_key_type! { pub struct OutputPortId; } #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub enum AnyParameterId { - Input(InputId), - Output(OutputId), -} +pub struct InputId(pub NodeId, pub InputPortId, pub HookId); + +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct OutputId(pub NodeId, pub OutputPortId, pub HookId); + +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct ConnectionId(pub NodeId, pub PortId, pub HookId); + +impl ConnectionId { + pub fn as_input(&self) -> Option { + match self.1 { + PortId::Input(port) => Some(InputId(self.0, port, self.2)), + _ => None, + } + } + + pub fn as_output(&self) -> Option { + match self.1 { + PortId::Output(port) => Some(OutputId(self.0, port, self.2)), + _ => None, + } + } -impl AnyParameterId { pub fn assume_input(&self) -> InputId { - match self { - AnyParameterId::Input(input) => *input, - AnyParameterId::Output(output) => panic!("{:?} is not an InputId", output), + match self.as_input() { + Some(input) => input, + None => panic!("{:?} is not an InputId", self), } } pub fn assume_output(&self) -> OutputId { + match self.as_output() { + Some(output) => output, + None => panic!("{:?} is not an OutputId", self), + } + } + + pub fn node(&self) -> NodeId { + self.0 + } + + pub fn port(&self) -> PortId { + self.1 + } + + pub fn hook(&self) -> HookId { + self.2 + } +} + +impl InputId { + pub fn node(&self) -> NodeId { + self.0 + } + + pub fn port(&self) -> InputPortId { + self.1 + } + + pub fn hook(&self) -> HookId { + self.2 + } +} + +impl OutputId { + pub fn node(&self) -> NodeId { + self.0 + } + + pub fn port(&self) -> OutputPortId { + self.1 + } + + pub fn hook(&self) -> HookId { + self.2 + } +} + +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum PortId { + Input(InputPortId), + Output(OutputPortId), +} + +impl PortId { + pub fn assume_input(&self) -> InputPortId { + match self { + PortId::Input(input) => *input, + PortId::Output(output) => panic!("{:?} is not an InputPortId", output), + } + } + + pub fn assume_output(&self) -> OutputPortId { match self { - AnyParameterId::Output(output) => *output, - AnyParameterId::Input(input) => panic!("{:?} is not an OutputId", input), + PortId::Output(output) => *output, + PortId::Input(input) => panic!("{:?} is not an OutputPortId", input), + } + } + + pub fn is_complementary(&self, other: &Self) -> bool { + if let (PortId::Input(_), PortId::Output(_)) = (self, other) { + return true; + } + + if let (PortId::Output(_), PortId::Input(_)) = (self, other) { + return true; } + + return false; + } +} + +impl From for ConnectionId { + fn from(c: InputId) -> Self { + ConnectionId(c.0, c.into(), c.2) + } +} + +impl From for ConnectionId { + fn from(c: OutputId) -> Self { + ConnectionId(c.0, c.into(), c.2) + } +} + +impl From for NodeId { + fn from(c: ConnectionId) -> Self { + c.node() + } +} + +impl From for (NodeId, PortId) { + fn from(c: ConnectionId) -> Self { + (c.0, c.1) + } +} + +impl From for (PortId, HookId) { + fn from(c: ConnectionId) -> Self { + (c.1, c.2) + } +} + +impl From for PortId { + fn from(c: InputId) -> Self { + PortId::Input(c.1) + } +} + +impl From for PortId { + fn from(c: OutputId) -> Self { + PortId::Output(c.port()) + } +} + +impl From for PortId { + fn from(value: OutputPortId) -> Self { + PortId::Output(value) + } +} + +impl From for PortId { + fn from(value: InputPortId) -> Self { + PortId::Input(value) } } -impl From for AnyParameterId { - fn from(output: OutputId) -> Self { - Self::Output(output) +impl From for (PortId, HookId) { + fn from(c: OutputId) -> Self { + (PortId::Output(c.1), c.2) } } -impl From for AnyParameterId { - fn from(input: InputId) -> Self { - Self::Input(input) +impl From for (PortId, HookId) { + fn from(c: InputId) -> Self { + (PortId::Input(c.1), c.2) } } diff --git a/egui_node_graph/src/index_impls.rs b/egui_node_graph/src/index_impls.rs deleted file mode 100644 index f002330..0000000 --- a/egui_node_graph/src/index_impls.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::*; - -macro_rules! impl_index_traits { - ($id_type:ty, $output_type:ty, $arena:ident) => { - impl std::ops::Index<$id_type> for Graph { - type Output = $output_type; - - fn index(&self, index: $id_type) -> &Self::Output { - self.$arena.get(index).unwrap_or_else(|| { - panic!( - "{} index error for {:?}. Has the value been deleted?", - stringify!($id_type), - index - ) - }) - } - } - - impl std::ops::IndexMut<$id_type> for Graph { - fn index_mut(&mut self, index: $id_type) -> &mut Self::Output { - self.$arena.get_mut(index).unwrap_or_else(|| { - panic!( - "{} index error for {:?}. Has the value been deleted?", - stringify!($id_type), - index - ) - }) - } - } - }; -} - -impl_index_traits!(NodeId, Node, nodes); -impl_index_traits!(InputId, InputParam, inputs); -impl_index_traits!(OutputId, OutputParam, outputs); diff --git a/egui_node_graph/src/lib.rs b/egui_node_graph/src/lib.rs index 2b11be6..2375640 100644 --- a/egui_node_graph/src/lib.rs +++ b/egui_node_graph/src/lib.rs @@ -1,6 +1,6 @@ #![forbid(unsafe_code)] -use slotmap::{SecondaryMap, SlotMap}; +use slotmap::SlotMap; pub type SVec = smallvec::SmallVec<[T; 4]>; @@ -12,17 +12,6 @@ pub use graph::*; pub mod id_type; pub use id_type::*; -/// Implements the index trait for the Graph type, allowing indexing by all -/// three id types -pub mod index_impls; - -/// Implementing the main methods for the `Graph` -pub mod graph_impls; - -/// Custom error types, crate-wide -pub mod error; -pub use error::*; - /// The main struct in the library, contains all the necessary state to draw the /// UI graph pub mod ui_state; @@ -36,6 +25,12 @@ pub use node_finder::*; pub mod editor_ui; pub use editor_ui::*; +pub mod vertical_port; +pub use vertical_port::*; + +pub mod column_node; +pub use column_node::*; + /// Several traits that must be implemented by the user to customize the /// behavior of this library. pub mod traits; diff --git a/egui_node_graph/src/node_finder.rs b/egui_node_graph/src/node_finder.rs index 87a9596..ac6940b 100644 --- a/egui_node_graph/src/node_finder.rs +++ b/egui_node_graph/src/node_finder.rs @@ -1,40 +1,31 @@ -use std::marker::PhantomData; use crate::{color_hex_utils::*, NodeTemplateIter, NodeTemplateTrait}; use egui::*; -#[derive(Clone)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] -pub struct NodeFinder { +pub struct NodeFinder { pub query: String, - /// Reset every frame. When set, the node finder will be moved at that position - pub position: Option, + pub position: Pos2, pub just_spawned: bool, - _phantom: PhantomData, } -impl NodeFinder -where - NodeTemplate: NodeTemplateTrait, -{ +impl NodeFinder { pub fn new_at(pos: Pos2) -> Self { NodeFinder { query: "".into(), - position: Some(pos), + position: pos, just_spawned: true, - _phantom: Default::default(), } } /// Shows the node selector panel with a search bar. Returns whether a node /// archetype was selected and, in that case, the finder should be hidden on /// the next frame. - pub fn show( + pub fn show( &mut self, ui: &mut Ui, all_kinds: impl NodeTemplateIter, - user_state: &mut UserState, ) -> Option { let background_color; let text_color; @@ -76,7 +67,7 @@ where .show(ui, |ui| { ui.set_width(scroll_area_width); for kind in all_kinds.all_kinds() { - let kind_name = kind.node_finder_label(user_state).to_string(); + let kind_name = kind.node_finder_label().to_string(); if kind_name .to_lowercase() .contains(self.query.to_lowercase().as_str()) @@ -89,8 +80,8 @@ where } } } - }); - }); + }) + }) }); }); diff --git a/egui_node_graph/src/traits.rs b/egui_node_graph/src/traits.rs index a9cf2ef..2f3628c 100644 --- a/egui_node_graph/src/traits.rs +++ b/egui_node_graph/src/traits.rs @@ -1,38 +1,37 @@ use super::*; +use crate::{color_hex_utils::color_from_hex, utils::ColorUtils}; /// This trait must be implemented by the `ValueType` generic parameter of the /// [`Graph`]. The trait allows drawing custom inline widgets for the different /// types of the node graph. -/// -/// The [`Default`] trait bound is required to circumvent borrow checker issues -/// using `std::mem::take` Otherwise, it would be impossible to pass the -/// `node_data` parameter during `value_widget`. The default value is never -/// used, so the implementation is not important, but it should be reasonably -/// cheap to construct. -pub trait WidgetValueTrait: Default { - type Response; - type UserState; - type NodeData; - /// This method will be called for each input parameter with a widget. The - /// return value is a vector of custom response objects which can be used - /// to implement handling of side effects. If unsure, the response Vec can - /// be empty. - fn value_widget( +pub trait ValueTrait: Clone + std::fmt::Debug { + // TODO(MXG): We require the bounds Clone + Debug for ValueTrait in order to + // use the derive macro to implement Clone and Debug for PortResponse, but + // ValueTrait should not actually need Clone and Debug for that to work. We + // can remove this bound if we manually implement Clone and Debug or if this + // bug is ever fixed: https://github.com/rust-lang/rust/issues/26925 + + type Response: ResponseTrait; + /// This method will be called for each input parameter with a widget. + /// + /// The return value is a tuple with the recommended size of the widget and + /// a vector of custom response objects which can be used to implement + /// handling of side effects. If unsure, the response Vec can be empty. + fn show( &mut self, - param_name: &str, - node_id: NodeId, ui: &mut egui::Ui, - user_state: &mut Self::UserState, - node_data: &Self::NodeData, - ) -> Vec; + ) -> (egui::Rect, Vec); } -/// This trait must be implemented by the `DataType` generic parameter of the -/// [`Graph`]. This trait tells the library how to visually expose data types +/// This trait must be implemented by the `DataType` associated type of any +/// [`NodeTrait`]. This trait tells the library how to visually expose data types /// to the user. -pub trait DataTypeTrait: PartialEq + Eq { - /// The associated port color of this datatype - fn data_type_color(&self, user_state: &mut UserState) -> egui::Color32; +pub trait DataTypeTrait: Clone + std::fmt::Debug { + + /// This associated type gives the type of a raw value + type Value: ValueTrait; + + fn is_compatible(&self, other: &Self) -> bool; /// The name of this datatype. Return type is specified as Cow because /// some implementations will need to allocate a new string to provide an @@ -65,54 +64,115 @@ pub trait DataTypeTrait: PartialEq + Eq { fn name(&self) -> std::borrow::Cow; } -/// This trait must be implemented for the `NodeData` generic parameter of the -/// [`Graph`]. This trait allows customizing some aspects of the node drawing. -pub trait NodeDataTrait -where - Self: Sized, -{ - /// Must be set to the custom user `NodeResponse` type - type Response; - /// Must be set to the custom user `UserState` type - type UserState; - /// Must be set to the custom user `DataType` type - type DataType; - /// Must be set to the custom user `ValueType` type - type ValueType; +/// Trait to be implemented by port types. Ports belong to nodes and define the +/// behavior for connecting inputs and outputs for its node. +pub trait PortTrait { + type DataType: DataTypeTrait; + + /// Return the ideal Rect that the port would like to use so that the parent + /// Node widget can adjust its size if needed. + fn show( + &mut self, + ui: &mut egui::Ui, + id: (NodeId, PortId), + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> (egui::Rect, Vec>); + // TODO(@mxgrey): All of these Vec return types should be changed to + // impl IntoIterator when type_alias_impl_trait is a + // stable feature. That way we can avoid memory allocations in the return + // value. + + /// Get the data type information for this port. + fn data_type(&self) -> Self::DataType; + + /// Get the ID of an available hook, if one exists. + fn available_hook(&self) -> Option; + + /// Drops the connection at the specified hook. Returns the ID of the + /// other side of the dropped connection if the drop was successful or + /// [`Err`] if the hook does not exist or did not have a connection. + fn drop_connection(&mut self, id: HookId) -> Result; + + /// Remove all connections that this Port is holding and report the ID + /// information for them along with what they were connected to. + fn drop_all_connections(&mut self) -> Vec<(HookId, ConnectionId)>; + + /// Connect a hook in this port to another hook. This method should only be + /// called by a [`NodeTrait`] implementation. To create a connection as a + /// user, call [`Graph::add_connection`]. + fn connect(&mut self, from: HookId, to: graph::ConnectionToken) -> Result<(), PortAddConnectionError>; +} + +/// This trait must be implemented for the `Content` associated type of the +/// [`NodeTrait`]. This trait allows customizing some aspects of the node drawing. +pub trait NodeContentTrait: Sized { + type AppState; + type Response: ResponseTrait; /// Additional UI elements to draw in the nodes, after the parameters. - fn bottom_ui( - &self, + fn content_ui( + &mut self, ui: &mut egui::Ui, + app: &Self::AppState, node_id: NodeId, - graph: &Graph, - user_state: &mut Self::UserState, - ) -> Vec> - where - Self::Response: UserResponseTrait; + ) -> (egui::Rect, Vec); /// Set background color on titlebar /// If the return value is None, the default color is set. fn titlebar_color( &self, _ui: &egui::Ui, + _app: &Self::AppState, _node_id: NodeId, - _graph: &Graph, - _user_state: &mut Self::UserState, ) -> Option { None } +} - fn can_delete( - &self, - _node_id: NodeId, - _graph: &Graph, - _user_state: &mut Self::UserState, - ) -> bool { - true - } +pub trait NodeTrait { + type DataType: DataTypeTrait; + type Content: NodeContentTrait; + + fn show( + &mut self, + ui: &mut egui::Ui, + app: &::AppState, + id: NodeId, + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> Vec> where Self: Sized; + + /// Get the data type of the specified port if it exists, or None if the + /// port does not exist. + fn port_data_type(&self, port_id: PortId) -> Option; + + /// Get the ID of an available hook, if one exists. + fn available_hook(&self, port_id: PortId) -> Option; + + /// Drops the connection at the specified port and hook. Returns [`Ok`] if + /// the drop was successful or [`Err`] if the hook does not exist or did not + /// have a connection. + fn drop_connection( + &mut self, + id: (PortId, HookId) + ) -> Result; + + /// Remove all connections that this Node is holding and report the ID + /// information for them. + fn drop_all_connections(&mut self) -> Vec<(PortId, HookId, ConnectionId)>; + + /// Connect a hook in this node to another hook. This method can only be + /// called by the [`Graph`] class because only the graph module can produce + /// a ConnectionToken. To create a connection as a user, call [`Graph::add_connection`]. + fn connect(&mut self, from: (PortId, HookId), to: graph::ConnectionToken) -> Result<(), NodeAddConnectionError>; } +pub type ContentResponseOf = <::Content as NodeContentTrait>::Response; +pub type DataTypeOf = ::DataType; +pub type ValueResponseOf = < as DataTypeTrait>::Value as ValueTrait>::Response; +pub type AppStateOf = <::Content as NodeContentTrait>::AppState; + /// This trait can be implemented by any user type. The trait tells the library /// how to enumerate the node templates it will present to the user as part of /// the node finder. @@ -126,39 +186,165 @@ pub trait NodeTemplateIter { /// node template is what describes what kinds of nodes can be added to the /// graph, what is their name, and what are their input / output parameters. pub trait NodeTemplateTrait: Clone { - /// Must be set to the custom user `NodeData` type - type NodeData; - /// Must be set to the custom user `DataType` type - type DataType; - /// Must be set to the custom user `ValueType` type - type ValueType; - /// Must be set to the custom user `UserState` type - type UserState; + /// What kind of node can be produced by this template + type Node: NodeTrait; /// Returns a descriptive name for the node kind, used in the node finder. - /// - /// The return type is Cow to allow returning owned or borrowed values - /// more flexibly. Refer to the documentation for `DataTypeTrait::name` for - /// more information - fn node_finder_label(&self, user_state: &mut Self::UserState) -> std::borrow::Cow; + fn node_finder_label(&self) -> &str; /// Returns a descriptive name for the node kind, used in the graph. - fn node_graph_label(&self, user_state: &mut Self::UserState) -> String; - - /// Returns the user data for this node kind. - fn user_data(&self, user_state: &mut Self::UserState) -> Self::NodeData; + fn node_graph_label(&self) -> String; - /// This function is run when this node kind gets added to the graph. The - /// node will be empty by default, and this function can be used to fill its - /// parameters. + /// This function is run when this node kind gets added to the graph. fn build_node( &self, - graph: &mut Graph, - user_state: &mut Self::UserState, - node_id: NodeId, - ); + position: egui::Pos2, + app_state: &mut AppStateOf + ) -> Self::Node; } /// The custom user response types when drawing nodes in the graph must /// implement this trait. -pub trait UserResponseTrait: Clone + std::fmt::Debug {} +pub trait ResponseTrait: Clone + std::fmt::Debug {} +impl ResponseTrait for T {} + +pub trait GraphStyleTrait { + type DataType: DataTypeTrait; + + /// Recommend what color should be used for connections transmitting this data type + fn recommend_data_type_color(&self, typ: &Self::DataType) -> egui::Color32; + + /// Recommend what color should be used for the background of a node + fn recommend_node_background_color( + &self, + ui: &egui::Ui, + _node_id: NodeId, + ) -> egui::Color32 { + if ui.visuals().dark_mode { + color_from_hex("#3f3f3f").unwrap() + } else { + color_from_hex("#ffffff").unwrap() + } + } + + /// Recommend what color should be used for the text in a node + fn recommend_node_text_color( + &self, + ui: &egui::Ui, + _node_id: NodeId, + ) -> egui::Color32 { + if ui.visuals().dark_mode { + color_from_hex("#fefefe").unwrap() + } else { + color_from_hex("#505050").unwrap() + } + } + + fn recommend_port_passive_color( + &self, + ui: &egui::Ui, + (node_id, _port): (NodeId, PortId), + ) -> egui::Color32 { + self.recommend_node_background_color(ui, node_id).lighten(0.75) + } + + /// Ports may choose to be highlighted with this color when a connection + /// event is ongoing if their data type is compatible with the connection. + fn recommend_compatible_port_color( + &self, + _ui: &egui::Ui, + _port: (NodeId, PortId), + ) -> egui::Color32 { + color_from_hex("#D9F8C4").unwrap() + } + + /// Ports may choose to be highlighted with this color when a connection + /// event is ongoing if their data type is incompatible with the connection. + fn recommend_incompatible_port_color( + &self, + _ui: &egui::Ui, + _port: (NodeId, PortId), + ) -> egui::Color32 { + color_from_hex("#FFDEDE").unwrap() + } + + /// Ports may choose to be highlighted with this color when a compatible + /// connection is hovering over it. + fn recommend_port_accept_color( + &self, + _ui: &egui::Ui, + _port: (NodeId, PortId), + ) -> egui::Color32 { + color_from_hex("#00FFAB").unwrap() + } + + fn recommend_port_reject_color( + &self, + _ui: &egui::Ui, + _port: (NodeId, PortId), + ) -> egui::Color32 { + color_from_hex("#EB4747").unwrap() + } + + fn recommend_port_hover_color( + &self, + ui: &egui::Ui, + _port: (NodeId, PortId), + ) -> egui::Color32 { + if ui.visuals().dark_mode { + color_from_hex("#F9F3EE").unwrap() + } else { + color_from_hex("#C4DDFF").unwrap() + } + } + + /// (stroke, background) colors for the close button of a node when it is passive + fn recommend_close_button_passive_colors( + &self, + ui: &egui::Ui, + _node_id: NodeId, + ) -> (egui::Color32, egui::Color32) { + let dark = color_from_hex("#aaaaaa").unwrap(); + let light = color_from_hex("#555555").unwrap(); + if ui.visuals().dark_mode { + (light, dark) + } else { + (dark, light) + } + } + + /// (stroke, background) colors for the close button of a node when it is being hovered + fn recommend_close_button_hover_colors( + &self, + ui: &egui::Ui, + _node_id: NodeId, + ) -> (egui::Color32, egui::Color32) { + let dark = color_from_hex("#dddddd").unwrap(); + let light = color_from_hex("#222222").unwrap(); + if ui.visuals().dark_mode { + (light, dark) + } else { + (dark, light) + } + } + + /// (stroke, background) colors for the close button of a node when it is being clicked + fn recommend_close_button_clicked_colors( + &self, + ui: &egui::Ui, + _node_id: NodeId, + ) -> (egui::Color32, egui::Color32) { + let dark = color_from_hex("#ffffff").unwrap(); + let light = color_from_hex("#000000").unwrap(); + if ui.visuals().dark_mode { + (light, dark) + } else { + (dark, light) + } + } +} + +pub trait GraphContextTrait: GraphStyleTrait { + type Node: NodeTrait; + type NodeTemplate: NodeTemplateTrait; +} diff --git a/egui_node_graph/src/ui_state.rs b/egui_node_graph/src/ui_state.rs index 35a5b9a..a18309b 100644 --- a/egui_node_graph/src/ui_state.rs +++ b/egui_node_graph/src/ui_state.rs @@ -1,67 +1,47 @@ use super::*; -use std::marker::PhantomData; #[cfg(feature = "persistence")] use serde::{Deserialize, Serialize}; -#[derive(Default, Copy, Clone)] +#[derive(Copy, Clone)] #[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] pub struct PanZoom { pub pan: egui::Vec2, pub zoom: f32, } -#[derive(Clone)] #[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] -pub struct GraphEditorState { - pub graph: Graph, +pub struct GraphEditorState { + pub graph: Graph, /// Nodes are drawn in this order. Draw order is important because nodes /// that are drawn last are on top. pub node_order: Vec, /// An ongoing connection interaction: The mouse has dragged away from a /// port and the user is holding the click - pub connection_in_progress: Option<(NodeId, AnyParameterId)>, - /// The currently selected node. Some interface actions depend on the - /// currently selected node. + pub connection_in_progress: Option, + /// The currently selected nodes. Some interface actions depend on the + /// currently selected nodes. pub selected_nodes: Vec, - /// The mouse drag start position for an ongoing box selection. - pub ongoing_box_selection: Option, - /// The position of each node. - pub node_positions: SecondaryMap, /// The node finder is used to create new nodes. - pub node_finder: Option>, + pub node_finder: Option, /// The panning of the graph viewport. pub pan_zoom: PanZoom, - pub _user_state: PhantomData UserState>, + pub context: Context, } -impl - GraphEditorState -{ - pub fn new(default_zoom: f32) -> Self { +impl GraphEditorState { + pub fn new(default_zoom: f32, context: Context) -> Self { Self { + graph: Graph::new(), + node_order: Vec::new(), + connection_in_progress: None, + selected_nodes: Vec::new(), + node_finder: None, pan_zoom: PanZoom { pan: egui::Vec2::ZERO, zoom: default_zoom, }, - ..Default::default() - } - } -} -impl Default - for GraphEditorState -{ - fn default() -> Self { - Self { - graph: Default::default(), - node_order: Default::default(), - connection_in_progress: Default::default(), - selected_nodes: Default::default(), - ongoing_box_selection: Default::default(), - node_positions: Default::default(), - node_finder: Default::default(), - pan_zoom: Default::default(), - _user_state: Default::default(), + context, } } } diff --git a/egui_node_graph/src/vertical_port.rs b/egui_node_graph/src/vertical_port.rs new file mode 100644 index 0000000..7af63f1 --- /dev/null +++ b/egui_node_graph/src/vertical_port.rs @@ -0,0 +1,618 @@ +use super::*; +use std::collections::HashMap; + +/// The three kinds of input params. These describe how the graph must behave +/// with respect to inline widgets and connections for this parameter. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub enum InputKind { + /// No constant value can be set. Only incoming connections can produce it + ConnectionOnly, + /// Only a constant value can be set. No incoming connections accepted. + ConstantOnly, + /// Both incoming connections and constants are accepted. Connections take + /// precedence over the constant values. + ConnectionOrConstant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Side { + Left, + Right, +} + +/// A port that displays vertically. +#[derive(Debug)] +#[cfg_attr(feature = "persistence", derive(Serialize, Deserialize))] +pub struct VerticalPort { + /// Name of the port. This will be displayed next to the port icon. + pub label: String, + /// The data type of this node. Used to determine incoming connections. This + /// should always match the type of the InputParamValue, but the property is + /// not actually enforced. + pub data_type: DataType, + /// The limit on number of connections this port allows. A None value means + /// there is no limit. + connection_limit: Option, + /// What side of the block will this port be rendered on + pub side: Side, + /// List of existing hooks and whether or not they have a connection + hooks: SlotMap>, + /// The next hook that's available for a connection + available_hook: Option, + /// What order should the hooks be displayed in + ordering: Vec, +} + +pub struct VerticalOutputPort { + pub base: VerticalPort, +} + +impl VerticalOutputPort { + pub fn new( + label: String, + data_type: DataType, + connection_limit: Option, + ) -> Self { + let mut result = Self { + base: VerticalPort { + label, + data_type, + connection_limit, + side: Side::Right, + hooks: SlotMap::with_key(), + available_hook: None, + ordering: Vec::new(), + } + }; + result.base.consider_new_available_hook(); + result + } + + pub fn iter_hooks(&self) -> impl Iterator)> + '_ { + self.base.iter_hooks() + } +} + +pub struct VerticalInputPort { + /// The input kind. See [`InputKind`] + pub kind: InputKind, + pub default_value: Option, + pub base: VerticalPort, +} + +impl VerticalInputPort { + pub fn new( + label: String, + data_type: DataType, + connection_limit: Option, + kind: InputKind, + ) -> Self { + let mut result = Self { + kind, + default_value: None, + base: VerticalPort { + label, + data_type, + connection_limit, + side: Side::Left, + hooks: SlotMap::with_key(), + available_hook: None, + ordering: Vec::new(), + } + }; + result.base.consider_new_available_hook(); + result + } + + pub fn with_default_value(mut self, default_value: DataType::Value) -> Self { + self.default_value = Some(default_value); + self + } + + pub fn iter_hooks(&self) -> impl Iterator)> + '_ { + self.base.iter_hooks() + } + + pub fn using_default_value(&self) -> Option { + match self.kind { + InputKind::ConnectionOnly => { + None + } + InputKind::ConnectionOrConstant => { + if self.base.hooks.iter().find(|(_, c)| c.is_some()).is_none() { + // None of the hooks have a connection, so we will fall back + // on the default value. + self.default_value.clone() + } else { + None + } + } + InputKind::ConstantOnly => { + self.default_value.clone() + } + } + } +} + +impl VerticalPort { + + pub fn iter_hooks(&self) -> impl Iterator)> + '_ { + self.hooks.iter().map(|(id, token)| (id, token.as_ref().map(|t| t.connected_to()))) + } + + fn tangent(&self) -> egui::Vec2 { + match self.side { + Side::Left => egui::vec2(-1.0, 0.0), + Side::Right => egui::vec2(1.0, 0.0), + } + } + + pub fn show_impl( + &mut self, + ui: &mut egui::Ui, + id: (NodeId, PortId), + state: &mut EditorUiState, + context: &dyn GraphStyleTrait, + show_value: Option<&mut DataType::Value>, + ) -> (egui::Rect, Vec>) { + let (outer_left, outer_right) = (ui.min_rect().left(), ui.min_rect().right()); + let mut value_rect_opt = None; + let mut value_responses = Vec::<::Response>::default(); + let label_rect = ui.horizontal(|ui| { + ui.add_space(20.0); + ui.label(&self.label); + if let Some(value) = show_value { + if self.hooks.len() == 0 || (self.hooks.len() == 1 && self.available_hook.is_some()) { + // There are no connections to this port, so we should show + // the value input widget. + let (value_rect, value_resp) = value.show(ui); + value_rect_opt = Some(value_rect); + value_responses = value_resp; + } + } + }).response.rect; + + let row_rect = if let Some(value_rect) = value_rect_opt { + label_rect.union(value_rect) + } else { + label_rect + }; + + let (hook_x, port_edge_dx) = { + match self.side { + Side::Right => { + (outer_right - 6.0, -10.0) + } + Side::Left => { + (outer_left + 6.0, 10.0) + } + } + }; + + let edge_width = 1_f32; + let hook_spacing = 1_f32; + let radius = 5_f32; + let hook_count = self.hooks.len(); + let height_for_hooks: f32 = 2.0*edge_width + (2.0*radius + hook_spacing)*hook_count as f32 + hook_spacing; + let port_rect = { + let height = label_rect.height().max(height_for_hooks); + if port_edge_dx > 0.0 { + egui::Rect::from_min_size( + egui::pos2(hook_x, label_rect.top()), + egui::vec2(port_edge_dx, height), + ) + } else { + egui::Rect::from_min_size( + egui::pos2(hook_x + port_edge_dx, label_rect.top()), + egui::vec2(-port_edge_dx, height), + ) + } + }; + + let top_hook_y = { + if height_for_hooks >= label_rect.height() { + // The top hook needs to be as high in the port as possible + label_rect.y_range().start() + edge_width + hook_spacing + radius + } else { + // The hooks should be centered in the port + (label_rect.y_range().start() + label_rect.y_range().end())/2.0 - height_for_hooks/2.0 + edge_width + hook_spacing + radius + } + }; + + let (port_color, default_hook_color, hook_color_map, port_response): (_, _, HashMap, _) = 'port: { + // TODO(@mxgrey): It would be nice to move all this logic into its own + // utility function that can be used by different types of ports. + // That function would probably want to take in a port_rect and a + // hook_rect_iterator argument. + + // NOTE: We must allocate the hook rectangles before allocating the + // full port rectangles so that egui gives priority to the hooks + // over the port. For some reason the UI prioritizes sensing for + // the rectangles that are allocated sooner. + let hook_selected: Option<(HookId, egui::Response)> = { + let mut next_hook_y = top_hook_y; + self.ordering.iter().find_map(|hook_id| { + let hook_y = next_hook_y; + next_hook_y += hook_spacing + 2.0*radius; + let resp = ui.allocate_rect( + egui::Rect::from_center_size( + egui::pos2(hook_x, hook_y), + egui::vec2(2.0*radius, 2.0*radius), + ), + egui::Sense::click_and_drag(), + ); + + if ui.input().pointer.hover_pos().filter(|p| resp.rect.contains(*p)).is_some() { + Some((*hook_id, resp)) + } else { + None + } + }) + }; + + let ui_port_response = ui.allocate_rect(port_rect, egui::Sense::click_and_drag()); + if let Some((dragged_connection, dragged_data_type)) = &state.ongoing_drag { + let dragged_port: (NodeId, PortId) = dragged_connection.clone().into(); + if dragged_port == id { + // The port that is being dragged is this one. We should use + // the acceptance color while it is being dragged + let accept_color = context.recommend_port_accept_color(ui, id); + break 'port (accept_color, accept_color, HashMap::default(), None); + } + + let hovering_on_port = ui.input().pointer.hover_pos().filter(|p| { + if port_rect.contains(*p) { + true + } else if let Some((_, hook_resp)) = hook_selected { + hook_resp.rect.contains(*p) + } else { + false + } + }).is_some(); + + if let Some(available_hook) = self.available_hook { + let connection_possible = PortResponse::connect_event_ended( + ConnectionId(id.0, id.1, available_hook), + *dragged_connection, + ); + if let Some(connection_possible) = connection_possible { + if dragged_data_type.is_compatible(&self.data_type) && dragged_port.0 != id.0 { + // Check if the cursor is hovering on this port or one of its hooks + if hovering_on_port { + let resp = if ui.input().pointer.any_released() { + Some(connection_possible) + } else { + None + }; + + let accept_color = context.recommend_port_accept_color(ui, id); + // The port can accept or has accepted the connection + break 'port (accept_color, accept_color, HashMap::default(), resp); + } + + let compatible_color = context.recommend_compatible_port_color(ui, id); + // The connection is compatible but the user needs to + // drag it over to the port + break 'port ( + compatible_color, + context.recommend_data_type_color(&self.data_type.clone().into()), + HashMap::default(), + None, + ); + } + } + } + + // A connection is not possible, either because all the hooks + // are filled or because the data type that's being dragged + // is incompatible + let port_color = if hovering_on_port { + context.recommend_port_reject_color(ui, id) + } else { + context.recommend_incompatible_port_color(ui, id) + }; + + break 'port ( + port_color, + context.recommend_data_type_color(&self.data_type.clone().into()), + HashMap::default(), + None, + ); + } + + if let Some((hook_selected, hook_resp)) = hook_selected { + if self.available_hook.filter(|h| *h == hook_selected).is_some() { + if hook_resp.drag_started() { + let accept_color = context.recommend_port_accept_color(ui, id); + break 'port ( + accept_color, + context.recommend_data_type_color(&self.data_type), + HashMap::from_iter([(hook_selected, accept_color)]), + Some(PortResponse::ConnectEventStarted(ConnectionId(id.0, id.1, hook_selected))), + ); + } + + if hook_resp.hovered() { + // The user is hovering over the available hook. Show + // the user that we see the hovering. + let hover_color = context.recommend_port_hover_color(ui, id); + break 'port ( + hover_color, + context.recommend_data_type_color(&self.data_type), + HashMap::from_iter([(hook_selected, hover_color)]), + None, + ); + } + } else { + // The user is interacting with a hook that is part of a + // connection + if hook_resp.drag_started() { + // Dragging from a connected hook. Begin the connection + // moving event. + let accept_color = context.recommend_port_accept_color(ui, id); + break 'port ( + accept_color, + accept_color, + HashMap::default(), + Some(PortResponse::MoveEvent(ConnectionId(id.0, id.1, hook_selected))), + ); + } + + if hook_resp.hovered() { + // Hovering over a connected hook. Show the user that we + // see the hovering. + let hover_color = context.recommend_port_hover_color(ui, id); + break 'port ( + context.recommend_port_passive_color(ui, id), + context.recommend_data_type_color(&self.data_type), + HashMap::from_iter([(hook_selected, hover_color)]), + None, + ); + } + } + } + + if ui_port_response.drag_started() { + if let Some(available_hook) = self.available_hook { + // The user has started to drag a new connection from the + // port. + let accept_color = context.recommend_port_accept_color(ui, id); + break 'port ( + accept_color, + accept_color, + HashMap::default(), + Some(PortResponse::ConnectEventStarted(ConnectionId(id.0, id.1, available_hook))), + ); + } + } + + if ui_port_response.hovered() { + if let Some(available_hook) = self.available_hook { + // The user is hovering a port with an available hook. + // Show the user that we see the hovering. + let hover_color = context.recommend_port_hover_color(ui, id); + break 'port ( + hover_color, + context.recommend_data_type_color(&self.data_type), + HashMap::from_iter([(available_hook, hover_color)]), + None, + ); + } else { + // The user is hovering over a port that does not have an + // available hook. + let hover_color = context.recommend_incompatible_port_color(ui, id); + break 'port ( + hover_color, + context.recommend_data_type_color(&self.data_type), + HashMap::default(), + None, + ); + } + } + + if ui_port_response.drag_started() || ui_port_response.dragged() { + if self.available_hook.is_none() { + // The user is trying to drag on a port that has no available hook + let reject_color = context.recommend_port_reject_color(ui, id); + break 'port ( + reject_color, + context.recommend_data_type_color(&self.data_type), + HashMap::default(), + None, + ); + } + } + + // Nothing special is happening with this port + ( + context.recommend_port_passive_color(ui, id), + context.recommend_data_type_color(&self.data_type), + HashMap::default(), + None, + ) + }; + + let node_color = context.recommend_node_background_color(ui, id.0); + ui.painter().rect(port_rect, egui::Rounding::default(), port_color, (0_f32, node_color)); + + // Now draw the hooks and save their locations + let mut next_hook_y = top_hook_y; + for hook_id in &self.ordering { + let color = hook_color_map.get(&hook_id).unwrap_or(&default_hook_color); + let p = egui::pos2(hook_x, next_hook_y); + ui.painter().circle(p, radius, *color, (0_f32, *color)); + state.hook_geometry.insert(ConnectionId(id.0, id.1, *hook_id), (p, self.tangent())); + next_hook_y += hook_spacing + 2.0*radius; + } + + let responses = value_responses.into_iter().map(PortResponse::Value) + .chain([port_response].into_iter().filter_map(|r| r)).collect(); + return (row_rect.union(port_rect), responses); + } + + pub fn consider_new_available_hook(&mut self) { + if self.available_hook.is_none() { + if self.connection_limit.filter(|limit| *limit <= self.hooks.len()).is_none() { + let new_hook = self.hooks.insert(None); + self.available_hook = Some(new_hook); + self.ordering.push(new_hook); + } + } + } +} + +impl PortTrait for VerticalPort { + type DataType = DataType; + + fn show( + &mut self, + ui: &mut egui::Ui, + id: (NodeId, PortId), + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> (egui::Rect, Vec>) { + self.show_impl(ui, id, state, style, None) + } + + fn data_type(&self) -> Self::DataType { + self.data_type.clone() + } + + fn available_hook(&self) -> Option { + self.available_hook + } + + fn connect(&mut self, from: HookId, to: graph::ConnectionToken) -> Result<(), PortAddConnectionError> { + let connection = match self.hooks.get_mut(from) { + Some(connection) => connection, + None => return Err(PortAddConnectionError::BadHook(from)), + }; + + *connection = Some(to); + if self.available_hook == Some(from) { + // We are now using up the available hook, so we should decide + // whether to clear it or replace it. + self.available_hook = None; + self.consider_new_available_hook(); + } + + Ok(()) + } + + fn drop_connection(&mut self, id: HookId) -> Result { + let connection = match self.hooks.get(id) { + Some(Some(connection)) => connection, + Some(None) => return Err(PortDropConnectionError::NoConnection(id)), + None => return Err(PortDropConnectionError::BadHook(id)), + }.connected_to(); + + self.ordering.retain(|h| *h != id); + self.hooks.remove(id); + self.consider_new_available_hook(); + + Ok(connection) + } + + fn drop_all_connections(&mut self) -> Vec<(HookId, ConnectionId)> { + let mut dropped = Vec::new(); + for (id, connection) in &self.hooks { + if let Some(connection) = connection { + dropped.push((id, connection.connected_to())); + } + } + + self.ordering.clear(); + self.hooks.clear(); + self.consider_new_available_hook(); + + dropped + } +} + +impl PortTrait for VerticalInputPort { + type DataType = DataType; + + fn show( + &mut self, + ui: &mut egui::Ui, + id: (NodeId, PortId), + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> (egui::Rect, Vec>) { + match self.kind { + InputKind::ConnectionOnly => { + self.base.show_impl(ui, id, state, style, None) + }, + InputKind::ConstantOnly => { + let label_rect = ui.label(&self.base.label).rect; + if let Some(default_value) = &mut self.default_value { + let (value_rect, value_resp) = default_value.show(ui); + ( + label_rect.union(value_rect), + value_resp.into_iter().map(PortResponse::Value).collect(), + ) + } else { + (label_rect, Vec::new()) + } + }, + InputKind::ConnectionOrConstant => { + self.base.show_impl(ui, id, state, style, self.default_value.as_mut()) + } + } + } + + fn data_type(&self) -> Self::DataType { + self.base.data_type.clone() + } + + fn available_hook(&self) -> Option { + self.base.available_hook + } + + fn connect(&mut self, from: HookId, to: graph::ConnectionToken) -> Result<(), PortAddConnectionError> { + self.base.connect(from, to) + } + + fn drop_all_connections(&mut self) -> Vec<(HookId, ConnectionId)> { + self.base.drop_all_connections() + } + + fn drop_connection(&mut self, id: HookId) -> Result { + self.base.drop_connection(id) + } +} + +impl PortTrait for VerticalOutputPort { + type DataType = DataType; + + fn show( + &mut self, + ui: &mut egui::Ui, + id: (NodeId, PortId), + state: &mut EditorUiState, + style: &dyn GraphStyleTrait, + ) -> (egui::Rect, Vec>) { + self.base.show(ui, id, state, style) + } + + fn available_hook(&self) -> Option { + self.base.available_hook() + } + + fn connect(&mut self, from: HookId, to: graph::ConnectionToken) -> Result<(), PortAddConnectionError> { + self.base.connect(from, to) + } + + fn data_type(&self) -> Self::DataType { + self.base.data_type() + } + + fn drop_all_connections(&mut self) -> Vec<(HookId, ConnectionId)> { + self.base.drop_all_connections() + } + + fn drop_connection(&mut self, id: HookId) -> Result { + self.base.drop_connection(id) + } +} diff --git a/egui_node_graph_example/src/app.rs b/egui_node_graph_example/src/app.rs index 5ad68ca..affee63 100644 --- a/egui_node_graph_example/src/app.rs +++ b/egui_node_graph_example/src/app.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, collections::HashMap}; +use std::{borrow::Cow, collections::{HashMap, HashSet}}; use eframe::egui::{self, DragValue, TextStyle}; use egui_node_graph::*; @@ -8,16 +8,20 @@ use egui_node_graph::*; /// The NodeData holds a custom data struct inside each node. It's useful to /// store additional information that doesn't live in parameters. For this /// example, the node data stores the template (i.e. the "type") of the node. -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] -pub struct MyNodeData { +pub struct MyNodeContent { template: MyNodeTemplate, } +impl MyNodeContent { + pub fn new(template: MyNodeTemplate) -> Self { + Self { template } + } +} + /// `DataType`s are what defines the possible range of connections when /// attaching two ports together. The graph UI will make sure to not allow /// attaching incompatible datatypes. -#[derive(PartialEq, Eq)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[derive(PartialEq, Eq, Debug, Clone)] pub enum MyDataType { Scalar, Vec2, @@ -31,24 +35,27 @@ pub enum MyDataType { /// up to the user code in this example to make sure no parameter is created /// with a DataType of Scalar and a ValueType of Vec2. #[derive(Copy, Clone, Debug)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub enum MyValueType { - Vec2 { value: egui::Vec2 }, - Scalar { value: f32 }, + Vec2(egui::Vec2), + Scalar(f32), } -impl Default for MyValueType { - fn default() -> Self { - // NOTE: This is just a dummy `Default` implementation. The library - // requires it to circumvent some internal borrow checker issues. - Self::Scalar { value: 0.0 } +impl From for MyValueType { + fn from(value: f32) -> Self { + MyValueType::Scalar(value) + } +} + +impl From for MyValueType { + fn from(value: egui::Vec2) -> Self { + MyValueType::Vec2(value) } } impl MyValueType { /// Tries to downcast this value type to a vector pub fn try_to_vec2(self) -> anyhow::Result { - if let MyValueType::Vec2 { value } = self { + if let MyValueType::Vec2(value) = self { Ok(value) } else { anyhow::bail!("Invalid cast from {:?} to vec2", self) @@ -57,7 +64,7 @@ impl MyValueType { /// Tries to downcast this value type to a scalar pub fn try_to_scalar(self) -> anyhow::Result { - if let MyValueType::Scalar { value } = self { + if let MyValueType::Scalar(value) = self { Ok(value) } else { anyhow::bail!("Invalid cast from {:?} to scalar", self) @@ -69,7 +76,6 @@ impl MyValueType { /// will display in the "new node" popup. The user code needs to tell the /// library how to convert a NodeTemplate into a Node. #[derive(Clone, Copy)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub enum MyNodeTemplate { MakeVector, MakeScalar, @@ -94,20 +100,18 @@ pub enum MyResponse { /// parameter drawing callbacks. The contents of this struct are entirely up to /// the user. For this example, we use it to keep track of the 'active' node. #[derive(Default)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] -pub struct MyGraphState { +pub struct MyAppState { pub active_node: Option, } // =========== Then, you need to implement some traits ============ // A trait for the data types, to tell the library how to display them -impl DataTypeTrait for MyDataType { - fn data_type_color(&self, _user_state: &mut MyGraphState) -> egui::Color32 { - match self { - MyDataType::Scalar => egui::Color32::from_rgb(38, 109, 211), - MyDataType::Vec2 => egui::Color32::from_rgb(238, 207, 109), - } +impl DataTypeTrait for MyDataType { + type Value = MyValueType; + + fn is_compatible(&self, other: &Self) -> bool { + *self == *other } fn name(&self) -> Cow<'_, str> { @@ -118,16 +122,15 @@ impl DataTypeTrait for MyDataType { } } +type MyNode = SimpleColumnNode; + // A trait for the node kinds, which tells the library how to build new nodes // from the templates in the node finder impl NodeTemplateTrait for MyNodeTemplate { - type NodeData = MyNodeData; - type DataType = MyDataType; - type ValueType = MyValueType; - type UserState = MyGraphState; + type Node = MyNode; - fn node_finder_label(&self, _user_state: &mut Self::UserState) -> Cow<'_, str> { - Cow::Borrowed(match self { + fn node_finder_label(&self) -> &str { + match self { MyNodeTemplate::MakeVector => "New vector", MyNodeTemplate::MakeScalar => "New scalar", MyNodeTemplate::AddScalar => "Scalar add", @@ -135,110 +138,88 @@ impl NodeTemplateTrait for MyNodeTemplate { MyNodeTemplate::AddVector => "Vector add", MyNodeTemplate::SubtractVector => "Vector subtract", MyNodeTemplate::VectorTimesScalar => "Vector times scalar", - }) + } } - fn node_graph_label(&self, user_state: &mut Self::UserState) -> String { + fn node_graph_label(&self) -> String { // It's okay to delegate this to node_finder_label if you don't want to // show different names in the node finder and the node itself. - self.node_finder_label(user_state).into() - } - - fn user_data(&self, _user_state: &mut Self::UserState) -> Self::NodeData { - MyNodeData { template: *self } + self.node_finder_label().into() } fn build_node( &self, - graph: &mut Graph, - _user_state: &mut Self::UserState, - node_id: NodeId, - ) { - // The nodes are created empty by default. This function needs to take - // care of creating the desired inputs and outputs based on the template - - // We define some closures here to avoid boilerplate. Note that this is - // entirely optional. - let input_scalar = |graph: &mut MyGraph, name: &str| { - graph.add_input_param( - node_id, - name.to_string(), - MyDataType::Scalar, - MyValueType::Scalar { value: 0.0 }, - InputParamKind::ConnectionOrConstant, - true, - ); - }; - let input_vector = |graph: &mut MyGraph, name: &str| { - graph.add_input_param( - node_id, - name.to_string(), - MyDataType::Vec2, - MyValueType::Vec2 { - value: egui::vec2(0.0, 0.0), - }, - InputParamKind::ConnectionOrConstant, - true, - ); - }; - - let output_scalar = |graph: &mut MyGraph, name: &str| { - graph.add_output_param(node_id, name.to_string(), MyDataType::Scalar); - }; - let output_vector = |graph: &mut MyGraph, name: &str| { - graph.add_output_param(node_id, name.to_string(), MyDataType::Vec2); - }; - + position: egui::Pos2, + _app_state: &mut MyAppState, + ) -> Self::Node { + let node = MyNode::new(position, self.node_graph_label(), MyNodeContent::new(*self)) + .with_size_hint(120.0); match self { MyNodeTemplate::AddScalar => { - // The first input param doesn't use the closure so we can comment - // it in more detail. - graph.add_input_param( - node_id, - // This is the name of the parameter. Can be later used to - // retrieve the value. Parameter names should be unique. - "A".into(), - // The data type for this input. In this case, a scalar - MyDataType::Scalar, - // The value type for this input. We store zero as default - MyValueType::Scalar { value: 0.0 }, - // The input parameter kind. This allows defining whether a - // parameter accepts input connections and/or an inline - // widget to set its value. - InputParamKind::ConnectionOrConstant, - true, - ); - input_scalar(graph, "B"); - output_scalar(graph, "out"); - } - MyNodeTemplate::SubtractScalar => { - input_scalar(graph, "A"); - input_scalar(graph, "B"); - output_scalar(graph, "out"); - } - MyNodeTemplate::VectorTimesScalar => { - input_scalar(graph, "scalar"); - input_vector(graph, "vector"); - output_vector(graph, "out"); + node + .with_input(VerticalInputPort::new("in".to_owned(), MyDataType::Scalar, None, InputKind::ConnectionOnly)) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Scalar, None)) } MyNodeTemplate::AddVector => { - input_vector(graph, "v1"); - input_vector(graph, "v2"); - output_vector(graph, "out"); + node + .with_input(VerticalInputPort::new("in".to_owned(), MyDataType::Vec2, None, InputKind::ConnectionOnly)) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Vec2, None)) } - MyNodeTemplate::SubtractVector => { - input_vector(graph, "v1"); - input_vector(graph, "v2"); - output_vector(graph, "out"); + MyNodeTemplate::MakeScalar => { + node + .with_input( + VerticalInputPort::new("value".to_owned(), MyDataType::Scalar, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(0.0)) + ) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Scalar, None)) } MyNodeTemplate::MakeVector => { - input_scalar(graph, "x"); - input_scalar(graph, "y"); - output_vector(graph, "out"); + node + .with_input( + VerticalInputPort::new("x".to_owned(), MyDataType::Scalar, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(0.0)) + ) + .with_input( + VerticalInputPort::new("y".to_owned(), MyDataType::Scalar, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(0.0)) + ) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Vec2, None)) } - MyNodeTemplate::MakeScalar => { - input_scalar(graph, "value"); - output_scalar(graph, "out"); + MyNodeTemplate::SubtractScalar => { + node + .with_input( + VerticalInputPort::new("value".to_owned(), MyDataType::Scalar, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(0.0)) + ) + .with_input( + VerticalInputPort::new("minus".to_owned(), MyDataType::Scalar, None, InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(0.0)) + ) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Scalar, None)) + } + MyNodeTemplate::SubtractVector => { + node + .with_input( + VerticalInputPort::new("value".to_owned(), MyDataType::Vec2, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Vec2(egui::vec2(0.0, 0.0))) + ) + .with_input( + VerticalInputPort::new("minus".to_owned(), MyDataType::Vec2, None, InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Vec2(egui::vec2(0.0, 0.0))) + ) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Vec2, None)) + } + MyNodeTemplate::VectorTimesScalar => { + node + .with_input( + VerticalInputPort::new("scalar".to_owned(), MyDataType::Scalar, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Scalar(1.0)) + ) + .with_input( + VerticalInputPort::new("vec".to_owned(), MyDataType::Vec2, Some(1), InputKind::ConnectionOrConstant) + .with_default_value(MyValueType::Vec2(egui::vec2(1.0, 1.0))) + ) + .with_output(VerticalOutputPort::new("out".to_owned(), MyDataType::Vec2, None)) } } } @@ -264,71 +245,53 @@ impl NodeTemplateIter for AllMyNodeTemplates { } } -impl WidgetValueTrait for MyValueType { +impl ValueTrait for MyValueType { type Response = MyResponse; - type UserState = MyGraphState; - type NodeData = MyNodeData; - fn value_widget( - &mut self, - param_name: &str, - _node_id: NodeId, - ui: &mut egui::Ui, - _user_state: &mut MyGraphState, - _node_data: &MyNodeData, - ) -> Vec { + fn show(&mut self, ui: &mut egui::Ui) -> (egui::Rect, Vec) { // This trait is used to tell the library which UI to display for the // inline parameter widgets. - match self { - MyValueType::Vec2 { value } => { - ui.label(param_name); + let rect = match self { + MyValueType::Vec2(value) => { ui.horizontal(|ui| { ui.label("x"); ui.add(DragValue::new(&mut value.x)); ui.label("y"); ui.add(DragValue::new(&mut value.y)); - }); + }) } - MyValueType::Scalar { value } => { + MyValueType::Scalar(value) => { ui.horizontal(|ui| { - ui.label(param_name); ui.add(DragValue::new(value)); - }); + }) } - } + }.response.rect; // This allows you to return your responses from the inline widgets. - Vec::new() + (rect, Vec::new()) } } -impl UserResponseTrait for MyResponse {} -impl NodeDataTrait for MyNodeData { +impl NodeContentTrait for MyNodeContent { + type AppState = MyAppState; type Response = MyResponse; - type UserState = MyGraphState; - type DataType = MyDataType; - type ValueType = MyValueType; // This method will be called when drawing each node. This allows adding // extra ui elements inside the nodes. In this case, we create an "active" // button which introduces the concept of having an active node in the // graph. This is done entirely from user code with no modifications to the // node graph library. - fn bottom_ui( - &self, + fn content_ui( + &mut self, ui: &mut egui::Ui, + app_state: &Self::AppState, node_id: NodeId, - _graph: &Graph, - user_state: &mut Self::UserState, - ) -> Vec> - where - MyResponse: UserResponseTrait, - { + ) -> (egui::Rect, Vec) { // This logic is entirely up to the user. In this case, we check if the // current node we're drawing is the active one, by comparing against // the value stored in the global user state, and draw different button // UIs based on that. let mut responses = vec![]; - let is_active = user_state + let is_active = app_state .active_node .map(|id| id == node_id) .unwrap_or(false); @@ -337,62 +300,64 @@ impl NodeDataTrait for MyNodeData { // or clear the active node. These responses do nothing by themselves, // the library only makes the responses available to you after the graph // has been drawn. See below at the update method for an example. - if !is_active { - if ui.button("👁 Set active").clicked() { - responses.push(NodeResponse::User(MyResponse::SetActiveNode(node_id))); + let rect = if !is_active { + let resp = ui.button("👁 Set active"); + if resp.clicked() { + responses.push(MyResponse::SetActiveNode(node_id)); } + resp.rect } else { let button = egui::Button::new(egui::RichText::new("👁 Active").color(egui::Color32::BLACK)) .fill(egui::Color32::GOLD); - if ui.add(button).clicked() { - responses.push(NodeResponse::User(MyResponse::ClearActiveNode)); + let resp = ui.add(button); + if resp.clicked() { + responses.push(MyResponse::ClearActiveNode); } - } + resp.rect + }; - responses + (rect, responses) } } -type MyGraph = Graph; -type MyEditorState = - GraphEditorState; +type MyGraph = Graph; #[derive(Default)] +struct MyGraphContext; +impl GraphStyleTrait for MyGraphContext { + type DataType = MyDataType; + fn recommend_data_type_color(&self, data_type: &MyDataType) -> egui::Color32 { + match data_type { + MyDataType::Scalar => egui::Color32::from_rgb(38, 109, 211), + MyDataType::Vec2 => egui::Color32::from_rgb(238, 207, 109), + } + } +} +impl GraphContextTrait for MyGraphContext { + type Node = MyNode; + type NodeTemplate = MyNodeTemplate; +} + +type MyEditorState = GraphEditorState; + pub struct NodeGraphExample { // The `GraphEditorState` is the top-level object. You "register" all your // custom types by specifying it as its generic parameters. - state: MyEditorState, - - user_state: MyGraphState, + editor: MyEditorState, + app_state: MyAppState, } -#[cfg(feature = "persistence")] -const PERSISTENCE_KEY: &str = "egui_node_graph"; - -#[cfg(feature = "persistence")] -impl NodeGraphExample { - /// If the persistence feature is enabled, Called once before the first frame. - /// Load previous app state (if any). - pub fn new(cc: &eframe::CreationContext<'_>) -> Self { - let state = cc - .storage - .and_then(|storage| eframe::get_value(storage, PERSISTENCE_KEY)) - .unwrap_or_default(); +impl Default for NodeGraphExample { + fn default() -> Self { Self { - state, - user_state: MyGraphState::default(), + editor: GraphEditorState::new(1.0, MyGraphContext::default()), + app_state: MyAppState { active_node: None }, } } } impl eframe::App for NodeGraphExample { - #[cfg(feature = "persistence")] - /// If the persistence function is enabled, - /// Called by the frame work to save state before shutdown. - fn save(&mut self, storage: &mut dyn eframe::Storage) { - eframe::set_value(storage, PERSISTENCE_KEY, &self.state); - } /// Called each time the UI needs repainting, which may be many times per second. /// Put your widgets into a `SidePanel`, `TopPanel`, `CentralPanel`, `Window` or `Area`. fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { @@ -403,25 +368,26 @@ impl eframe::App for NodeGraphExample { }); let graph_response = egui::CentralPanel::default() .show(ctx, |ui| { - self.state - .draw_graph_editor(ui, AllMyNodeTemplates, &mut self.user_state) + self.editor.draw_graph_editor( + ui, AllMyNodeTemplates, &mut self.app_state + ) }) .inner; for node_response in graph_response.node_responses { // Here, we ignore all other graph events. But you may find // some use for them. For example, by playing a sound when a new // connection is created - if let NodeResponse::User(user_event) = node_response { + if let NodeResponse::Content(user_event) = node_response { match user_event { - MyResponse::SetActiveNode(node) => self.user_state.active_node = Some(node), - MyResponse::ClearActiveNode => self.user_state.active_node = None, + MyResponse::SetActiveNode(node) => self.app_state.active_node = Some(node), + MyResponse::ClearActiveNode => self.app_state.active_node = None, } } } - if let Some(node) = self.user_state.active_node { - if self.state.graph.nodes.contains_key(node) { - let text = match evaluate_node(&self.state.graph, node, &mut HashMap::new()) { + if let Some(node_id) = self.app_state.active_node { + if self.editor.graph.node(node_id).is_some() { + let text = match evaluate(&self.editor.graph, node_id) { Ok(value) => format!("The result is: {:?}", value), Err(err) => format!("Execution error: {}", err), }; @@ -433,159 +399,244 @@ impl eframe::App for NodeGraphExample { egui::Color32::WHITE, ); } else { - self.user_state.active_node = None; + self.app_state.active_node = None; } } } } -type OutputsCache = HashMap; - -/// Recursively evaluates all dependencies of this node, then evaluates the node itself. -pub fn evaluate_node( +pub fn evaluate( graph: &MyGraph, node_id: NodeId, - outputs_cache: &mut OutputsCache, ) -> anyhow::Result { - // To solve a similar problem as creating node types above, we define an - // Evaluator as a convenience. It may be overkill for this small example, - // but something like this makes the code much more readable when the - // number of nodes starts growing. - - struct Evaluator<'a> { - graph: &'a MyGraph, - outputs_cache: &'a mut OutputsCache, - node_id: NodeId, + + #[derive(Debug)] + enum InputType { + Connections(Vec), + Constant(MyValueType), } - impl<'a> Evaluator<'a> { - fn new(graph: &'a MyGraph, outputs_cache: &'a mut OutputsCache, node_id: NodeId) -> Self { - Self { - graph, - outputs_cache, - node_id, + + impl InputType { + fn dependencies(&self) -> Vec { + match self { + InputType::Connections(v) => v.clone(), + InputType::Constant(_) => vec![], } } - fn evaluate_input(&mut self, name: &str) -> anyhow::Result { - // Calling `evaluate_input` recursively evaluates other nodes in the - // graph until the input value for a paramater has been computed. - evaluate_input(self.graph, self.node_id, name, self.outputs_cache) - } - fn populate_output( - &mut self, - name: &str, - value: MyValueType, - ) -> anyhow::Result { - // After computing an output, we don't just return it, but we also - // populate the outputs cache with it. This ensures the evaluation - // only ever computes an output once. - // - // The return value of the function is the "final" output of the - // node, the thing we want to get from the evaluation. The example - // would be slightly more contrived when we had multiple output - // values, as we would need to choose which of the outputs is the - // one we want to return. Other outputs could be used as - // intermediate values. - // - // Note that this is just one possible semantic interpretation of - // the graphs, you can come up with your own evaluation semantics! - populate_output(self.graph, self.outputs_cache, self.node_id, name, value) - } - fn input_vector(&mut self, name: &str) -> anyhow::Result { - self.evaluate_input(name)?.try_to_vec2() - } - fn input_scalar(&mut self, name: &str) -> anyhow::Result { - self.evaluate_input(name)?.try_to_scalar() + + fn values(&self, evaluations: &HashMap<(NodeId, PortId), MyValueType>) -> Vec { + match self { + InputType::Connections(inputs) => { + inputs.iter().map(|OutputId(node, port, _)| { + evaluations.get(&(*node, (*port).into())).unwrap().clone() + }).collect() + } + InputType::Constant(value) => vec![value.clone()], + } } - fn output_vector(&mut self, name: &str, value: egui::Vec2) -> anyhow::Result { - self.populate_output(name, MyValueType::Vec2 { value }) + + fn sum_scalar_values(&self, evaluations: &HashMap<(NodeId, PortId), MyValueType>) -> f32 { + self.values(evaluations).iter().map(|v| v.try_to_scalar().unwrap()).fold(0_f32, |a, b| a + b) } - fn output_scalar(&mut self, name: &str, value: f32) -> anyhow::Result { - self.populate_output(name, MyValueType::Scalar { value }) + + fn sum_vector_values(&self, evaluations: &HashMap<(NodeId, PortId), MyValueType>) -> egui::Vec2 { + self.values(evaluations).iter().map(|v| v.try_to_vec2().unwrap()).fold(egui::vec2(0.0, 0.0), |a, b| a + b) } } - let node = &graph[node_id]; - let mut evaluator = Evaluator::new(graph, outputs_cache, node_id); - match node.user_data.template { - MyNodeTemplate::AddScalar => { - let a = evaluator.input_scalar("A")?; - let b = evaluator.input_scalar("B")?; - evaluator.output_scalar("out", a + b) + #[derive(Debug)] + enum NodeInput { + AddScalar(InputType), + AddVector(InputType), + MakeScalar(InputType), + MakeVector { + x: InputType, + y: InputType, + }, + SubtractScalar { + value: InputType, + minus: InputType, + }, + SubtractVector { + value: InputType, + minus: InputType, + }, + VectorTimesScalar { + scalar: InputType, + vec: InputType, + }, + } + + impl NodeInput { + fn new(node_id: NodeId, graph: &MyGraph) -> Self { + let node = graph.node(node_id).unwrap(); + match &node.content.template { + MyNodeTemplate::AddScalar => { + NodeInput::AddScalar(collect_inputs(node.inputs.iter().next().unwrap().1)) + } + MyNodeTemplate::AddVector => { + NodeInput::AddVector(collect_inputs(node.inputs.iter().next().unwrap().1)) + } + MyNodeTemplate::MakeScalar => { + NodeInput::MakeScalar(collect_inputs(node.inputs.iter().next().unwrap().1)) + } + MyNodeTemplate::MakeVector => { + NodeInput::MakeVector { + x: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "x").unwrap().1), + y: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "y").unwrap().1), + } + } + MyNodeTemplate::SubtractScalar => { + NodeInput::SubtractScalar { + value: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "value").unwrap().1), + minus: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "minus").unwrap().1), + } + } + MyNodeTemplate::SubtractVector => { + NodeInput::SubtractVector { + value: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "value").unwrap().1), + minus: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "minus").unwrap().1), + } + } + MyNodeTemplate::VectorTimesScalar => { + NodeInput::VectorTimesScalar { + scalar: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "scalar").unwrap().1), + vec: collect_inputs(node.inputs.iter().find(|p| p.1.base.label == "vec").unwrap().1), + } + } + } + } + + fn find_rank(&self, evaluatees: &HashMap) -> Option { + let mut rank = 0; + for dep in self.dependencies() { + if let Some((dep_rank, _)) = evaluatees.get(&dep.node()) { + // The rank of this node is the highest rank of its + // dependencies, plus one. + rank = rank.max(*dep_rank+1); + } else { + // This node has an unranked dependency, so we cannot rank + // it yet. + return None; + } + } + + Some(rank) } - MyNodeTemplate::SubtractScalar => { - let a = evaluator.input_scalar("A")?; - let b = evaluator.input_scalar("B")?; - evaluator.output_scalar("out", a - b) + + fn dependencies(&self) -> Vec { + match self { + NodeInput::AddScalar(v) => v.dependencies(), + NodeInput::AddVector(v) => v.dependencies(), + NodeInput::MakeScalar(v) => v.dependencies(), + NodeInput::MakeVector { x, y } => x.dependencies().iter().cloned().chain(y.dependencies().iter().cloned()).collect(), + NodeInput::SubtractScalar { value, minus } => value.dependencies().iter().cloned().chain(minus.dependencies().iter().cloned()).collect(), + NodeInput::SubtractVector { value, minus } => value.dependencies().iter().cloned().chain(minus.dependencies().iter().cloned()).collect(), + NodeInput::VectorTimesScalar { scalar, vec } => scalar.dependencies().iter().cloned().chain(vec.dependencies().iter().cloned()).collect(), + } } - MyNodeTemplate::VectorTimesScalar => { - let scalar = evaluator.input_scalar("scalar")?; - let vector = evaluator.input_vector("vector")?; - evaluator.output_vector("out", vector * scalar) + + fn evaluate(&self, evaluations: &HashMap<(NodeId, PortId), MyValueType>) -> MyValueType { + match &self { + NodeInput::AddScalar(input) => { + input.sum_scalar_values(evaluations).into() + } + NodeInput::AddVector(input) => { + input.sum_vector_values(evaluations).into() + } + NodeInput::MakeScalar(input) => { + // To gracefully handle cases where there are no connections + // and no constant value set, we implement this the same way + // as AddScalar + input.sum_scalar_values(evaluations).into() + } + NodeInput::MakeVector { x, y } => { + // To gracefully handle cases where there are no connections + // and no constant value set, we implement this similarly to + // AddScalar + let x = x.sum_scalar_values(evaluations); + let y = y.sum_scalar_values(evaluations); + egui::vec2(x, y).into() + } + NodeInput::SubtractScalar { value, minus } => { + let value = value.sum_scalar_values(evaluations); + let minus = minus.sum_scalar_values(evaluations); + (value - minus).into() + } + NodeInput::SubtractVector { value, minus } => { + let value = value.sum_vector_values(evaluations); + let minus = minus.sum_vector_values(evaluations); + (value - minus).into() + } + NodeInput::VectorTimesScalar { scalar, vec } => { + let scalar = scalar.sum_scalar_values(evaluations); + let vec = vec.sum_vector_values(evaluations); + (scalar * vec).into() + } + } } - MyNodeTemplate::AddVector => { - let v1 = evaluator.input_vector("v1")?; - let v2 = evaluator.input_vector("v2")?; - evaluator.output_vector("out", v1 + v2) + } + + fn collect_inputs(port: &VerticalInputPort) -> InputType { + if let Some(constant) = port.using_default_value() { + InputType::Constant(constant) + } else { + let connections = port.iter_hooks().filter_map(|(_, c)| c.map(|c| c.as_output()).flatten()).collect(); + InputType::Connections(connections) } - MyNodeTemplate::SubtractVector => { - let v1 = evaluator.input_vector("v1")?; - let v2 = evaluator.input_vector("v2")?; - evaluator.output_vector("out", v1 - v2) + } + + let mut ranking_queue = HashMap::::new(); + let mut evaluatees = HashMap::::new(); + ranking_queue.insert(node_id, NodeInput::new(node_id, graph)); + while !ranking_queue.is_empty() { + let mut next_queue = HashMap::::new(); + next_queue.reserve(ranking_queue.len()); + let mut previous_queue = HashSet::::new(); + for (node_id, _) in &ranking_queue { + previous_queue.insert(*node_id); } - MyNodeTemplate::MakeVector => { - let x = evaluator.input_scalar("x")?; - let y = evaluator.input_scalar("y")?; - evaluator.output_vector("out", egui::vec2(x, y)) + + for (node_id, node_input) in ranking_queue { + if let Some(rank) = node_input.find_rank(&evaluatees) { + evaluatees.insert(node_id, (rank, node_input)); + } else { + for dep in node_input.dependencies() { + if !previous_queue.contains(&dep.node()) && !evaluatees.contains_key(&dep.node()) { + next_queue.insert(dep.node(), NodeInput::new(dep.node(), graph)); + } + } + next_queue.insert(node_id, node_input); + } } - MyNodeTemplate::MakeScalar => { - let value = evaluator.input_scalar("value")?; - evaluator.output_scalar("out", value) + + if previous_queue.iter().find(|id| !next_queue.contains_key(*id)).is_none() { + if next_queue.iter().find(|(id, _)| !previous_queue.contains(*id)).is_none() { + anyhow::bail!("circular dependency!"); + } } + + ranking_queue = next_queue; } -} -fn populate_output( - graph: &MyGraph, - outputs_cache: &mut OutputsCache, - node_id: NodeId, - param_name: &str, - value: MyValueType, -) -> anyhow::Result { - let output_id = graph[node_id].get_output(param_name)?; - outputs_cache.insert(output_id, value); - Ok(value) -} + let mut evaluation_queue: Vec<(usize, NodeId, NodeInput)> = evaluatees.into_iter().map(|(id, (r, e))| (r, id, e)).collect(); + evaluation_queue.sort_by(|(r_a, _, _), (r_b, _, _)| r_a.cmp(r_b)); -// Evaluates the input value of -fn evaluate_input( - graph: &MyGraph, - node_id: NodeId, - param_name: &str, - outputs_cache: &mut OutputsCache, -) -> anyhow::Result { - let input_id = graph[node_id].get_input(param_name)?; - - // The output of another node is connected. - if let Some(other_output_id) = graph.connection(input_id) { - // The value was already computed due to the evaluation of some other - // node. We simply return value from the cache. - if let Some(other_value) = outputs_cache.get(&other_output_id) { - Ok(*other_value) - } - // This is the first time encountering this node, so we need to - // recursively evaluate it. - else { - // Calling this will populate the cache - evaluate_node(graph, graph[other_output_id].node, outputs_cache)?; - - // Now that we know the value is cached, return it - Ok(*outputs_cache - .get(&other_output_id) - .expect("Cache should be populated")) + let mut evaluations = HashMap::<(NodeId, PortId), MyValueType>::new(); + for (_, node_id, evaluatee) in evaluation_queue { + if let Some((output_port, _)) = graph.node(node_id).unwrap().outputs.iter().next() { + let evaluation = evaluatee.evaluate(&evaluations); + evaluations.insert((node_id, output_port.into()), evaluation); + } else { + anyhow::bail!("missing output port for node {:?}", node_id); } } - // No existing connection, take the inline value instead. - else { - Ok(graph[input_id].value) + + if let Some((output_port, _)) = graph.node(node_id).unwrap().outputs.iter().next() { + return evaluations.get(&(node_id, output_port.into())).cloned().ok_or_else( + || anyhow::format_err!("failed to include active node {:?} in evaluation", node_id) + ); + } else { + anyhow::bail!("missing output port for active node {:?}", node_id); } }