Skip to content

Add Altair plotting functionality #2810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 39 additions & 174 deletions mesa/visualization/backends/altair_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,13 @@ def draw_agents(
"x:Q",
title=xlabel,
scale=alt.Scale(type="linear", domain=[xmin, xmax]),
axis=None,
),
y=alt.Y(
"y:Q",
title=ylabel,
scale=alt.Scale(type="linear", domain=[ymin, ymax]),
axis=None,
),
size=alt.Size("size:Q", legend=None, scale=alt.Scale(domain=[0, 50])),
shape=alt.Shape(
Expand Down Expand Up @@ -352,8 +354,7 @@ def draw_propertylayer(
Returns:
alt.Chart: A tuple containing the base chart and the color bar chart.
"""
base = None
bar_chart_viz = None
main_charts = []

for layer_name in property_layers:
if layer_name == "empty":
Expand Down Expand Up @@ -384,7 +385,6 @@ def draw_propertylayer(
vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data)
vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data)

# Prepare data for Altair
df = pd.DataFrame(
{
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
Expand All @@ -393,183 +393,48 @@ def draw_propertylayer(
}
)

current_chart = None
if color:
# Create a function to map values to RGBA colors with proper opacity scaling
def apply_rgba(
val, v_min=vmin, v_max=vmax, a=alpha, p_color=portrayal.color
):
# Normalize value to range [0,1] and clamp
normalized = max(
0,
min(
((val - v_min) / (v_max - v_min))
if (v_max - v_min) != 0
else 0.5,
1,
),
)

# Scale opacity by alpha parameter
opacity = normalized * a

# Convert color to RGB components
rgb_color_val = to_rgb(p_color)
r = int(rgb_color_val[0] * 255)
g = int(rgb_color_val[1] * 255)
b = int(rgb_color_val[2] * 255)
return f"rgba({r}, {g}, {b}, {opacity:.2f})"

# Apply color mapping to each value in the dataset
df["color_str"] = df["value"].apply(apply_rgba)

# Create chart for the property layer
current_chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
fill=alt.Fill("color_str:N", scale=None),
)
.properties(
width=chart_width, height=chart_height, title=layer_name
)
# For a single color gradient, we define the range from transparent to solid.
rgb = to_rgb(color)
r, g, b = (int(c * 255) for c in rgb)

min_color = f"rgba({r},{g},{b},0)"
max_color = f"rgba({r},{g},{b},{alpha})"
opacity = 1
color_scale = alt.Scale(
range=[min_color, max_color], domain=[vmin, vmax]
)
base = (
alt.layer(current_chart, base)
if base is not None
else current_chart
)

# Add colorbar if specified in portrayal
if portrayal.colorbar:
# Extract RGB components from base color
rgb_color_val = to_rgb(portrayal.color)
r_int = int(rgb_color_val[0] * 255)
g_int = int(rgb_color_val[1] * 255)
b_int = int(rgb_color_val[2] * 255)

# Define gradient endpoints
min_color_str = f"rgba({r_int},{g_int},{b_int},0)"
max_color_str = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})"

# Define colorbar dimensions
colorbar_height = 20
colorbar_width = chart_width

# Create dataframe for gradient visualization
df_gradient = pd.DataFrame({"x_grad": [0, 1], "y_grad": [0, 1]})

# Create evenly distributed tick values
axis_values = np.linspace(vmin, vmax, 11)
tick_positions = np.linspace(10, colorbar_width - 10, 11)

# Prepare data for axis and labels
axis_data = pd.DataFrame(
{"value_axis": axis_values, "x_axis": tick_positions}
)

# Create colorbar with linear gradient
colorbar_chart_obj = (
alt.Chart(df_gradient)
.mark_rect(
x=20,
y=0,
width=colorbar_width - 20,
height=colorbar_height,
color=alt.Gradient(
gradient="linear",
stops=[
alt.GradientStop(color=min_color_str, offset=0),
alt.GradientStop(color=max_color_str, offset=1),
],
x1=0,
x2=1, # Horizontal gradient
y1=0,
y2=0, # Keep y constant
),
)
.encode(
x=alt.value(chart_width / 2), y=alt.value(8)
) # Center colorbar
.properties(width=colorbar_width, height=colorbar_height)
)
# Add tick marks to colorbar
axis_chart = (
alt.Chart(axis_data)
.mark_tick(thickness=2, size=10)
.encode(
x=alt.X("x_axis:Q", axis=None),
y=alt.value(colorbar_height - 2),
)
)
# Add value labels below tick marks
text_labels = (
alt.Chart(axis_data)
.mark_text(baseline="top", fontSize=10, dy=0)
.encode(
x=alt.X("x_axis:Q"),
text=alt.Text("value_axis:Q", format=".1f"),
y=alt.value(colorbar_height + 10),
)
)
# Add title to colorbar
title_chart = (
alt.Chart(pd.DataFrame([{"text_title": layer_name}]))
.mark_text(
fontSize=12,
fontWeight="bold",
baseline="bottom",
align="center",
)
.encode(
text="text_title:N",
x=alt.value(colorbar_width / 2),
y=alt.value(colorbar_height + 40),
)
)
# Combine all colorbar components
combined_colorbar = alt.layer(
colorbar_chart_obj, axis_chart, text_labels, title_chart
).properties(width=colorbar_width, height=colorbar_height + 50)

bar_chart_viz = (
alt.vconcat(bar_chart_viz, combined_colorbar).resolve_scale(
color="independent"
)
if bar_chart_viz is not None
else combined_colorbar
)

elif colormap:
cmap = colormap
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

current_chart = (
alt.Chart(df)
.mark_rect(opacity=alpha)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color(
"value:Q",
scale=cmap_scale,
title=layer_name,
legend=alt.Legend(title=layer_name)
if portrayal.colorbar
else None,
),
)
.properties(width=chart_width, height=chart_height)
)
base = (
alt.layer(current_chart, base)
if base is not None
else current_chart
)
color_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])
opacity = alpha

else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return (base, bar_chart_viz)

current_chart = (
alt.Chart(df)
.mark_rect(opacity=opacity)
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color(
"value:Q",
scale=color_scale,
title=layer_name,
legend=alt.Legend(title=layer_name, orient="bottom")
if portrayal.colorbar
else None,
),
)
.properties(width=chart_width, height=chart_height)
)

if current_chart is not None:
main_charts.append(current_chart)

base = alt.layer(*main_charts).resolve_scale(color="independent")
return base
11 changes: 6 additions & 5 deletions mesa/visualization/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from collections.abc import Callable

from .altair_components import SpaceAltair, make_altair_space
from .altair_components import (
SpaceAltair,
make_altair_plot_component,
make_altair_space,
)
from .matplotlib_components import (
SpaceMatplotlib,
make_mpl_plot_component,
Expand Down Expand Up @@ -80,16 +84,13 @@ def make_plot_component(
backend: the backend to use {"matplotlib", "altair"}
plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component
Notes:
altair plotting backend is not yet implemented and planned for mesa 3.1.
Returns:
function: A function that creates a plot component
"""
if backend == "matplotlib":
return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs)
elif backend == "altair":
raise NotImplementedError("altair line plots are not yet implemented")
return make_altair_plot_component(measure, post_process, **plot_drawing_kwargs)
else:
raise ValueError(
f"unknown backend {backend}, must be one of matplotlib, altair"
Expand Down
84 changes: 84 additions & 0 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Altair based solara components for visualization mesa spaces."""

import warnings
from collections.abc import Callable

import altair as alt
import numpy as np
Expand Down Expand Up @@ -448,3 +449,86 @@ def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return base, bar_chart


def make_altair_plot_component(
measure: str | dict[str, str] | list[str] | tuple[str],
post_process: Callable | None = None,
grid=False,
):
"""Create a plotting function for a specified measure.

Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
post_process: a user-specified callable to do post-processing called with the Axes instance.
grid: Bool to draw grid or not.

Returns:
function: A function that creates a PlotAltair component.
"""

def MakePlotAltair(model):
return PlotAltair(model, measure, post_process=post_process, grid=grid)

return MakePlotAltair


@solara.component
def PlotAltair(model, measure, post_process: Callable | None = None, grid=False):
"""Create an Altair-based plot for a measure or measures.

Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
If a dict is given, keys are measure names and values are colors.
post_process: A user-specified callable for post-processing, called
with the Altair Chart instance.
grid: Bool to draw grid or not.

Returns:
solara.FigureAltair: A component for rendering the plot.
"""
update_counter.get()
df = model.datacollector.get_model_vars_dataframe().reset_index()
df = df.rename(columns={"index": "Step"})

y_title = "Value"
if isinstance(measure, str):
measures_to_plot = [measure]
y_title = measure
elif isinstance(measure, list | tuple):
measures_to_plot = list(measure)
elif isinstance(measure, dict):
measures_to_plot = list(measure.keys())

df_long = df.melt(
id_vars=["Step"],
value_vars=measures_to_plot,
var_name="Measure",
value_name="Value",
)

chart = (
alt.Chart(df_long)
.mark_line()
.encode(
x=alt.X("Step:Q", axis=alt.Axis(tickMinStep=1, title="Step", grid=grid)),
y=alt.Y("Value:Q", axis=alt.Axis(title=y_title, grid=grid)),
tooltip=["Step", "Measure", "Value"],
)
.properties(width=450, height=350)
.interactive()
)

if len(measures_to_plot) > 0:
color_args = {}
if isinstance(measure, dict):
color_args["scale"] = alt.Scale(
domain=list(measure.keys()), range=list(measure.values())
)
chart = chart.encode(color=alt.Color("Measure:N", **color_args))

if post_process is not None:
chart = post_process(chart)

return solara.FigureAltair(chart)
Loading
Loading