Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,9 +1383,58 @@ def update_plot(self):
self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
self.img.setLevels(self.saturation[0][self.currentZ])
else:
image = np.zeros((self.Ly, self.Lx), np.uint8)
if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
image = self.flows[self.view - 1][self.currentZ]
# --- Robust Z-slicing for flow views (e.g., gradXY, cellprob) ---
# Flows may come in as (Z, Y, X), (Z, Y, X, C), (C, Z, Y, X), or (Y, X, C) for NZ=1.
# We slice along the true Z axis and ensure channels (if any) are last.
def _pick_z_image(arr, z_idx, nZ, expect_rgb=False):
a = arr
shp = a.shape
# Identify the Z axis by matching nZ, if possible
z_axis = None
if nZ in shp:
for ax, s in enumerate(shp):
if s == nZ:
z_axis = ax
break
# Slice along Z when we have a stack
if z_axis is not None and a.ndim >= 3 and nZ > 1:
z_idx = max(0, min(z_idx, nZ - 1))
a = np.take(a, z_idx, axis=z_axis)
# Ensure channels-last for pyqtgraph when we have a small channel axis (<=4)
if a.ndim == 3:
if a.shape[-1] <= 4:
pass # already channels-last
elif a.shape[0] <= 4:
a = np.moveaxis(a, 0, -1)
elif a.shape[1] <= 4:
a = np.moveaxis(a, 1, -1)
# If the last axis still looks like Z (>>4), collapse to 2D
if a.shape[-1] > 4:
a = a.mean(axis=-1)
# If there are extra dims, progressively collapse them
while a.ndim > 3:
a = a.mean(axis=0)
return a

# Default canvas matches expected modality: RGB for gradXY, scalar for cellprob
image = (np.zeros((self.Ly, self.Lx, 3), np.uint8)
if self.view == 1 else np.zeros((self.Ly, self.Lx), np.uint8))

if (self.view - 1) < len(self.flows):
# Handle flow indexing correctly for 3D
# Flow array structure: [0]=gradXY, [1]=cellprob, [2]=rawXY, [3]=origCellprob, [4]=gradZ
flow_idx = self.view - 1
if self.load_3D and self.view == 2: # cellprob view
flow_idx = 1 # Use flows[1] for cellprob (normalized)
elif self.load_3D and self.view == 3: # gradZ view in 3D
flow_idx = 4 if len(self.flows) > 4 else 1 # Use flows[4] for gradZ, fallback to flows[1]

flow = self.flows[flow_idx]
nZ = getattr(self, "NZ", None)
if not nZ or nZ < 1:
nZ = self.stack.shape[0] if self.stack.ndim >= 3 else 1
image = _pick_z_image(flow, self.currentZ, nZ, expect_rgb=(self.view == 1))

if self.view > 1:
self.img.setImage(image, autoLevels=False, lut=self.bwr)
else:
Expand Down Expand Up @@ -1947,7 +1996,8 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):

if self.load_3D:
if stitch_threshold == 0.:
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
# Apply the same normalize99 treatment to gradZ as cellprob gets for better contrast
flows_new.append((np.clip(normalize99(flows[1][0].copy()), 0, 1) * 255).astype("uint8"))
else:
flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))

Expand Down
Loading