Skip to content

Commit 24a1df3

Browse files
authored
solara: Implement visualization for network grid (#1767)
1 parent 1ee5cda commit 24a1df3

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

mesa/experimental/jupyter_viz.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import threading
22

33
import matplotlib.pyplot as plt
4+
import networkx as nx
45
import reacton.ipywidgets as widgets
56
import solara
67
from matplotlib.figure import Figure
78
from matplotlib.ticker import MaxNLocator
89

10+
import mesa
11+
912
# Avoid interactive backend
1013
plt.switch_backend("agg")
1114

@@ -91,10 +94,24 @@ def portray(self, g):
9194
return out
9295

9396

97+
def _draw_network_grid(viz, space_ax):
98+
graph = viz.model.grid.G
99+
pos = nx.spring_layout(graph, seed=0)
100+
nx.draw(
101+
graph,
102+
ax=space_ax,
103+
pos=pos,
104+
**viz.agent_portrayal(graph),
105+
)
106+
107+
94108
def make_space(viz):
95109
space_fig = Figure()
96110
space_ax = space_fig.subplots()
97-
space_ax.scatter(**viz.portray(viz.model.grid))
111+
if isinstance(viz.model.grid, mesa.space.NetworkGrid):
112+
_draw_network_grid(viz, space_ax)
113+
else:
114+
space_ax.scatter(**viz.portray(viz.model.grid))
98115
space_ax.set_axis_off()
99116
solara.FigureMatplotlib(space_fig, dependencies=[viz.model, viz.df])
100117

0 commit comments

Comments
 (0)