Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions cellpose/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def dx_to_circ(dP):
return rgb


def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None, seg_norm=False, share_axes=False):
"""Plot segmentation results (like on website).

Can save each panel of figure with file_name option. Use channels option if
Expand All @@ -58,47 +58,53 @@ def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
share_axes (bool, optional): Share x and y axes between subplots for synchronized zooming. Defaults to False.
"""
if not MATPLOTLIB_ENABLED:
raise ImportError(
"matplotlib not installed, install with 'pip install matplotlib'")
ax = fig.add_subplot(1, 4, 1)

# 1. Process initial image
img0 = img.copy()

if img0.shape[0] < 4:
img0 = np.transpose(img0, (1, 2, 0))
if img0.shape[-1] < 3 or img0.ndim < 3:
img0 = image_to_rgb(img0, channels=channels)
else:
if img0.max() <= 50.0:
img0 = np.uint8(np.clip(img0, 0, 1) * 255)
ax.imshow(img0)
ax.set_title("original image")
ax.axis("off")

# 2. Prepare plot data
outlines = utils.masks_to_outlines(maski)

overlay = mask_overlay(img0, maski)

ax = fig.add_subplot(1, 4, 2)

outX, outY = np.nonzero(outlines)
imgout = img0.copy()
imgout[outX, outY] = np.array([255, 0, 0]) # pure red

ax.imshow(imgout)
ax.set_title("predicted outlines")
ax.axis("off")

ax = fig.add_subplot(1, 4, 3)
ax.imshow(overlay)
ax.set_title("predicted masks")
ax.axis("off")
imgout[outX, outY] = np.array([255, 0, 0])

# List of (data, title) for the 4 subplots
plot_data = [
(img0, "original image"),
(imgout, "predicted outlines"),
(overlay, "predicted masks"),
(flowi, "predicted cell pose")
]

ax = fig.add_subplot(1, 4, 4)
ax.imshow(flowi)
ax.set_title("predicted cell pose")
ax.axis("off")
# 3. Create subplots in a loop
axes = []
for i, (data, title) in enumerate(plot_data):
# Determine sharing: only for subplots index 1, 2, 3 and if share_axes is True
if i > 0 and share_axes:
ax = fig.add_subplot(1, 4, i + 1, sharex=axes[0], sharey=axes[0])
else:
ax = fig.add_subplot(1, 4, i + 1)

ax.imshow(data)
ax.set_title(title)
ax.axis("off")
axes.append(ax)

# 4. Handle file saving
if file_name is not None:
save_path = os.path.splitext(file_name)[0]
io.imsave(save_path + "_overlay.jpg", overlay)
Expand Down
Loading