From b57d960f6a3c00e5af6f988d1380399d8639b897 Mon Sep 17 00:00:00 2001 From: Derek Thirstrup Date: Tue, 5 Aug 2025 10:08:42 -0700 Subject: [PATCH] Update gui.py for fix: GUI flow visualization indexing and gradZ contrast bugs - Fix IndexError when viewing flows with different array layouts (C,Z,Y,X) - Add robust Z-axis slicing for flow views with _pick_z_image() helper - Apply normalize99() to gradZ for contrast stretching - Correct flow array index mapping for 3D gradZ/cellprob views - Prevent TypeError from incompatible array shapes in pyqtgraph Resolves crashes and poor visibility in 3D flow visualizations. --- cellpose/gui/gui.py | 58 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 99c7fcd7..2547ba89 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -1383,9 +1383,58 @@ def update_plot(self): self.img.setImage(image, autoLevels=False, lut=self.cmap[0]) self.img.setLevels(self.saturation[0][self.currentZ]) else: - image = np.zeros((self.Ly, self.Lx), np.uint8) - if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0: - image = self.flows[self.view - 1][self.currentZ] + # --- Robust Z-slicing for flow views (e.g., gradXY, cellprob) --- + # Flows may come in as (Z, Y, X), (Z, Y, X, C), (C, Z, Y, X), or (Y, X, C) for NZ=1. + # We slice along the true Z axis and ensure channels (if any) are last. + def _pick_z_image(arr, z_idx, nZ, expect_rgb=False): + a = arr + shp = a.shape + # Identify the Z axis by matching nZ, if possible + z_axis = None + if nZ in shp: + for ax, s in enumerate(shp): + if s == nZ: + z_axis = ax + break + # Slice along Z when we have a stack + if z_axis is not None and a.ndim >= 3 and nZ > 1: + z_idx = max(0, min(z_idx, nZ - 1)) + a = np.take(a, z_idx, axis=z_axis) + # Ensure channels-last for pyqtgraph when we have a small channel axis (<=4) + if a.ndim == 3: + if a.shape[-1] <= 4: + pass # already channels-last + elif a.shape[0] <= 4: + a = np.moveaxis(a, 0, -1) + elif a.shape[1] <= 4: + a = np.moveaxis(a, 1, -1) + # If the last axis still looks like Z (>>4), collapse to 2D + if a.shape[-1] > 4: + a = a.mean(axis=-1) + # If there are extra dims, progressively collapse them + while a.ndim > 3: + a = a.mean(axis=0) + return a + + # Default canvas matches expected modality: RGB for gradXY, scalar for cellprob + image = (np.zeros((self.Ly, self.Lx, 3), np.uint8) + if self.view == 1 else np.zeros((self.Ly, self.Lx), np.uint8)) + + if (self.view - 1) < len(self.flows): + # Handle flow indexing correctly for 3D + # Flow array structure: [0]=gradXY, [1]=cellprob, [2]=rawXY, [3]=origCellprob, [4]=gradZ + flow_idx = self.view - 1 + if self.load_3D and self.view == 2: # cellprob view + flow_idx = 1 # Use flows[1] for cellprob (normalized) + elif self.load_3D and self.view == 3: # gradZ view in 3D + flow_idx = 4 if len(self.flows) > 4 else 1 # Use flows[4] for gradZ, fallback to flows[1] + + flow = self.flows[flow_idx] + nZ = getattr(self, "NZ", None) + if not nZ or nZ < 1: + nZ = self.stack.shape[0] if self.stack.ndim >= 3 else 1 + image = _pick_z_image(flow, self.currentZ, nZ, expect_rgb=(self.view == 1)) + if self.view > 1: self.img.setImage(image, autoLevels=False, lut=self.bwr) else: @@ -1947,7 +1996,8 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True): if self.load_3D: if stitch_threshold == 0.: - flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8")) + # Apply the same normalize99 treatment to gradZ as cellprob gets for better contrast + flows_new.append((np.clip(normalize99(flows[1][0].copy()), 0, 1) * 255).astype("uint8")) else: flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))