Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 94 additions & 14 deletions trackpy/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def wrapper(*args, **kwargs):
show_plot = (plt.get_backend() != "agg")
else:
show_plot = False

#if kwargs.get('fig') is None:
# kwargs['fig'] = plt.gcf()
# # show plot unless the matplotlib backend is headless
# show_plot = (plt.get_backend() != "agg")
#else:
# show_plot = False

# Delete legend keyword so remaining ones can be passed to plot().
legend = kwargs.pop('legend', False)
Expand Down Expand Up @@ -261,25 +268,29 @@ def scatter3d(*args, **kwargs):


@make_axes
def plot_traj(traj, colorby='particle', mpp=None, label=False,
superimpose=None, cmap=None, ax=None, t_column=None,
pos_columns=None, plot_style={}, **kwargs):
def plot_traj(traj, colorby='particle', mpp=None, fps=None, label=False,
superimpose=None, cmap=None, fig=None, ax=None, t_column=None,
pos_columns=None, plot_style={}, colorbar_units='frames', **kwargs):
"""Plot traces of trajectories for each particle.
Optionally superimpose it on a frame from the video.

Parameters
----------
traj : DataFrame
The DataFrame should include time and spatial coordinate columns.
colorby : {'particle', 'frame'}, optional
colorby : {'particle', 'frame', 'velocity'}, optional
mpp : float, optional
Microns per pixel. If omitted, the labels will have units of pixels.
fps : int, optional
Number of frames in one second.
label : boolean, optional
Set to True to write particle ID numbers next to trajectories.
superimpose : ndarray, optional
Background image, default None
cmap : colormap, optional
This is only used in colorby='frame' mode. Default = mpl.cm.winter
fig : matplotlib figure object, optional
Defaults to current figure
ax : matplotlib axes object, optional
Defaults to current axes
t_column : string, optional
Expand All @@ -288,6 +299,11 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False,
Dataframe column names for spatial coordinates. Default is ['x', 'y'].
plot_style : dictionary
Keyword arguments passed through to the `Axes.plot(...)` command
colorbar_units : {None, 'frames', 'real'} string, optional
Set to None to disable the colorbar. If 'frames' is selected then
units like frames and px will be used. If 'real' is selected then
units like seconds and um will be used.


Returns
-------
Expand All @@ -301,6 +317,8 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False,
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

fig = plt.gcf()

if cmap is None:
cmap = plt.cm.winter
if t_column is None:
Expand All @@ -312,15 +330,37 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False,
_plot_style = dict(linewidth=1)
_plot_style.update(**_normalize_kwargs(plot_style, 'line2d'))

colorbar_title = ""
frames_to_seconds_factor = 1.0

# Select the colorbar title according to the type of data plotted.
if colorbar_units == 'frames':
if colorby == 'frame':
colorbar_title = "Frame number"
if colorby == 'velocity':
colorbar_title = "Speed (px/s)"

if colorbar_units == 'real':
if colorby == 'frame':
colorbar_title = "Time (s)"
if fps is None:
fps = 24 # Set default framerate to 24 fps if not specified.
frames_to_seconds_factor = 1.0 / fps
if colorby == 'velocity':
if mpl.rcParams['text.usetex']:
colorbar_title = r'Speed (\textmu m/s)'
else:
colorbar_title = 'Speed (\xb5m/s)'

# Axes labels
if mpp is None:
_set_labels(ax, '{} [px]', pos_columns)
_set_labels(ax, '{} (px)', pos_columns)
mpp = 1. # for computations of image extent below
else:
if mpl.rcParams['text.usetex']:
_set_labels(ax, r'{} [\textmu m]', pos_columns)
_set_labels(ax, r'{} (\textmu m)', pos_columns)
else:
_set_labels(ax, r'{} [\xb5m]', pos_columns)
_set_labels(ax, '{} (\xb5m)', pos_columns)
# Background image
if superimpose is not None:
ax.imshow(superimpose, cmap=plt.cm.gray,
Expand All @@ -338,23 +378,63 @@ def plot_traj(traj, colorby='particle', mpp=None, label=False,
# Read https://scipy-cookbook.readthedocs.io/items/Matplotlib_MulticoloredLine.html
x = traj.set_index([t_column, 'particle'])['x'].unstack()
y = traj.set_index([t_column, 'particle'])['y'].unstack()
color_numbers = traj[t_column].values/float(traj[t_column].max())
norm = plt.Normalize(color_numbers.min(), color_numbers.max())
color_max = float(traj[t_column].max())
norm = plt.Normalize(traj[t_column].min() * frames_to_seconds_factor,
traj[t_column].max() * frames_to_seconds_factor)
logger.info("Drawing multicolor lines takes a while. "
"Come back in a minute.")
for particle in x:
x_series = x[particle].dropna()
y_series = y[particle].dropna()
frames = np.array(x_series.index)
frames_map = frames/color_max
points = np.array([x_series.values, y_series.values]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(frames_map)
ax.add_collection(lc)
lc.set_array(frames * frames_to_seconds_factor)
lines = ax.add_collection(lc)
ax.set_xlim(x.apply(np.min).min(), x.apply(np.max).max())
ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max())
if colorbar_units is not None:
fig.colorbar(lines, shrink=0.5, label=colorbar_title)

if colorby == 'velocity':
x = traj.set_index([t_column, 'particle'])['x'].unstack()
y = traj.set_index([t_column, 'particle'])['y'].unstack()
logger.info("Drawing multicolor lines takes a while. "
"Come back in a minute.")
# Normalization of the colormap has to be done AFTER the velocities
# are calculated, so the results are stored in the following lists.
list_of_segments = []
list_of_min_velocities = []
list_of_max_velocities = []
list_of_velocities = []
# Computation of velocities and segments to be drawn.
for particle in x:
x_series = x[particle].dropna()
y_series = y[particle].dropna()
if x_series.shape[0] <= 2:
continue
frames = np.array(x_series.index)
xy_displacements = np.gradient(np.array([x_series, y_series]), frames, axis=1, edge_order=2)
displacement_um = np.linalg.norm(xy_displacements, axis=0) * mpp
velocity = displacement_um * frames_to_seconds_factor
list_of_velocities.append(velocity)
list_of_min_velocities.append(np.min(velocity))
list_of_max_velocities.append(np.max(velocity))
points = np.array([x_series.values, y_series.values]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
list_of_segments.append(segments)
# Normalization of the colormap.
norm = plt.Normalize(np.min(np.array(list_of_min_velocities)),
np.max(np.array(list_of_max_velocities)))
# Drawing segments.
for segments, velocity in zip(list_of_segments, list_of_velocities):
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(velocity)
lines = ax.add_collection(lc)
ax.set_xlim(x.apply(np.min).min(), x.apply(np.max).max())
ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max())
ax.set_ylim(y.apply(np.min).min(), y.apply(np.max).max())
if colorbar_units is not None:
fig.colorbar(lines, shrink=0.5, label=colorbar_title)
if label:
unstacked = traj.set_index([t_column, 'particle'])[pos_columns].unstack()
first_frame = int(traj[t_column].min())
Expand Down
Loading