diff --git a/cellpose/plot.py b/cellpose/plot.py index 99bacddc..81e533dc 100644 --- a/cellpose/plot.py +++ b/cellpose/plot.py @@ -44,7 +44,7 @@ def dx_to_circ(dP): return rgb -def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None): +def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None, seg_norm=False, share_axes=False): """Plot segmentation results (like on website). Can save each panel of figure with file_name option. Use channels option if @@ -58,13 +58,14 @@ def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None): channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0]. file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None. seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False. + share_axes (bool, optional): Share x and y axes between subplots for synchronized zooming. Defaults to False. """ if not MATPLOTLIB_ENABLED: raise ImportError( "matplotlib not installed, install with 'pip install matplotlib'") - ax = fig.add_subplot(1, 4, 1) + + # 1. Process initial image img0 = img.copy() - if img0.shape[0] < 4: img0 = np.transpose(img0, (1, 2, 0)) if img0.shape[-1] < 3 or img0.ndim < 3: @@ -72,33 +73,38 @@ def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None): else: if img0.max() <= 50.0: img0 = np.uint8(np.clip(img0, 0, 1) * 255) - ax.imshow(img0) - ax.set_title("original image") - ax.axis("off") + # 2. Prepare plot data outlines = utils.masks_to_outlines(maski) - overlay = mask_overlay(img0, maski) - - ax = fig.add_subplot(1, 4, 2) + outX, outY = np.nonzero(outlines) imgout = img0.copy() - imgout[outX, outY] = np.array([255, 0, 0]) # pure red - - ax.imshow(imgout) - ax.set_title("predicted outlines") - ax.axis("off") - - ax = fig.add_subplot(1, 4, 3) - ax.imshow(overlay) - ax.set_title("predicted masks") - ax.axis("off") + imgout[outX, outY] = np.array([255, 0, 0]) + + # List of (data, title) for the 4 subplots + plot_data = [ + (img0, "original image"), + (imgout, "predicted outlines"), + (overlay, "predicted masks"), + (flowi, "predicted cell pose") + ] - ax = fig.add_subplot(1, 4, 4) - ax.imshow(flowi) - ax.set_title("predicted cell pose") - ax.axis("off") + # 3. Create subplots in a loop + axes = [] + for i, (data, title) in enumerate(plot_data): + # Determine sharing: only for subplots index 1, 2, 3 and if share_axes is True + if i > 0 and share_axes: + ax = fig.add_subplot(1, 4, i + 1, sharex=axes[0], sharey=axes[0]) + else: + ax = fig.add_subplot(1, 4, i + 1) + + ax.imshow(data) + ax.set_title(title) + ax.axis("off") + axes.append(ax) + # 4. Handle file saving if file_name is not None: save_path = os.path.splitext(file_name)[0] io.imsave(save_path + "_overlay.jpg", overlay)