diff --git a/mesa/visualization/backends/altair_backend.py b/mesa/visualization/backends/altair_backend.py index edbb28730a9..2f9315b6562 100644 --- a/mesa/visualization/backends/altair_backend.py +++ b/mesa/visualization/backends/altair_backend.py @@ -299,11 +299,13 @@ def draw_agents( "x:Q", title=xlabel, scale=alt.Scale(type="linear", domain=[xmin, xmax]), + axis=None, ), y=alt.Y( "y:Q", title=ylabel, scale=alt.Scale(type="linear", domain=[ymin, ymax]), + axis=None, ), size=alt.Size("size:Q", legend=None, scale=alt.Scale(domain=[0, 50])), shape=alt.Shape( @@ -352,8 +354,7 @@ def draw_propertylayer( Returns: alt.Chart: A tuple containing the base chart and the color bar chart. """ - base = None - bar_chart_viz = None + main_charts = [] for layer_name in property_layers: if layer_name == "empty": @@ -384,7 +385,6 @@ def draw_propertylayer( vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data) vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data) - # Prepare data for Altair df = pd.DataFrame( { "x": np.repeat(np.arange(data.shape[0]), data.shape[1]), @@ -393,183 +393,48 @@ def draw_propertylayer( } ) - current_chart = None if color: - # Create a function to map values to RGBA colors with proper opacity scaling - def apply_rgba( - val, v_min=vmin, v_max=vmax, a=alpha, p_color=portrayal.color - ): - # Normalize value to range [0,1] and clamp - normalized = max( - 0, - min( - ((val - v_min) / (v_max - v_min)) - if (v_max - v_min) != 0 - else 0.5, - 1, - ), - ) - - # Scale opacity by alpha parameter - opacity = normalized * a - - # Convert color to RGB components - rgb_color_val = to_rgb(p_color) - r = int(rgb_color_val[0] * 255) - g = int(rgb_color_val[1] * 255) - b = int(rgb_color_val[2] * 255) - return f"rgba({r}, {g}, {b}, {opacity:.2f})" - - # Apply color mapping to each value in the dataset - df["color_str"] = df["value"].apply(apply_rgba) - - # Create chart for the property layer - current_chart = ( - alt.Chart(df) - .mark_rect() - .encode( - x=alt.X("x:O", axis=None), - y=alt.Y("y:O", axis=None), - fill=alt.Fill("color_str:N", scale=None), - ) - .properties( - width=chart_width, height=chart_height, title=layer_name - ) + # For a single color gradient, we define the range from transparent to solid. + rgb = to_rgb(color) + r, g, b = (int(c * 255) for c in rgb) + + min_color = f"rgba({r},{g},{b},0)" + max_color = f"rgba({r},{g},{b},{alpha})" + opacity = 1 + color_scale = alt.Scale( + range=[min_color, max_color], domain=[vmin, vmax] ) - base = ( - alt.layer(current_chart, base) - if base is not None - else current_chart - ) - - # Add colorbar if specified in portrayal - if portrayal.colorbar: - # Extract RGB components from base color - rgb_color_val = to_rgb(portrayal.color) - r_int = int(rgb_color_val[0] * 255) - g_int = int(rgb_color_val[1] * 255) - b_int = int(rgb_color_val[2] * 255) - - # Define gradient endpoints - min_color_str = f"rgba({r_int},{g_int},{b_int},0)" - max_color_str = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})" - - # Define colorbar dimensions - colorbar_height = 20 - colorbar_width = chart_width - - # Create dataframe for gradient visualization - df_gradient = pd.DataFrame({"x_grad": [0, 1], "y_grad": [0, 1]}) - - # Create evenly distributed tick values - axis_values = np.linspace(vmin, vmax, 11) - tick_positions = np.linspace(10, colorbar_width - 10, 11) - - # Prepare data for axis and labels - axis_data = pd.DataFrame( - {"value_axis": axis_values, "x_axis": tick_positions} - ) - - # Create colorbar with linear gradient - colorbar_chart_obj = ( - alt.Chart(df_gradient) - .mark_rect( - x=20, - y=0, - width=colorbar_width - 20, - height=colorbar_height, - color=alt.Gradient( - gradient="linear", - stops=[ - alt.GradientStop(color=min_color_str, offset=0), - alt.GradientStop(color=max_color_str, offset=1), - ], - x1=0, - x2=1, # Horizontal gradient - y1=0, - y2=0, # Keep y constant - ), - ) - .encode( - x=alt.value(chart_width / 2), y=alt.value(8) - ) # Center colorbar - .properties(width=colorbar_width, height=colorbar_height) - ) - # Add tick marks to colorbar - axis_chart = ( - alt.Chart(axis_data) - .mark_tick(thickness=2, size=10) - .encode( - x=alt.X("x_axis:Q", axis=None), - y=alt.value(colorbar_height - 2), - ) - ) - # Add value labels below tick marks - text_labels = ( - alt.Chart(axis_data) - .mark_text(baseline="top", fontSize=10, dy=0) - .encode( - x=alt.X("x_axis:Q"), - text=alt.Text("value_axis:Q", format=".1f"), - y=alt.value(colorbar_height + 10), - ) - ) - # Add title to colorbar - title_chart = ( - alt.Chart(pd.DataFrame([{"text_title": layer_name}])) - .mark_text( - fontSize=12, - fontWeight="bold", - baseline="bottom", - align="center", - ) - .encode( - text="text_title:N", - x=alt.value(colorbar_width / 2), - y=alt.value(colorbar_height + 40), - ) - ) - # Combine all colorbar components - combined_colorbar = alt.layer( - colorbar_chart_obj, axis_chart, text_labels, title_chart - ).properties(width=colorbar_width, height=colorbar_height + 50) - - bar_chart_viz = ( - alt.vconcat(bar_chart_viz, combined_colorbar).resolve_scale( - color="independent" - ) - if bar_chart_viz is not None - else combined_colorbar - ) elif colormap: cmap = colormap - cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax]) - - current_chart = ( - alt.Chart(df) - .mark_rect(opacity=alpha) - .encode( - x=alt.X("x:O", axis=None), - y=alt.Y("y:O", axis=None), - color=alt.Color( - "value:Q", - scale=cmap_scale, - title=layer_name, - legend=alt.Legend(title=layer_name) - if portrayal.colorbar - else None, - ), - ) - .properties(width=chart_width, height=chart_height) - ) - base = ( - alt.layer(current_chart, base) - if base is not None - else current_chart - ) + color_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax]) + opacity = alpha + else: raise ValueError( f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) - return (base, bar_chart_viz) + + current_chart = ( + alt.Chart(df) + .mark_rect(opacity=opacity) + .encode( + x=alt.X("x:O", axis=None), + y=alt.Y("y:O", axis=None), + color=alt.Color( + "value:Q", + scale=color_scale, + title=layer_name, + legend=alt.Legend(title=layer_name, orient="bottom") + if portrayal.colorbar + else None, + ), + ) + .properties(width=chart_width, height=chart_height) + ) + + if current_chart is not None: + main_charts.append(current_chart) + + base = alt.layer(*main_charts).resolve_scale(color="independent") + return base diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py index db21723c404..bdd43207fda 100644 --- a/mesa/visualization/components/__init__.py +++ b/mesa/visualization/components/__init__.py @@ -4,7 +4,11 @@ from collections.abc import Callable -from .altair_components import SpaceAltair, make_altair_space +from .altair_components import ( + SpaceAltair, + make_altair_plot_component, + make_altair_space, +) from .matplotlib_components import ( SpaceMatplotlib, make_mpl_plot_component, @@ -80,16 +84,13 @@ def make_plot_component( backend: the backend to use {"matplotlib", "altair"} plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component - Notes: - altair plotting backend is not yet implemented and planned for mesa 3.1. - Returns: function: A function that creates a plot component """ if backend == "matplotlib": return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs) elif backend == "altair": - raise NotImplementedError("altair line plots are not yet implemented") + return make_altair_plot_component(measure, post_process, **plot_drawing_kwargs) else: raise ValueError( f"unknown backend {backend}, must be one of matplotlib, altair" diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index 2ad3b249fa0..9df5636d251 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,6 +1,7 @@ """Altair based solara components for visualization mesa spaces.""" import warnings +from collections.abc import Callable import altair as alt import numpy as np @@ -448,3 +449,86 @@ def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal): f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) return base, bar_chart + + +def make_altair_plot_component( + measure: str | dict[str, str] | list[str] | tuple[str], + post_process: Callable | None = None, + grid=False, +): + """Create a plotting function for a specified measure. + + Args: + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + post_process: a user-specified callable to do post-processing called with the Axes instance. + grid: Bool to draw grid or not. + + Returns: + function: A function that creates a PlotAltair component. + """ + + def MakePlotAltair(model): + return PlotAltair(model, measure, post_process=post_process, grid=grid) + + return MakePlotAltair + + +@solara.component +def PlotAltair(model, measure, post_process: Callable | None = None, grid=False): + """Create an Altair-based plot for a measure or measures. + + Args: + model (mesa.Model): The model instance. + measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot. + If a dict is given, keys are measure names and values are colors. + post_process: A user-specified callable for post-processing, called + with the Altair Chart instance. + grid: Bool to draw grid or not. + + Returns: + solara.FigureAltair: A component for rendering the plot. + """ + update_counter.get() + df = model.datacollector.get_model_vars_dataframe().reset_index() + df = df.rename(columns={"index": "Step"}) + + y_title = "Value" + if isinstance(measure, str): + measures_to_plot = [measure] + y_title = measure + elif isinstance(measure, list | tuple): + measures_to_plot = list(measure) + elif isinstance(measure, dict): + measures_to_plot = list(measure.keys()) + + df_long = df.melt( + id_vars=["Step"], + value_vars=measures_to_plot, + var_name="Measure", + value_name="Value", + ) + + chart = ( + alt.Chart(df_long) + .mark_line() + .encode( + x=alt.X("Step:Q", axis=alt.Axis(tickMinStep=1, title="Step", grid=grid)), + y=alt.Y("Value:Q", axis=alt.Axis(title=y_title, grid=grid)), + tooltip=["Step", "Measure", "Value"], + ) + .properties(width=450, height=350) + .interactive() + ) + + if len(measures_to_plot) > 0: + color_args = {} + if isinstance(measure, dict): + color_args["scale"] = alt.Scale( + domain=list(measure.keys()), range=list(measure.values()) + ) + chart = chart.encode(color=alt.Color("Measure:N", **color_args)) + + if post_process is not None: + chart = post_process(chart) + + return solara.FigureAltair(chart) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index f554b8913ed..5dbae2598a7 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -142,7 +142,7 @@ def SolaraViz( if renderer is not None: if isinstance(renderer, SpaceRenderer): renderer = solara.use_reactive(renderer) # noqa: RUF100 # noqa: SH102 - display_components.append(create_space_component(renderer.value)) + display_components.insert(0, create_space_component(renderer.value)) with solara.AppBar(): solara.AppBarTitle(name if name else model.value.__class__.__name__) @@ -297,7 +297,7 @@ def SpaceRendererComponent( else: structure = renderer.space_mesh if renderer.space_mesh else None agents = renderer.agent_mesh if renderer.agent_mesh else None - prop_base, prop_cbar = renderer.propertylayer_mesh or (None, None) + propertylayer = renderer.propertylayer_mesh or None if renderer.space_mesh: structure = renderer.draw_structure(**renderer.space_kwargs) @@ -306,34 +306,22 @@ def SpaceRendererComponent( renderer.agent_portrayal, **renderer.agent_kwargs ) if renderer.propertylayer_mesh: - prop_base, prop_cbar = renderer.draw_propertylayer( + propertylayer = renderer.draw_propertylayer( renderer.propertylayer_portrayal ) spatial_charts_list = [ - chart for chart in [structure, prop_base, agents] if chart + chart for chart in [structure, propertylayer, agents] if chart ] - main_spatial = None + final_chart = None if spatial_charts_list: - main_spatial = ( + final_chart = ( spatial_charts_list[0] if len(spatial_charts_list) == 1 else alt.layer(*spatial_charts_list) ) - # Determine final chart by combining with color bar if present - final_chart = None - if main_spatial and prop_cbar: - final_chart = alt.vconcat(main_spatial, prop_cbar).configure_view( - stroke=None - ) - elif main_spatial: # Only main_spatial, no prop_cbar - final_chart = main_spatial - elif prop_cbar: # Only prop_cbar, no main_spatial - final_chart = prop_cbar - final_chart = final_chart.configure_view(grid=False) - if final_chart is None: # If no charts are available, return an empty chart final_chart = ( diff --git a/mesa/visualization/space_renderer.py b/mesa/visualization/space_renderer.py index c8ef183972a..f5924916cd7 100644 --- a/mesa/visualization/space_renderer.py +++ b/mesa/visualization/space_renderer.py @@ -309,6 +309,7 @@ def render( self.draw_propertylayer(propertylayer_portrayal) self.post_process_func = post_process + return self @property def canvas(self): diff --git a/tests/test_backends.py b/tests/test_backends.py index a98e08c7a74..d4c15d36185 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -271,8 +271,7 @@ def propertylayer_portrayal_color(layer): result = ab.draw_propertylayer( space, space._mesa_property_layers, propertylayer_portrayal_color ) - assert result[0] is not None - assert result[1] is None + assert result is not None # Test with colormap def propertylayer_portrayal_colormap(layer): @@ -283,7 +282,7 @@ def propertylayer_portrayal_colormap(layer): result = ab.draw_propertylayer( space, space._mesa_property_layers, propertylayer_portrayal_colormap ) - assert result[0] is not None + assert result is not None # Test with no color or colormap def propertylayer_portrayal(layer):