From 3631b4b858b06b83609e0e978eee7fc62bb81015 Mon Sep 17 00:00:00 2001 From: "Kevin M. Dean" Date: Sat, 21 Mar 2026 15:34:27 -0500 Subject: [PATCH 1/4] Implement 3D registration pipeline and preprocessing GUI controls --- src/clearex/gui/app.py | 372 +++++- src/clearex/main.py | 262 +++- src/clearex/registration/pipeline.py | 1799 ++++++++++++++++++++++++++ src/clearex/workflow.py | 81 +- tests/gui/test_gui_execution.py | 63 + tests/registration/test_pipeline.py | 365 ++++++ tests/test_main.py | 126 +- tests/test_workflow.py | 65 +- 8 files changed, 3114 insertions(+), 19 deletions(-) create mode 100644 src/clearex/registration/pipeline.py create mode 100644 tests/registration/test_pipeline.py diff --git a/src/clearex/gui/app.py b/src/clearex/gui/app.py index 58a1fd1..6fdc27f 100644 --- a/src/clearex/gui/app.py +++ b/src/clearex/gui/app.py @@ -6865,10 +6865,9 @@ class AnalysisSelectionDialog(QDialog): _OPERATION_TABS: tuple[tuple[str, tuple[str, ...]], ...] = ( ( "Preprocessing", - ("flatfield", "deconvolution", "shear_transform"), + ("flatfield", "deconvolution", "shear_transform", "registration"), ), ("Segmentation", ("particle_detection", "usegment3d")), - ("Postprocessing", ("registration",)), ("Visualization", ("display_pyramid", "visualization", "mip_export")), ) _OPERATION_OUTPUT_COMPONENTS: Dict[str, str] = { @@ -6884,7 +6883,7 @@ class AnalysisSelectionDialog(QDialog): "results/particle_detection/latest/detections" ) _DEFAULT_USEGMENT3D_PARAMETERS: Dict[str, Any] = { - "execution_order": 5, + "execution_order": 6, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -6940,6 +6939,22 @@ class AnalysisSelectionDialog(QDialog): "postprocess_edt_fixed_point_percentile": 0.01, "output_dtype": "uint32", } + _DEFAULT_REGISTRATION_PARAMETERS: Dict[str, Any] = { + "execution_order": 4, + "input_source": "data", + "force_rerun": False, + "chunk_basis": "3d", + "detect_2d_per_slice": False, + "use_map_overlap": True, + "overlap_zyx": [8, 32, 32], + "memory_overhead_factor": 2.5, + "registration_channel": 0, + "registration_type": "rigid", + "input_resolution_level": 0, + "anchor_mode": "central", + "anchor_position": None, + "blend_mode": "feather", + } _PARAMETER_HINTS: Dict[str, str] = { "input_source": ( "Input source controls which dataset this operation reads from. " @@ -7157,6 +7172,29 @@ class AnalysisSelectionDialog(QDialog): "Input pyramid level used for segmentation (0 = full resolution). " "Higher levels reduce memory/runtime." ), + "registration_channel": ( + "Source channel used to estimate pairwise tile transforms. " + "The solved transforms are then applied to all channels." + ), + "registration_type": ( + "Pairwise ANTsPy transform family for overlap registration: " + "translation, rigid, or similarity." + ), + "registration_resolution_level": ( + "Input pyramid level used only for pairwise registration " + "(0 = full resolution). Final fusion always uses full resolution." + ), + "registration_anchor_mode": ( + "Fix either the most central tile automatically or a manually " + "selected anchor position during global optimization." + ), + "registration_anchor_position": ( + "Tile index held fixed when anchor mode is set to manual." + ), + "registration_blend_mode": ( + "Overlap fusion mode for the final stitched volume. Feather " + "weights edges to reduce seams; average uses uniform weights." + ), "usegment3d_output_reference_space": ( "Choose whether final labels are stored at level 0 (original " "resolution) or native selected input level." @@ -7342,6 +7380,15 @@ def __init__(self, initial: WorkflowConfig) -> None: ) ) self._operation_defaults["usegment3d"] = dict(self._usegment3d_defaults) + self._registration_defaults = dict( + self._operation_defaults.get( + "registration", + self._DEFAULT_REGISTRATION_PARAMETERS, + ) + ) + self._operation_defaults["registration"] = dict( + self._registration_defaults + ) self._visualization_defaults = dict( self._operation_defaults.get("visualization", {}) ) @@ -8269,6 +8316,9 @@ def _build_ui(self) -> None: self._usegment3d_resolution_level_spin.valueChanged.connect( self._set_usegment3d_parameter_enabled_state ) + self._registration_anchor_mode_combo.currentIndexChanged.connect( + self._set_registration_parameter_enabled_state + ) self._usegment3d_output_reference_combo.currentIndexChanged.connect( self._set_usegment3d_parameter_enabled_state ) @@ -8460,6 +8510,8 @@ def _build_operation_panel(self, operation_name: str) -> QWidget: self._build_particle_parameter_rows(form) elif operation_name == "usegment3d": self._build_usegment3d_parameter_rows(form) + elif operation_name == "registration": + self._build_registration_parameter_rows(form) elif operation_name == "visualization": self._build_visualization_parameter_rows(form) elif operation_name == "mip_export": @@ -9602,6 +9654,121 @@ def _build_usegment3d_parameter_rows(self, form: QFormLayout) -> None: ) form.addRow(postprocess_section) + def _build_registration_parameter_rows(self, form: QFormLayout) -> None: + """Add registration parameter controls to a form. + + Parameters + ---------- + form : QFormLayout + Parent form layout receiving registration controls. + + Returns + ------- + None + Widgets are created and attached in-place. + """ + pairwise_section, pairwise_form = self._build_parameter_section_card( + "Pairwise Registration" + ) + self._registration_channel_spin = QSpinBox() + self._registration_channel_spin.setRange(0, 0) + pairwise_form.addRow("channel", self._registration_channel_spin) + self._register_parameter_hint( + self._registration_channel_spin, + self._PARAMETER_HINTS["registration_channel"], + ) + + self._registration_type_combo = QComboBox() + self._registration_type_combo.addItem("Translation", "translation") + self._registration_type_combo.addItem("Rigid", "rigid") + self._registration_type_combo.addItem("Similarity", "similarity") + pairwise_form.addRow("type", self._registration_type_combo) + self._register_parameter_hint( + self._registration_type_combo, + self._PARAMETER_HINTS["registration_type"], + ) + + self._registration_resolution_level_spin = QSpinBox() + self._registration_resolution_level_spin.setRange(0, 0) + pairwise_form.addRow( + "resolution level", + self._registration_resolution_level_spin, + ) + self._register_parameter_hint( + self._registration_resolution_level_spin, + self._PARAMETER_HINTS["registration_resolution_level"], + ) + form.addRow(pairwise_section) + + global_section, global_form = self._build_parameter_section_card( + "Global Optimization" + ) + self._registration_anchor_mode_combo = QComboBox() + self._registration_anchor_mode_combo.addItem( + "Central tile", "central" + ) + self._registration_anchor_mode_combo.addItem( + "Manual tile", "manual" + ) + global_form.addRow("anchor mode", self._registration_anchor_mode_combo) + self._register_parameter_hint( + self._registration_anchor_mode_combo, + self._PARAMETER_HINTS["registration_anchor_mode"], + ) + + self._registration_anchor_position_spin = QSpinBox() + self._registration_anchor_position_spin.setRange(0, 0) + global_form.addRow( + "anchor tile", + self._registration_anchor_position_spin, + ) + self._register_parameter_hint( + self._registration_anchor_position_spin, + self._PARAMETER_HINTS["registration_anchor_position"], + ) + form.addRow(global_section) + + fusion_section, fusion_form = self._build_parameter_section_card("Fusion") + overlap_row = QHBoxLayout() + apply_compact_row_spacing(overlap_row) + self._registration_overlap_z_spin = QSpinBox() + self._registration_overlap_z_spin.setRange(0, 1_000_000) + self._registration_overlap_y_spin = QSpinBox() + self._registration_overlap_y_spin.setRange(0, 1_000_000) + self._registration_overlap_x_spin = QSpinBox() + self._registration_overlap_x_spin.setRange(0, 1_000_000) + overlap_row.addWidget(QLabel("z")) + overlap_row.addWidget(self._registration_overlap_z_spin) + overlap_row.addWidget(QLabel("y")) + overlap_row.addWidget(self._registration_overlap_y_spin) + overlap_row.addWidget(QLabel("x")) + overlap_row.addWidget(self._registration_overlap_x_spin) + overlap_widget = QWidget() + overlap_widget.setLayout(overlap_row) + fusion_form.addRow("overlap pad", overlap_widget) + self._register_parameter_hint( + self._registration_overlap_z_spin, + self._PARAMETER_HINTS["overlap_zyx"], + ) + self._register_parameter_hint( + self._registration_overlap_y_spin, + self._PARAMETER_HINTS["overlap_zyx"], + ) + self._register_parameter_hint( + self._registration_overlap_x_spin, + self._PARAMETER_HINTS["overlap_zyx"], + ) + + self._registration_blend_mode_combo = QComboBox() + self._registration_blend_mode_combo.addItem("Feather", "feather") + self._registration_blend_mode_combo.addItem("Average", "average") + fusion_form.addRow("blend mode", self._registration_blend_mode_combo) + self._register_parameter_hint( + self._registration_blend_mode_combo, + self._PARAMETER_HINTS["registration_blend_mode"], + ) + form.addRow(fusion_section) + def _rebuild_usegment3d_channel_checkboxes( self, *, @@ -12007,6 +12174,7 @@ def _on_operation_selection_changed(self) -> None: self._set_shear_parameter_enabled_state() self._set_particle_parameter_enabled_state() self._set_usegment3d_parameter_enabled_state() + self._set_registration_parameter_enabled_state() self._set_visualization_parameter_enabled_state() self._set_mip_export_parameter_enabled_state() @@ -12446,6 +12614,41 @@ def _set_usegment3d_parameter_enabled_state(self) -> None: ) self._usegment3d_postprocess_dtform_combo.setEnabled(postprocess_enabled) + def _set_registration_parameter_enabled_state(self) -> None: + """Enable/disable registration widgets based on selection and anchor mode. + + Parameters + ---------- + None + + Returns + ------- + None + Widget enabled states are updated in-place. + """ + registration_enabled = self._operation_checkboxes["registration"].isChecked() + widgets = ( + self._registration_channel_spin, + self._registration_type_combo, + self._registration_resolution_level_spin, + self._registration_anchor_mode_combo, + self._registration_overlap_z_spin, + self._registration_overlap_y_spin, + self._registration_overlap_x_spin, + self._registration_blend_mode_combo, + ) + for widget in widgets: + widget.setEnabled(registration_enabled) + + anchor_mode = ( + str(self._registration_anchor_mode_combo.currentData() or "central") + .strip() + .lower() + or "central" + ) + manual_anchor = registration_enabled and anchor_mode == "manual" + self._registration_anchor_position_spin.setEnabled(manual_anchor) + def _set_flatfield_parameter_enabled_state(self) -> None: """Enable/disable flatfield widgets based on selection and overlap mode. @@ -12624,6 +12827,9 @@ def _hydrate(self, initial: WorkflowConfig) -> None: usegment3d_params = dict( normalized_parameters.get("usegment3d", self._usegment3d_defaults) ) + registration_params = dict( + normalized_parameters.get("registration", self._registration_defaults) + ) visualization_params = dict( normalized_parameters.get("visualization", self._visualization_defaults) ) @@ -12754,6 +12960,11 @@ def _hydrate(self, initial: WorkflowConfig) -> None: self._usegment3d_resolution_level_spin.setMaximum( max(0, int(max_usegment3d_resolution_level)) ) + self._registration_channel_spin.setMaximum(channel_count - 1) + self._registration_resolution_level_spin.setMaximum( + max(0, int(max_usegment3d_resolution_level)) + ) + self._registration_anchor_position_spin.setMaximum(position_count - 1) requested_usegment_channels = usegment3d_params.get("channel_indices", []) if isinstance(requested_usegment_channels, str): requested_channel_values: list[Any] = self._split_csv_values( @@ -13343,6 +13554,107 @@ def _hydrate(self, initial: WorkflowConfig) -> None: dtform_index = self._usegment3d_postprocess_dtform_combo.count() - 1 self._usegment3d_postprocess_dtform_combo.setCurrentIndex(dtform_index) + self._registration_channel_spin.setValue( + max( + 0, + min( + int(self._registration_channel_spin.maximum()), + int(registration_params.get("registration_channel", 0)), + ), + ) + ) + registration_type = ( + str(registration_params.get("registration_type", "rigid")) + .strip() + .lower() + or "rigid" + ) + registration_type_index = self._registration_type_combo.findData( + registration_type + ) + if registration_type_index < 0: + registration_type_index = self._registration_type_combo.findData( + "rigid" + ) + if registration_type_index < 0: + registration_type_index = 0 + self._registration_type_combo.setCurrentIndex(registration_type_index) + + registration_resolution_level = max( + 0, + int(registration_params.get("input_resolution_level", 0)), + ) + self._registration_resolution_level_spin.setValue( + min( + int(self._registration_resolution_level_spin.maximum()), + int(registration_resolution_level), + ) + ) + + anchor_mode = ( + str(registration_params.get("anchor_mode", "central")) + .strip() + .lower() + or "central" + ) + if anchor_mode not in {"central", "manual"}: + anchor_mode = "central" + anchor_mode_index = self._registration_anchor_mode_combo.findData( + anchor_mode + ) + if anchor_mode_index < 0: + anchor_mode_index = self._registration_anchor_mode_combo.findData( + "central" + ) + if anchor_mode_index < 0: + anchor_mode_index = 0 + self._registration_anchor_mode_combo.setCurrentIndex(anchor_mode_index) + anchor_position = registration_params.get("anchor_position") + if anchor_position in {None, ""}: + parsed_anchor_position = 0 + else: + parsed_anchor_position = int(anchor_position) + self._registration_anchor_position_spin.setValue( + max( + 0, + min( + int(self._registration_anchor_position_spin.maximum()), + int(parsed_anchor_position), + ), + ) + ) + registration_overlap_zyx = registration_params.get("overlap_zyx", [8, 32, 32]) + if ( + not isinstance(registration_overlap_zyx, (tuple, list)) + or len(registration_overlap_zyx) != 3 + ): + registration_overlap_zyx = [8, 32, 32] + self._registration_overlap_z_spin.setValue( + max(0, int(registration_overlap_zyx[0])) + ) + self._registration_overlap_y_spin.setValue( + max(0, int(registration_overlap_zyx[1])) + ) + self._registration_overlap_x_spin.setValue( + max(0, int(registration_overlap_zyx[2])) + ) + registration_blend_mode = ( + str(registration_params.get("blend_mode", "feather")).strip().lower() + or "feather" + ) + registration_blend_index = self._registration_blend_mode_combo.findData( + registration_blend_mode + ) + if registration_blend_index < 0: + registration_blend_index = self._registration_blend_mode_combo.findData( + "feather" + ) + if registration_blend_index < 0: + registration_blend_index = 0 + self._registration_blend_mode_combo.setCurrentIndex( + registration_blend_index + ) + self._visualization_show_all_positions_checkbox.setChecked( bool(visualization_params.get("show_all_positions", False)) ) @@ -14169,6 +14481,58 @@ def _collect_usegment3d_parameters(self) -> Dict[str, Any]: "postprocess_dtform_method": dtform_method, } + def _collect_registration_parameters(self) -> Dict[str, Any]: + """Collect registration parameter values from widgets. + + Parameters + ---------- + None + + Returns + ------- + dict[str, Any] + Registration parameter mapping. + """ + anchor_mode = ( + str(self._registration_anchor_mode_combo.currentData() or "central") + .strip() + .lower() + or "central" + ) + anchor_position: Optional[int] + if anchor_mode == "manual": + anchor_position = int(self._registration_anchor_position_spin.value()) + else: + anchor_position = None + + return { + "chunk_basis": "3d", + "detect_2d_per_slice": False, + "use_map_overlap": True, + "overlap_zyx": [ + int(self._registration_overlap_z_spin.value()), + int(self._registration_overlap_y_spin.value()), + int(self._registration_overlap_x_spin.value()), + ], + "memory_overhead_factor": float( + self._registration_defaults.get("memory_overhead_factor", 2.5) + ), + "registration_channel": int(self._registration_channel_spin.value()), + "registration_type": str( + self._registration_type_combo.currentData() or "rigid" + ).strip() + or "rigid", + "input_resolution_level": int( + self._registration_resolution_level_spin.value() + ), + "anchor_mode": anchor_mode, + "anchor_position": anchor_position, + "blend_mode": str( + self._registration_blend_mode_combo.currentData() or "feather" + ).strip() + or "feather", + } + def _collect_visualization_parameters(self) -> Dict[str, Any]: """Collect visualization parameter values from widgets. @@ -14308,6 +14672,8 @@ def _collect_operation_parameters(self, operation_name: str) -> Dict[str, Any]: defaults.update(self._collect_particle_parameters()) elif operation_name == "usegment3d": defaults.update(self._collect_usegment3d_parameters()) + elif operation_name == "registration": + defaults.update(self._collect_registration_parameters()) elif operation_name == "visualization": defaults.update(self._collect_visualization_parameters()) elif operation_name == "mip_export": diff --git a/src/clearex/main.py b/src/clearex/main.py index 2da8c70..02869ac 100644 --- a/src/clearex/main.py +++ b/src/clearex/main.py @@ -79,6 +79,41 @@ from clearex.mip_export.pipeline import ( run_mip_export_analysis, ) +try: + from clearex.registration.pipeline import ( + run_registration_analysis, + ) +except ImportError: + + def run_registration_analysis(*, zarr_path, parameters, client, progress_callback): + """Fallback when the optional registration runtime module is unavailable. + + Parameters + ---------- + zarr_path : str + Canonical analysis-store path. + parameters : dict[str, Any] + Runtime parameter mapping. + client : Any + Dask client handle. + progress_callback : callable + Progress callback. + + Returns + ------- + None + This fallback always raises before returning. + + Raises + ------ + RuntimeError + Always raised to indicate missing registration implementation. + """ + del zarr_path, parameters, client, progress_callback + raise RuntimeError( + "registration analysis is unavailable: " + "could not import clearex.registration.pipeline." + ) try: from clearex.usegment3d.pipeline import ( @@ -156,6 +191,7 @@ def run_usegment3d_analysis(*, zarr_path, parameters, client, progress_callback) "shear_transform", "particle_detection", "usegment3d", + "registration", "display_pyramid", "mip_export", } @@ -2144,6 +2180,7 @@ def _usegment3d_progress(percent: int, message: str) -> None: continue if operation_name == "registration": + registration_parameters = dict(operation_parameters) _emit_analysis_progress( operation_start, "Running registration workflow.", @@ -2155,23 +2192,232 @@ def _usegment3d_progress(percent: int, message: str) -> None: if provenance_store_path and is_zarr_store_path( provenance_store_path ): - logger.warning( - "Registration is enabled but is not yet integrated with " - "canonical 6D store inputs. Skipping registration." + progress_state = {"last_percent": -5} + + def _registration_progress( + percent: int, message: str + ) -> None: + """Throttle registration progress logs. + + Parameters + ---------- + percent : int + Progress percent. + message : str + Progress message. + + Returns + ------- + None + Logger side effects only. + """ + last_percent = int(progress_state["last_percent"]) + if percent >= 100 or percent - last_percent >= 5: + progress_state["last_percent"] = int(percent) + logger.info( + f"[registration] {int(percent)}% - {message}" + ) + mapped = operation_start + int( + (max(0, min(100, int(percent))) / 100) + * max(1, operation_end - operation_start) + ) + _emit_analysis_progress( + mapped, + f"registration: {message}", + ) + + summary = run_registration_analysis( + zarr_path=provenance_store_path, + parameters=registration_parameters, + client=analysis_client, + progress_callback=_registration_progress, + ) + registration_component = str( + getattr(summary, "component", "results/registration/latest") + ) + registration_data_component = str( + getattr( + summary, + "data_component", + f"{registration_component}/data", + ) + ) + registration_affines_component = str( + getattr( + summary, + "affines_component", + f"{registration_component}/affines_tpx44", + ) + ) + registration_source_component = str( + getattr( + summary, + "source_component", + registration_parameters.get("input_source", "data"), + ) + ) + produced_components["registration"] = ( + registration_data_component + ) + output_records["registration"] = { + "component": registration_component, + "data_component": registration_data_component, + "affines_component": registration_affines_component, + "source_component": registration_source_component, + "pairwise_source_component": str( + getattr( + summary, + "pairwise_source_component", + registration_source_component, + ) + ), + "requested_source_component": str( + getattr( + summary, + "requested_source_component", + registration_parameters.get("input_source", "data"), + ) + ), + "requested_input_resolution_level": int( + getattr( + summary, + "requested_input_resolution_level", + registration_parameters.get( + "input_resolution_level", 0 + ), + ) + ), + "input_resolution_level": int( + getattr(summary, "input_resolution_level", 0) + ), + "registration_channel": int( + getattr(summary, "registration_channel", 0) + ), + "registration_type": str( + getattr( + summary, + "registration_type", + registration_parameters.get( + "registration_type", "rigid" + ), + ) + ), + "edge_count": int(getattr(summary, "edge_count", 0)), + "active_edge_count": int( + getattr(summary, "active_edge_count", 0) + ), + "dropped_edge_count": int( + getattr(summary, "dropped_edge_count", 0) + ), + "output_shape_tpczyx": list( + getattr(summary, "output_shape_tpczyx", ()) + ), + "output_chunks_tpczyx": list( + getattr(summary, "output_chunks_tpczyx", ()) + ), + "blend_mode": str( + getattr( + summary, + "blend_mode", + registration_parameters.get( + "blend_mode", "feather" + ), + ) + ), + "storage_policy": "latest_only", + } + logger.info( + "Registration completed: " + f"component={registration_component}, " + f"data_component={registration_data_component}, " + f"source={registration_source_component}, " + f"pairwise_source={getattr(summary, 'pairwise_source_component', registration_source_component)}, " + f"edges={getattr(summary, 'edge_count', 0)}, " + f"active_edges={getattr(summary, 'active_edge_count', 0)}, " + f"output_shape={getattr(summary, 'output_shape_tpczyx', ())}." ) step_records.append( { "name": "registration", "parameters": { - **operation_parameters, - "status": "skipped", - "reason": "not_integrated_with_canonical_store", + **registration_parameters, + "component": registration_component, + "data_component": registration_data_component, + "affines_component": registration_affines_component, + "source_component": registration_source_component, + "pairwise_source_component": str( + getattr( + summary, + "pairwise_source_component", + registration_source_component, + ) + ), + "requested_source_component": str( + getattr( + summary, + "requested_source_component", + registration_parameters.get( + "input_source", "data" + ), + ) + ), + "requested_input_resolution_level": int( + getattr( + summary, + "requested_input_resolution_level", + registration_parameters.get( + "input_resolution_level", 0 + ), + ) + ), + "input_resolution_level": int( + getattr(summary, "input_resolution_level", 0) + ), + "registration_channel": int( + getattr(summary, "registration_channel", 0) + ), + "registration_type": str( + getattr( + summary, + "registration_type", + registration_parameters.get( + "registration_type", "rigid" + ), + ) + ), + "anchor_positions": list( + getattr(summary, "anchor_positions", ()) + ), + "edge_count": int( + getattr(summary, "edge_count", 0) + ), + "active_edge_count": int( + getattr(summary, "active_edge_count", 0) + ), + "dropped_edge_count": int( + getattr(summary, "dropped_edge_count", 0) + ), + "output_shape_tpczyx": list( + getattr(summary, "output_shape_tpczyx", ()) + ), + "output_chunks_tpczyx": list( + getattr(summary, "output_chunks_tpczyx", ()) + ), + "blend_mode": str( + getattr( + summary, + "blend_mode", + registration_parameters.get( + "blend_mode", "feather" + ), + ) + ), }, } ) _emit_analysis_progress( operation_end, - "Registration skipped (not yet integrated with canonical store).", + "Registration complete.", ) else: logger.warning( @@ -2181,7 +2427,7 @@ def _usegment3d_progress(percent: int, message: str) -> None: { "name": "registration", "parameters": { - **operation_parameters, + **registration_parameters, "status": "skipped", "reason": "no_zarr_store", }, diff --git a/src/clearex/registration/pipeline.py b/src/clearex/registration/pipeline.py new file mode 100644 index 0000000..40834b2 --- /dev/null +++ b/src/clearex/registration/pipeline.py @@ -0,0 +1,1799 @@ +# Copyright (c) 2021-2025 The University of Texas Southwestern Medical Center. +# All rights reserved. + +"""Chunked tile-registration workflow for canonical 6D analysis stores.""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +import json +import math +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Sequence, Union + +import dask +from dask import delayed +import numpy as np +from scipy import ndimage, optimize +from scipy.spatial.transform import Rotation +import zarr + +try: + import ants +except ImportError: # pragma: no cover - exercised via runtime fallback tests + ants = None + +from clearex.io.experiment import load_navigate_experiment +from clearex.io.provenance import register_latest_output_reference +from clearex.workflow import SpatialCalibrationConfig, spatial_calibration_from_dict + +if TYPE_CHECKING: + from dask.distributed import Client + + +ProgressCallback = Callable[[int, str], None] +_ANTS_AFF_ITERATIONS = (2000, 1000, 500, 250) +_ANTS_AFF_SHRINK_FACTORS = (8, 4, 2, 1) +_ANTS_AFF_SMOOTHING_SIGMAS = (3, 2, 1, 0) +_ANTS_RANDOM_SAMPLING_RATE = 0.25 +_ABSOLUTE_RESIDUAL_THRESHOLD_PX = 3.5 +_RELATIVE_RESIDUAL_THRESHOLD = 2.5 +_WEIGHT_EPS = np.float32(1e-6) +_PERMUTE_ZYX_TO_XYZ = np.asarray( + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], dtype=np.float64 +) + + +@dataclass(frozen=True) +class RegistrationSummary: + """Summary metadata for one registration run. + + Attributes + ---------- + component : str + Output latest-group component. + data_component : str + Fused output data component. + affines_component : str + Optimized affine dataset component. + source_component : str + Full-resolution source component used for fusion. + requested_source_component : str + Requested input source before resolution-level expansion. + pairwise_source_component : str + Effective selected-level source used for pairwise registration. + input_resolution_level : int + Effective registration level used for pairwise estimation. + requested_input_resolution_level : int + Requested registration level. + registration_channel : int + Channel used to estimate transforms. + registration_type : str + Registration transform family. + anchor_positions : tuple[int, ...] + Anchor tile index used for each timepoint. + positions : int + Number of source positions. + timepoints : int + Number of processed timepoints. + edge_count : int + Total candidate graph edges. + active_edge_count : int + Total active edges after pruning across timepoints. + dropped_edge_count : int + Total dropped edges across timepoints. + output_shape_tpczyx : tuple[int, int, int, int, int, int] + Output fused shape. + output_chunks_tpczyx : tuple[int, int, int, int, int, int] + Output chunk layout. + blend_mode : str + Blend mode applied during fusion. + """ + + component: str + data_component: str + affines_component: str + source_component: str + requested_source_component: str + pairwise_source_component: str + input_resolution_level: int + requested_input_resolution_level: int + registration_channel: int + registration_type: str + anchor_positions: tuple[int, ...] + positions: int + timepoints: int + edge_count: int + active_edge_count: int + dropped_edge_count: int + output_shape_tpczyx: tuple[int, int, int, int, int, int] + output_chunks_tpczyx: tuple[int, int, int, int, int, int] + blend_mode: str + + +@dataclass(frozen=True) +class _EdgeSpec: + """Nominal overlap edge between two positions.""" + + fixed_position: int + moving_position: int + overlap_bbox_xyz: tuple[tuple[float, float], tuple[float, float], tuple[float, float]] + overlap_voxels: int + + +def _emit( + progress_callback: Optional[ProgressCallback], percent: int, message: str +) -> None: + """Emit progress when a callback is available.""" + if progress_callback is None: + return + progress_callback(int(percent), str(message)) + + +def _safe_float(value: Any, *, default: float = 0.0) -> float: + """Parse a float with fallback.""" + try: + return float(value) + except Exception: + return float(default) + + +def _looks_like_multiposition_header(row: Any) -> bool: + """Return whether a row resembles a Navigate multiposition header.""" + if not isinstance(row, (list, tuple)) or not row: + return False + labels = {str(value).strip().upper() for value in row} + return {"X", "Y", "Z"}.issubset(labels) + + +def _parse_multiposition_stage_rows(payload: Any) -> list[dict[str, float]]: + """Parse stage-coordinate rows from multiposition payloads.""" + if not isinstance(payload, list): + return [] + + rows = list(payload) + header_index: dict[str, int] = {} + if rows and _looks_like_multiposition_header(rows[0]): + header = rows.pop(0) + if isinstance(header, (list, tuple)): + for idx, value in enumerate(header): + header_index[str(value).strip().upper()] = int(idx) + + parsed: list[dict[str, float]] = [] + for row in rows: + if not isinstance(row, (list, tuple)): + continue + + def _value(field: str, fallback_index: int) -> float: + index = header_index.get(field, fallback_index) + if index < 0 or index >= len(row): + return 0.0 + return _safe_float(row[index], default=0.0) + + parsed.append( + { + "x": _value("X", 0), + "y": _value("Y", 1), + "z": _value("Z", 2), + "theta": _value("THETA", 3), + "f": _value("F", 4), + } + ) + return parsed + + +def _load_stage_rows(root_attrs: Mapping[str, Any]) -> list[dict[str, float]]: + """Load multiposition stage rows from experiment metadata.""" + source_experiment = root_attrs.get("source_experiment") + if not isinstance(source_experiment, str): + return [] + + experiment_path = Path(source_experiment).expanduser() + if not experiment_path.exists(): + return [] + + sidecar_path = experiment_path.parent / "multi_positions.yml" + if sidecar_path.exists(): + try: + text = sidecar_path.read_text() + try: + payload = json.loads(text) + except json.JSONDecodeError: + try: + import yaml # type: ignore[import-not-found] + + payload = yaml.safe_load(text) + except Exception: + payload = None + parsed = _parse_multiposition_stage_rows(payload) + if parsed: + return parsed + except Exception: + pass + + try: + experiment = load_navigate_experiment(experiment_path) + except Exception: + return [] + return _parse_multiposition_stage_rows(experiment.raw.get("MultiPositions")) + + +def _resolve_world_axis_delta( + *, + row: Mapping[str, float], + reference: Mapping[str, float], + binding: str, +) -> float: + """Resolve one world-axis delta from stage rows.""" + if binding == "none": + return 0.0 + sign = -1.0 if binding.startswith("-") else 1.0 + source_axis = binding[1:] + return sign * float(row[source_axis] - reference[source_axis]) + + +def _rotation_matrix_x(theta_deg: float) -> np.ndarray: + """Build a right-handed rotation matrix around X.""" + theta_rad = math.radians(float(theta_deg)) + cos_theta = math.cos(theta_rad) + sin_theta = math.sin(theta_rad) + return np.asarray( + [[1.0, 0.0, 0.0], [0.0, cos_theta, -sin_theta], [0.0, sin_theta, cos_theta]], + dtype=np.float64, + ) + + +def _position_centroid_anchor( + stage_rows: Sequence[Mapping[str, float]], + spatial_calibration: SpatialCalibrationConfig, + positions: Sequence[int], +) -> int: + """Return the position closest to the centroid in world space.""" + if not positions: + return 0 + stage_axis_map = spatial_calibration.stage_axis_map_by_world_axis() + reference = stage_rows[positions[0]] + coords: dict[int, np.ndarray] = {} + for position in positions: + row = stage_rows[position] + coords[int(position)] = np.asarray( + [ + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["x"] + ), + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["y"] + ), + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["z"] + ), + ], + dtype=np.float64, + ) + centroid = np.mean(list(coords.values()), axis=0) + return min( + positions, + key=lambda position: float( + np.linalg.norm(coords[int(position)] - centroid, ord=2) + ), + ) + + +def _build_nominal_transforms_xyz( + stage_rows: Sequence[Mapping[str, float]], + spatial_calibration: SpatialCalibrationConfig, + *, + anchor_position: int, + positions: Sequence[int], +) -> dict[int, np.ndarray]: + """Build nominal stage-derived transforms in physical XYZ coordinates.""" + stage_axis_map = spatial_calibration.stage_axis_map_by_world_axis() + reference = stage_rows[anchor_position] + transforms: dict[int, np.ndarray] = {} + for position in positions: + row = stage_rows[position] + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = _rotation_matrix_x(float(row["theta"] - reference["theta"])) + transform[:3, 3] = np.asarray( + [ + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["x"] + ), + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["y"] + ), + _resolve_world_axis_delta( + row=row, reference=reference, binding=stage_axis_map["z"] + ), + ], + dtype=np.float64, + ) + transforms[int(position)] = transform + return transforms + + +def _extract_voxel_size_um_zyx( + root: zarr.hierarchy.Group, source_component: str +) -> tuple[float, float, float]: + """Extract voxel size in ``(z, y, x)`` order.""" + try: + source = root[source_component] + except Exception: + source = None + + if source is not None: + try: + payload = source.attrs.get("voxel_size_um_zyx") + if isinstance(payload, (list, tuple)) and len(payload) >= 3: + parsed = tuple(float(value) for value in payload[:3]) + if all(value > 0 for value in parsed): + return parsed # type: ignore[return-value] + except Exception: + pass + + try: + payload = root.attrs.get("voxel_size_um_zyx") + if isinstance(payload, (list, tuple)) and len(payload) >= 3: + parsed = tuple(float(value) for value in payload[:3]) + if all(value > 0 for value in parsed): + return parsed # type: ignore[return-value] + except Exception: + pass + + navigate = root.attrs.get("navigate_experiment") + if isinstance(navigate, Mapping): + xy_um = _safe_float(navigate.get("xy_pixel_size_um"), default=1.0) + z_um = _safe_float(navigate.get("z_step_um"), default=1.0) + if xy_um > 0 and z_um > 0: + return (float(z_um), float(xy_um), float(xy_um)) + return (1.0, 1.0, 1.0) + + +def _component_level_suffix(component: str) -> Optional[int]: + """Parse a trailing ``level_`` suffix.""" + token = str(component).strip().split("/")[-1] + if not token.startswith("level_"): + return None + try: + parsed = int(token.split("_", maxsplit=1)[1]) + except Exception: + return None + return parsed if parsed >= 0 else None + + +def _resolve_source_components_for_level( + *, + root: zarr.hierarchy.Group, + requested_source_component: str, + input_resolution_level: int, +) -> tuple[str, str, int]: + """Resolve full-resolution and effective selected-level source components.""" + requested = str(requested_source_component).strip() or "data" + full_resolution_source = requested + if requested.startswith("data_pyramid/level_"): + full_resolution_source = "data" + elif "_pyramid/level_" in requested: + full_resolution_source = requested.split("_pyramid/level_", maxsplit=1)[0] + + if full_resolution_source not in root: + raise ValueError( + f"registration input component '{full_resolution_source}' was not found." + ) + + effective_level = max(0, int(input_resolution_level)) + if effective_level <= 0: + return full_resolution_source, full_resolution_source, 0 + + direct_level = _component_level_suffix(requested) + if direct_level is not None and requested in root and direct_level == effective_level: + return full_resolution_source, requested, effective_level + + candidate_components = [ + f"{full_resolution_source}_pyramid/level_{effective_level}", + ] + if full_resolution_source == "data": + candidate_components.insert(0, f"data_pyramid/level_{effective_level}") + for candidate in candidate_components: + if candidate in root: + return full_resolution_source, candidate, effective_level + + raise ValueError( + "registration input_resolution_level=" + f"{effective_level} was requested for '{full_resolution_source}', " + "but no matching pyramid component exists." + ) + + +def _pyramid_factor_zyx_for_level( + root: zarr.hierarchy.Group, *, level: int +) -> tuple[float, float, float]: + """Return per-axis pyramid factors in ``(z, y, x)`` order.""" + if level <= 0: + return (1.0, 1.0, 1.0) + + factors = root.attrs.get("data_pyramid_factors_tpczyx") + if isinstance(factors, (tuple, list)) and len(factors) > level: + entry = factors[level] + if isinstance(entry, (tuple, list)) and len(entry) >= 6: + try: + return ( + max(1.0, float(entry[3])), + max(1.0, float(entry[4])), + max(1.0, float(entry[5])), + ) + except Exception: + pass + uniform = float(2 ** int(level)) + return (uniform, uniform, uniform) + + +def _tile_extent_xyz( + shape_zyx: Sequence[int], voxel_size_um_zyx: Sequence[float] +) -> np.ndarray: + """Return physical tile extent in XYZ order.""" + return np.asarray( + [ + float(shape_zyx[2]) * float(voxel_size_um_zyx[2]), + float(shape_zyx[1]) * float(voxel_size_um_zyx[1]), + float(shape_zyx[0]) * float(voxel_size_um_zyx[0]), + ], + dtype=np.float64, + ) + + +def _transform_points_xyz(transform_xyz: np.ndarray, points_xyz: np.ndarray) -> np.ndarray: + """Apply a homogeneous affine to point rows.""" + homogeneous = np.concatenate( + [points_xyz.astype(np.float64), np.ones((points_xyz.shape[0], 1), dtype=np.float64)], + axis=1, + ) + return (transform_xyz @ homogeneous.T).T[:, :3] + + +def _bbox_from_points_xyz(points_xyz: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Return axis-aligned bounding box from point rows.""" + return ( + np.min(points_xyz, axis=0, initial=np.inf).astype(np.float64), + np.max(points_xyz, axis=0, initial=-np.inf).astype(np.float64), + ) + + +def _tile_bbox_xyz( + transform_xyz: np.ndarray, tile_extent_xyz: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Return transformed axis-aligned bounding box for one tile.""" + ex, ey, ez = (float(value) for value in tile_extent_xyz) + corners = np.asarray( + [ + [0.0, 0.0, 0.0], + [ex, 0.0, 0.0], + [0.0, ey, 0.0], + [0.0, 0.0, ez], + [ex, ey, 0.0], + [ex, 0.0, ez], + [0.0, ey, ez], + [ex, ey, ez], + ], + dtype=np.float64, + ) + return _bbox_from_points_xyz(_transform_points_xyz(transform_xyz, corners)) + + +def _bbox_intersection_xyz( + left: tuple[np.ndarray, np.ndarray], + right: tuple[np.ndarray, np.ndarray], +) -> Optional[tuple[np.ndarray, np.ndarray]]: + """Return XYZ bbox intersection when overlap volume is positive.""" + minimum = np.maximum(left[0], right[0]) + maximum = np.minimum(left[1], right[1]) + if np.any(maximum <= minimum): + return None + return minimum, maximum + + +def _bbox_volume_voxels( + minimum_xyz: np.ndarray, + maximum_xyz: np.ndarray, + voxel_size_um_zyx: Sequence[float], +) -> int: + """Estimate overlap volume in voxels for one bbox.""" + shape_xyz = (maximum_xyz - minimum_xyz).astype(np.float64) + voxel_xyz = np.asarray( + [ + float(voxel_size_um_zyx[2]), + float(voxel_size_um_zyx[1]), + float(voxel_size_um_zyx[0]), + ], + dtype=np.float64, + ) + counts = np.maximum(1, np.floor(shape_xyz / np.maximum(voxel_xyz, 1e-6))).astype(int) + return int(np.prod(counts, dtype=np.int64)) + + +def _build_edge_specs( + nominal_transforms_xyz: Mapping[int, np.ndarray], + *, + positions: Sequence[int], + tile_extent_xyz: np.ndarray, + voxel_size_um_zyx: Sequence[float], +) -> list[_EdgeSpec]: + """Build overlap graph edges from nominal transformed tile boxes.""" + bboxes = { + int(position): _tile_bbox_xyz(nominal_transforms_xyz[int(position)], tile_extent_xyz) + for position in positions + } + edges: list[_EdgeSpec] = [] + ordered = [int(position) for position in positions] + for idx, fixed_position in enumerate(ordered): + for moving_position in ordered[idx + 1 :]: + overlap = _bbox_intersection_xyz( + bboxes[int(fixed_position)], bboxes[int(moving_position)] + ) + if overlap is None: + continue + overlap_voxels = _bbox_volume_voxels( + overlap[0], overlap[1], voxel_size_um_zyx + ) + edges.append( + _EdgeSpec( + fixed_position=int(fixed_position), + moving_position=int(moving_position), + overlap_bbox_xyz=( + (float(overlap[0][0]), float(overlap[1][0])), + (float(overlap[0][1]), float(overlap[1][1])), + (float(overlap[0][2]), float(overlap[1][2])), + ), + overlap_voxels=int(overlap_voxels), + ) + ) + return edges + + +def _xyz_scale_diagonal(voxel_size_xyz: Sequence[float]) -> np.ndarray: + """Return diagonal scaling matrix for XYZ voxel sizes.""" + return np.diag( + [ + float(voxel_size_xyz[0]), + float(voxel_size_xyz[1]), + float(voxel_size_xyz[2]), + ] + ).astype(np.float64) + + +def _world_to_input_affine_zyx( + local_to_world_xyz: np.ndarray, + *, + reference_origin_xyz: np.ndarray, + voxel_size_um_zyx: Sequence[float], +) -> tuple[np.ndarray, np.ndarray]: + """Return ``ndimage.affine_transform`` mapping into ZYX input indices.""" + voxel_size_xyz = ( + float(voxel_size_um_zyx[2]), + float(voxel_size_um_zyx[1]), + float(voxel_size_um_zyx[0]), + ) + scale_xyz = _xyz_scale_diagonal(voxel_size_xyz) + scale_inv_xyz = np.diag([1.0 / value for value in voxel_size_xyz]).astype(np.float64) + world_to_local_xyz = np.linalg.inv(local_to_world_xyz) + matrix = ( + _PERMUTE_ZYX_TO_XYZ + @ scale_inv_xyz + @ world_to_local_xyz[:3, :3] + @ scale_xyz + @ _PERMUTE_ZYX_TO_XYZ + ) + offset = _PERMUTE_ZYX_TO_XYZ @ scale_inv_xyz @ ( + (world_to_local_xyz[:3, :3] @ reference_origin_xyz) + world_to_local_xyz[:3, 3] + ) + return matrix.astype(np.float64), offset.astype(np.float64) + + +def _resample_source_to_world_grid( + source_zyx: np.ndarray, + local_to_world_xyz: np.ndarray, + *, + reference_origin_xyz: np.ndarray, + reference_shape_zyx: tuple[int, int, int], + voxel_size_um_zyx: Sequence[float], + order: int, + cval: float, +) -> np.ndarray: + """Resample one source volume onto a world-aligned crop/output grid.""" + matrix, offset = _world_to_input_affine_zyx( + local_to_world_xyz, + reference_origin_xyz=reference_origin_xyz, + voxel_size_um_zyx=voxel_size_um_zyx, + ) + return ndimage.affine_transform( + np.asarray(source_zyx), + matrix=matrix, + offset=offset, + output_shape=reference_shape_zyx, + order=int(order), + mode="constant", + cval=float(cval), + prefilter=bool(order > 1), + ).astype(np.float32, copy=False) + + +def _crop_from_overlap_bbox( + overlap_bbox_xyz: tuple[tuple[float, float], tuple[float, float], tuple[float, float]], + *, + voxel_size_um_zyx: Sequence[float], + overlap_zyx: Sequence[int], +) -> tuple[np.ndarray, tuple[int, int, int]]: + """Return crop origin and shape for one overlap bbox.""" + pad_xyz = np.asarray( + [ + float(overlap_zyx[2]) * float(voxel_size_um_zyx[2]), + float(overlap_zyx[1]) * float(voxel_size_um_zyx[1]), + float(overlap_zyx[0]) * float(voxel_size_um_zyx[0]), + ], + dtype=np.float64, + ) + minimum_xyz = np.asarray( + [overlap_bbox_xyz[0][0], overlap_bbox_xyz[1][0], overlap_bbox_xyz[2][0]], + dtype=np.float64, + ) - pad_xyz + maximum_xyz = np.asarray( + [overlap_bbox_xyz[0][1], overlap_bbox_xyz[1][1], overlap_bbox_xyz[2][1]], + dtype=np.float64, + ) + pad_xyz + voxel_xyz = np.asarray( + [ + float(voxel_size_um_zyx[2]), + float(voxel_size_um_zyx[1]), + float(voxel_size_um_zyx[0]), + ], + dtype=np.float64, + ) + size_xyz = np.maximum(voxel_xyz, maximum_xyz - minimum_xyz) + shape_xyz = np.maximum(1, np.ceil(size_xyz / np.maximum(voxel_xyz, 1e-6))).astype(int) + shape_zyx = (int(shape_xyz[2]), int(shape_xyz[1]), int(shape_xyz[0])) + return minimum_xyz, shape_zyx + + +def _ants_image_from_zyx( + volume_zyx: np.ndarray, *, voxel_size_um_zyx: Sequence[float] +) -> Any: + """Create an ANTs image from a ZYX NumPy array with XYZ spacing.""" + if ants is None: + raise RuntimeError( + "registration requires antspyx/ants to estimate tile transforms." + ) + image = ants.from_numpy(np.asarray(volume_zyx, dtype=np.float32)) + image.set_spacing( + ( + float(voxel_size_um_zyx[2]), + float(voxel_size_um_zyx[1]), + float(voxel_size_um_zyx[0]), + ) + ) + image.set_origin((0.0, 0.0, 0.0)) + return image + + +def _ants_transform_to_matrix_xyz(transform: Any) -> np.ndarray: + """Convert an ANTs affine transform into a homogeneous XYZ matrix.""" + parameters = np.asarray(transform.parameters, dtype=np.float64) + if parameters.size < 12: + raise ValueError("Expected a 3D affine transform with 12 parameters.") + matrix = parameters[:9].reshape(3, 3) + translation = parameters[9:12] + center = np.asarray(transform.fixed_parameters, dtype=np.float64) + offset = translation + center - (matrix @ center) + affine = np.eye(4, dtype=np.float64) + affine[:3, :3] = matrix + affine[:3, 3] = offset + return affine + + +def _registration_type_to_ants(value: str) -> str: + """Map normalized registration type to ANTsPy value.""" + mapping = { + "translation": "Translation", + "rigid": "Rigid", + "similarity": "Similarity", + } + return mapping[str(value).strip().lower()] + + +def _register_pairwise_overlap( + *, + zarr_path: str, + source_component: str, + t_index: int, + registration_channel: int, + edge: _EdgeSpec, + nominal_fixed_transform_xyz: np.ndarray, + nominal_moving_transform_xyz: np.ndarray, + voxel_size_um_zyx: Sequence[float], + overlap_zyx: Sequence[int], + registration_type: str, +) -> dict[str, Any]: + """Register one nominal overlap crop and return a correction transform.""" + root = zarr.open_group(str(zarr_path), mode="r") + source = root[source_component] + fixed_source = np.asarray( + source[int(t_index), int(edge.fixed_position), int(registration_channel), :, :, :], + dtype=np.float32, + ) + moving_source = np.asarray( + source[int(t_index), int(edge.moving_position), int(registration_channel), :, :, :], + dtype=np.float32, + ) + crop_origin_xyz, crop_shape_zyx = _crop_from_overlap_bbox( + edge.overlap_bbox_xyz, + voxel_size_um_zyx=voxel_size_um_zyx, + overlap_zyx=overlap_zyx, + ) + fixed_crop = _resample_source_to_world_grid( + fixed_source, + nominal_fixed_transform_xyz, + reference_origin_xyz=crop_origin_xyz, + reference_shape_zyx=crop_shape_zyx, + voxel_size_um_zyx=voxel_size_um_zyx, + order=1, + cval=0.0, + ) + moving_crop = _resample_source_to_world_grid( + moving_source, + nominal_moving_transform_xyz, + reference_origin_xyz=crop_origin_xyz, + reference_shape_zyx=crop_shape_zyx, + voxel_size_um_zyx=voxel_size_um_zyx, + order=1, + cval=0.0, + ) + fixed_mask = np.asarray(fixed_crop > 0, dtype=np.float32) + moving_mask = np.asarray(moving_crop > 0, dtype=np.float32) + overlap_pixels = int(np.count_nonzero((fixed_mask > 0) & (moving_mask > 0))) + if overlap_pixels <= 0 or float(np.std(fixed_crop)) <= 1e-6 or float(np.std(moving_crop)) <= 1e-6: + return { + "fixed_position": int(edge.fixed_position), + "moving_position": int(edge.moving_position), + "success": False, + "reason": "insufficient_overlap_signal", + "correction_matrix_xyz": np.eye(4, dtype=np.float64).tolist(), + "overlap_voxels": int(edge.overlap_voxels), + "nominal_overlap_pixels": int(overlap_pixels), + } + + try: + fixed_image = _ants_image_from_zyx(fixed_crop, voxel_size_um_zyx=voxel_size_um_zyx) + moving_image = _ants_image_from_zyx( + moving_crop, voxel_size_um_zyx=voxel_size_um_zyx + ) + fixed_mask_image = _ants_image_from_zyx( + fixed_mask, voxel_size_um_zyx=voxel_size_um_zyx + ) + moving_mask_image = _ants_image_from_zyx( + moving_mask, voxel_size_um_zyx=voxel_size_um_zyx + ) + registration = ants.registration( + fixed=fixed_image, + moving=moving_image, + fixed_mask=fixed_mask_image, + moving_mask=moving_mask_image, + type_of_transform=_registration_type_to_ants(registration_type), + initial_transform="Identity", + aff_metric="mattes", + aff_sampling=32, + aff_random_sampling_rate=float(_ANTS_RANDOM_SAMPLING_RATE), + aff_iterations=_ANTS_AFF_ITERATIONS, + aff_shrink_factors=_ANTS_AFF_SHRINK_FACTORS, + aff_smoothing_sigmas=_ANTS_AFF_SMOOTHING_SIGMAS, + mask_all_stages=True, + verbose=False, + ) + transform = ants.read_transform(registration["fwdtransforms"][0]) + correction_matrix_xyz = _ants_transform_to_matrix_xyz(transform) + success = True + reason = "" + except Exception as exc: + correction_matrix_xyz = np.eye(4, dtype=np.float64) + success = False + reason = str(exc) + + return { + "fixed_position": int(edge.fixed_position), + "moving_position": int(edge.moving_position), + "success": bool(success), + "reason": reason, + "correction_matrix_xyz": correction_matrix_xyz.tolist(), + "overlap_voxels": int(edge.overlap_voxels), + "nominal_overlap_pixels": int(overlap_pixels), + } + + +def _matrix_from_pose(params: np.ndarray, registration_type: str) -> np.ndarray: + """Build a homogeneous correction matrix from pose parameters.""" + matrix = np.eye(4, dtype=np.float64) + if registration_type == "translation": + matrix[:3, 3] = np.asarray(params[:3], dtype=np.float64) + return matrix + + rotation = Rotation.from_rotvec(np.asarray(params[:3], dtype=np.float64)).as_matrix() + translation = np.asarray(params[3:6], dtype=np.float64) + if registration_type == "rigid": + matrix[:3, :3] = rotation + matrix[:3, 3] = translation + return matrix + + scale = float(np.exp(float(params[6]))) + matrix[:3, :3] = rotation * scale + matrix[:3, 3] = translation + return matrix + + +def _pose_from_matrix(matrix_xyz: np.ndarray, registration_type: str) -> np.ndarray: + """Convert a homogeneous correction matrix into optimizer parameters.""" + translation = np.asarray(matrix_xyz[:3, 3], dtype=np.float64) + if registration_type == "translation": + return translation + + linear = np.asarray(matrix_xyz[:3, :3], dtype=np.float64) + if registration_type == "similarity": + scale = float(np.cbrt(max(np.linalg.det(linear), 1e-12))) + rotation_matrix = linear / max(scale, 1e-12) + rotation = Rotation.from_matrix(rotation_matrix) + return np.concatenate([rotation.as_rotvec(), translation, np.asarray([math.log(scale)])]) + + rotation = Rotation.from_matrix(linear) + return np.concatenate([rotation.as_rotvec(), translation]) + + +def _translation_residual_voxels( + measured_xyz: np.ndarray, + predicted_xyz: np.ndarray, + *, + voxel_size_um_zyx: Sequence[float], +) -> float: + """Return translation residual magnitude in selected-level voxels.""" + delta_xyz = np.asarray(predicted_xyz[:3, 3] - measured_xyz[:3, 3], dtype=np.float64) + mean_voxel = float(np.mean(np.asarray(voxel_size_um_zyx, dtype=np.float64))) + return float(np.linalg.norm(delta_xyz, ord=2) / max(mean_voxel, 1e-6)) + + +def _component_positions( + positions: Sequence[int], active_edge_indices: Sequence[int], edges: Sequence[_EdgeSpec] +) -> list[list[int]]: + """Return connected position components for the active edge graph.""" + adjacency: dict[int, set[int]] = {int(position): set() for position in positions} + for edge_index in active_edge_indices: + edge = edges[int(edge_index)] + adjacency[int(edge.fixed_position)].add(int(edge.moving_position)) + adjacency[int(edge.moving_position)].add(int(edge.fixed_position)) + + components: list[list[int]] = [] + visited: set[int] = set() + for position in positions: + node = int(position) + if node in visited: + continue + queue: deque[int] = deque([node]) + component: list[int] = [] + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + component.append(int(current)) + for neighbor in sorted(adjacency.get(current, ())): + if neighbor not in visited: + queue.append(int(neighbor)) + components.append(component) + return components + + +def _solve_translation_component( + *, + component_positions: Sequence[int], + active_edge_indices: Sequence[int], + edge_results: Sequence[Mapping[str, Any]], + edges: Sequence[_EdgeSpec], + anchor_position: int, +) -> dict[int, np.ndarray]: + """Solve translation-only correction poses for one connected component.""" + solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + variable_positions = [ + int(position) for position in component_positions if int(position) != int(anchor_position) + ] + if not variable_positions: + return solved + + column_index = {position: idx for idx, position in enumerate(variable_positions)} + equations: list[list[float]] = [] + targets: list[float] = [] + for edge_index in active_edge_indices: + edge = edges[int(edge_index)] + fixed_position = int(edge.fixed_position) + moving_position = int(edge.moving_position) + if fixed_position not in column_index and fixed_position != int(anchor_position) and fixed_position not in component_positions: + continue + if moving_position not in column_index and moving_position != int(anchor_position) and moving_position not in component_positions: + continue + measurement = np.asarray( + edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 + ) + weight = math.sqrt( + max(1.0, float(edge_results[int(edge_index)].get("overlap_voxels", 1))) + ) + delta = np.asarray(measurement[:3, 3], dtype=np.float64) + for axis in range(3): + row = [0.0] * (len(variable_positions) * 3) + if fixed_position in column_index: + row[(column_index[fixed_position] * 3) + axis] -= float(weight) + if moving_position in column_index: + row[(column_index[moving_position] * 3) + axis] += float(weight) + equations.append(row) + targets.append(float(weight) * float(delta[axis])) + + if not equations: + return solved + + design = np.asarray(equations, dtype=np.float64) + target = np.asarray(targets, dtype=np.float64) + coefficients, *_ = np.linalg.lstsq(design, target, rcond=None) + for position in variable_positions: + base = column_index[position] * 3 + solved[int(position)][:3, 3] = np.asarray( + coefficients[base : base + 3], dtype=np.float64 + ) + return solved + + +def _solve_nonlinear_component( + *, + component_positions: Sequence[int], + active_edge_indices: Sequence[int], + edge_results: Sequence[Mapping[str, Any]], + edges: Sequence[_EdgeSpec], + registration_type: str, + anchor_position: int, + voxel_size_um_zyx: Sequence[float], +) -> dict[int, np.ndarray]: + """Solve rigid/similarity correction poses for one connected component.""" + solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + variable_positions = [ + int(position) for position in component_positions if int(position) != int(anchor_position) + ] + if not variable_positions: + return solved + + pose_size = 6 if registration_type == "rigid" else 7 + variable_index = {position: idx for idx, position in enumerate(variable_positions)} + initial = np.zeros(len(variable_positions) * pose_size, dtype=np.float64) + for edge_index in active_edge_indices: + edge = edges[int(edge_index)] + if int(edge.moving_position) not in variable_index: + continue + if int(edge.fixed_position) != int(anchor_position): + continue + measurement = np.asarray( + edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 + ) + base = variable_index[int(edge.moving_position)] * pose_size + initial[base : base + pose_size] = _pose_from_matrix( + measurement, registration_type + ) + + def _matrix_for(position: int, params: np.ndarray) -> np.ndarray: + if int(position) == int(anchor_position): + return np.eye(4, dtype=np.float64) + base = variable_index[int(position)] * pose_size + return _matrix_from_pose(params[base : base + pose_size], registration_type) + + def _residuals(params: np.ndarray) -> np.ndarray: + residual_values: list[float] = [] + mean_voxel = float(np.mean(np.asarray(voxel_size_um_zyx, dtype=np.float64))) + for edge_index in active_edge_indices: + edge = edges[int(edge_index)] + measured = np.asarray( + edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 + ) + weight = math.sqrt( + max(1.0, float(edge_results[int(edge_index)].get("overlap_voxels", 1))) + ) + fixed_matrix = _matrix_for(int(edge.fixed_position), params) + moving_matrix = _matrix_for(int(edge.moving_position), params) + predicted = np.linalg.inv(fixed_matrix) @ moving_matrix + + measured_linear = np.asarray(measured[:3, :3], dtype=np.float64) + predicted_linear = np.asarray(predicted[:3, :3], dtype=np.float64) + if registration_type == "similarity": + measured_scale = float(np.cbrt(max(np.linalg.det(measured_linear), 1e-12))) + predicted_scale = float(np.cbrt(max(np.linalg.det(predicted_linear), 1e-12))) + measured_rotation = measured_linear / max(measured_scale, 1e-12) + predicted_rotation = predicted_linear / max(predicted_scale, 1e-12) + scale_residual = math.log(max(predicted_scale, 1e-12) / max(measured_scale, 1e-12)) + else: + measured_rotation = measured_linear + predicted_rotation = predicted_linear + scale_residual = 0.0 + + delta_rotation = Rotation.from_matrix( + measured_rotation.T @ predicted_rotation + ).as_rotvec() + delta_translation = ( + np.asarray(predicted[:3, 3] - measured[:3, 3], dtype=np.float64) + / max(mean_voxel, 1e-6) + ) + residual_values.extend((weight * delta_rotation).tolist()) + residual_values.extend((weight * delta_translation).tolist()) + if registration_type == "similarity": + residual_values.append(float(weight) * float(scale_residual)) + return np.asarray(residual_values, dtype=np.float64) + + result = optimize.least_squares( + _residuals, + initial, + method="trf", + x_scale="jac", + max_nfev=200, + ) + solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + for position in variable_positions: + base = variable_index[int(position)] * pose_size + solved[int(position)] = _matrix_from_pose( + result.x[base : base + pose_size], registration_type + ) + return solved + + +def _solve_with_pruning( + *, + positions: Sequence[int], + edges: Sequence[_EdgeSpec], + edge_results: Sequence[Mapping[str, Any]], + anchor_position: int, + registration_type: str, + voxel_size_um_zyx: Sequence[float], +) -> tuple[dict[int, np.ndarray], np.ndarray, np.ndarray]: + """Solve correction transforms with simple bad-link pruning.""" + success_mask = np.asarray( + [bool(result.get("success", False)) for result in edge_results], dtype=bool + ) + active_mask = success_mask.copy() + residuals = np.full(len(edges), np.nan, dtype=np.float64) + solved: dict[int, np.ndarray] = { + int(position): np.eye(4, dtype=np.float64) for position in positions + } + if not np.any(active_mask): + return solved, active_mask, residuals + + while True: + solved = {int(position): np.eye(4, dtype=np.float64) for position in positions} + active_edge_indices = np.flatnonzero(active_mask) + components = _component_positions(positions, active_edge_indices, edges) + for component_positions in components: + component_edges = [ + int(edge_index) + for edge_index in active_edge_indices + if int(edges[int(edge_index)].fixed_position) in component_positions + and int(edges[int(edge_index)].moving_position) in component_positions + ] + component_anchor = ( + int(anchor_position) + if int(anchor_position) in component_positions + else int(component_positions[0]) + ) + if registration_type == "translation": + component_solution = _solve_translation_component( + component_positions=component_positions, + active_edge_indices=component_edges, + edge_results=edge_results, + edges=edges, + anchor_position=component_anchor, + ) + else: + component_solution = _solve_nonlinear_component( + component_positions=component_positions, + active_edge_indices=component_edges, + edge_results=edge_results, + edges=edges, + registration_type=registration_type, + anchor_position=component_anchor, + voxel_size_um_zyx=voxel_size_um_zyx, + ) + solved.update(component_solution) + + residuals[:] = np.nan + for edge_index in active_edge_indices: + edge = edges[int(edge_index)] + measured = np.asarray( + edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 + ) + predicted = np.linalg.inv(solved[int(edge.fixed_position)]) @ solved[ + int(edge.moving_position) + ] + residuals[int(edge_index)] = _translation_residual_voxels( + measured, predicted, voxel_size_um_zyx=voxel_size_um_zyx + ) + + active_residuals = residuals[np.isfinite(residuals)] + if active_residuals.size == 0: + break + worst_edge_index = int(np.nanargmax(residuals)) + worst_residual = float(residuals[worst_edge_index]) + mean_residual = float(np.mean(active_residuals)) + if ( + worst_residual <= float(_ABSOLUTE_RESIDUAL_THRESHOLD_PX) + or worst_residual <= max(float(_ABSOLUTE_RESIDUAL_THRESHOLD_PX), mean_residual * float(_RELATIVE_RESIDUAL_THRESHOLD)) + ): + break + + edge = edges[worst_edge_index] + component_positions = next( + component + for component in components + if int(edge.fixed_position) in component and int(edge.moving_position) in component + ) + component_active = [ + int(edge_index) + for edge_index in active_edge_indices + if int(edges[int(edge_index)].fixed_position) in component_positions + and int(edges[int(edge_index)].moving_position) in component_positions + ] + if len(component_active) <= max(0, len(component_positions) - 1): + break + active_mask[worst_edge_index] = False + + return solved, active_mask, residuals + + +def _blend_weight_volume( + shape_zyx: Sequence[int], + *, + blend_mode: str, + overlap_zyx: Sequence[int], +) -> np.ndarray: + """Build a separable edge-feathered weight volume.""" + shape = tuple(int(value) for value in shape_zyx) + if str(blend_mode).strip().lower() == "average": + return np.ones(shape, dtype=np.float32) + + profiles: list[np.ndarray] = [] + for axis_size, ramp_width in zip(shape, overlap_zyx, strict=False): + profile = np.ones(int(axis_size), dtype=np.float32) + width = max(0, min(int(ramp_width), max(0, int(axis_size // 2)))) + if width > 0: + ramp = 0.5 - 0.5 * np.cos( + np.linspace(0.0, np.pi, width, dtype=np.float32) + ) + profile[:width] = ramp + profile[-width:] = np.minimum(profile[-width:], ramp[::-1]) + profiles.append(profile) + return ( + profiles[0][:, np.newaxis, np.newaxis] + * profiles[1][np.newaxis, :, np.newaxis] + * profiles[2][np.newaxis, np.newaxis, :] + ).astype(np.float32) + + +def _cast_to_dtype(data: np.ndarray, dtype: np.dtype[Any]) -> np.ndarray: + """Cast float output to the requested dtype.""" + out_dtype = np.dtype(dtype) + if np.issubdtype(out_dtype, np.integer): + info = np.iinfo(out_dtype) + return np.clip(np.rint(data), info.min, info.max).astype(out_dtype, copy=False) + return np.asarray(data, dtype=out_dtype) + + +def _axis_chunk_bounds(size: int, chunk_size: int) -> list[tuple[int, int]]: + """Build contiguous chunk bounds for one axis.""" + return [ + (start, min(start + chunk_size, int(size))) + for start in range(0, int(size), int(chunk_size)) + ] + + +def _process_and_write_registration_chunk( + *, + zarr_path: str, + source_component: str, + output_component: str, + affines_component: str, + transformed_bboxes_component: str, + t_index: int, + c_index: int, + z_bounds: tuple[int, int], + y_bounds: tuple[int, int], + x_bounds: tuple[int, int], + output_origin_xyz: Sequence[float], + voxel_size_um_zyx: Sequence[float], + blend_mode: str, + overlap_zyx: Sequence[int], + output_dtype: str, +) -> int: + """Render one fused output chunk into the registration result store.""" + root = zarr.open_group(str(zarr_path), mode="r") + source = root[source_component] + z0, z1 = (int(z_bounds[0]), int(z_bounds[1])) + y0, y1 = (int(y_bounds[0]), int(y_bounds[1])) + x0, x1 = (int(x_bounds[0]), int(x_bounds[1])) + chunk_shape_zyx = (z1 - z0, y1 - y0, x1 - x0) + chunk_origin_xyz = np.asarray( + [ + float(output_origin_xyz[0]) + (float(x0) * float(voxel_size_um_zyx[2])), + float(output_origin_xyz[1]) + (float(y0) * float(voxel_size_um_zyx[1])), + float(output_origin_xyz[2]) + (float(z0) * float(voxel_size_um_zyx[0])), + ], + dtype=np.float64, + ) + chunk_bbox_xyz = np.asarray( + [ + [float(chunk_origin_xyz[0]), float(chunk_origin_xyz[0]) + (float(chunk_shape_zyx[2]) * float(voxel_size_um_zyx[2]))], + [float(chunk_origin_xyz[1]), float(chunk_origin_xyz[1]) + (float(chunk_shape_zyx[1]) * float(voxel_size_um_zyx[1]))], + [float(chunk_origin_xyz[2]), float(chunk_origin_xyz[2]) + (float(chunk_shape_zyx[0]) * float(voxel_size_um_zyx[0]))], + ], + dtype=np.float64, + ) + + chunk_sum = np.zeros(chunk_shape_zyx, dtype=np.float32) + chunk_weight = np.zeros(chunk_shape_zyx, dtype=np.float32) + position_count = int(source.shape[1]) + affines = root[affines_component] + transformed_bboxes = root[transformed_bboxes_component] + for position_index in range(position_count): + bbox_payload = np.asarray( + transformed_bboxes[int(t_index), int(position_index)], dtype=np.float64 + ) + bbox_min = bbox_payload[:3] + bbox_max = bbox_payload[3:] + if np.any(bbox_max <= chunk_bbox_xyz[:, 0]) or np.any(chunk_bbox_xyz[:, 1] <= bbox_min): + continue + source_volume = np.asarray( + source[int(t_index), int(position_index), int(c_index), :, :, :], + dtype=np.float32, + ) + transform_xyz = np.asarray( + affines[int(t_index), int(position_index)], dtype=np.float64 + ) + warped_volume = _resample_source_to_world_grid( + source_volume, + transform_xyz, + reference_origin_xyz=chunk_origin_xyz, + reference_shape_zyx=chunk_shape_zyx, + voxel_size_um_zyx=voxel_size_um_zyx, + order=1, + cval=0.0, + ) + weight_volume = _blend_weight_volume( + source_volume.shape, blend_mode=blend_mode, overlap_zyx=overlap_zyx + ) + warped_weight = _resample_source_to_world_grid( + weight_volume, + transform_xyz, + reference_origin_xyz=chunk_origin_xyz, + reference_shape_zyx=chunk_shape_zyx, + voxel_size_um_zyx=voxel_size_um_zyx, + order=1, + cval=0.0, + ) + chunk_sum += warped_volume * warped_weight + chunk_weight += warped_weight + + normalized = np.zeros(chunk_shape_zyx, dtype=np.float32) + np.divide( + chunk_sum, + np.maximum(chunk_weight, _WEIGHT_EPS), + out=normalized, + where=chunk_weight > 0, + ) + write_root = zarr.open_group(str(zarr_path), mode="a") + write_root[output_component][int(t_index), 0, int(c_index), z0:z1, y0:y1, x0:x1] = _cast_to_dtype( + normalized, np.dtype(output_dtype) + ) + return 1 + + +def _prepare_output_group( + *, + zarr_path: Union[str, Path], + source_component: str, + parameters: Mapping[str, Any], + output_shape_tpczyx: tuple[int, int, int, int, int, int], + output_chunks_tpczyx: tuple[int, int, int, int, int, int], + voxel_size_um_zyx: Sequence[float], + output_origin_xyz: Sequence[float], +) -> tuple[str, str, str]: + """Create latest registration output datasets.""" + root = zarr.open_group(str(zarr_path), mode="a") + results_group = root.require_group("results") + registration_group = results_group.require_group("registration") + if "latest" in registration_group: + del registration_group["latest"] + latest = registration_group.create_group("latest") + latest.create_dataset( + name="data", + shape=output_shape_tpczyx, + chunks=output_chunks_tpczyx, + dtype=root[source_component].dtype, + overwrite=True, + ) + latest["data"].attrs.update( + { + "axes": ["t", "p", "c", "z", "y", "x"], + "source_component": str(source_component), + "voxel_size_um_zyx": [float(value) for value in voxel_size_um_zyx], + "output_origin_xyz_um": [float(value) for value in output_origin_xyz], + "storage_policy": "latest_only", + } + ) + latest.attrs.update( + { + "storage_policy": "latest_only", + "source_component": str(source_component), + "parameters": {str(key): value for key, value in dict(parameters).items()}, + "output_shape_tpczyx": [int(value) for value in output_shape_tpczyx], + "output_chunks_tpczyx": [int(value) for value in output_chunks_tpczyx], + "voxel_size_um_zyx": [float(value) for value in voxel_size_um_zyx], + "output_origin_xyz_um": [float(value) for value in output_origin_xyz], + } + ) + latest.create_dataset( + name="affines_tpx44", + shape=(output_shape_tpczyx[0], int(root[source_component].shape[1]), 4, 4), + dtype=np.float64, + overwrite=True, + ) + return ( + "results/registration/latest", + "results/registration/latest/data", + "results/registration/latest/affines_tpx44", + ) + + +def _estimate_worker_thread_capacity(client: "Client") -> int: + """Estimate available worker thread capacity for throttled submission.""" + try: + info = client.scheduler_info() + workers = info.get("workers", {}) + thread_total = sum( + max(1, int(details.get("nthreads", 1))) for details in workers.values() + ) + return max(1, int(thread_total)) + except Exception: + return 1 + + +def run_registration_analysis( + *, + zarr_path: Union[str, Path], + parameters: Mapping[str, Any], + client: Optional["Client"] = None, + progress_callback: Optional[ProgressCallback] = None, +) -> RegistrationSummary: + """Run 3D tile registration for a canonical multiposition analysis store. + + Parameters + ---------- + zarr_path : str or pathlib.Path + Path to canonical analysis-store Zarr/N5 object. + parameters : mapping[str, Any] + Normalized registration parameters. + client : dask.distributed.Client, optional + Active Dask client for distributed execution. + progress_callback : callable, optional + Progress callback invoked as ``callback(percent, message)``. + + Returns + ------- + RegistrationSummary + Summary metadata for the completed registration run. + + Raises + ------ + ValueError + If source components or required stage metadata are missing. + """ + root = zarr.open_group(str(zarr_path), mode="r") + requested_source_component = str(parameters.get("input_source", "data")).strip() or "data" + requested_resolution_level = max(0, int(parameters.get("input_resolution_level", 0))) + source_component, pairwise_source_component, effective_level = _resolve_source_components_for_level( + root=root, + requested_source_component=requested_source_component, + input_resolution_level=requested_resolution_level, + ) + full_source = root[source_component] + pairwise_source = root[pairwise_source_component] + source_shape_tpczyx = tuple(int(value) for value in full_source.shape) + pairwise_shape_tpczyx = tuple(int(value) for value in pairwise_source.shape) + if len(source_shape_tpczyx) != 6 or len(pairwise_shape_tpczyx) != 6: + raise ValueError( + "registration requires canonical 6D data (t,p,c,z,y,x). " + f"Input component '{source_component}' is incompatible." + ) + if source_shape_tpczyx[1] != pairwise_shape_tpczyx[1]: + raise ValueError("registration source position count mismatch between levels.") + + positions = [int(index) for index in range(int(source_shape_tpczyx[1]))] + channel_count = int(source_shape_tpczyx[2]) + registration_channel = max(0, int(parameters.get("registration_channel", 0))) + if registration_channel >= channel_count: + raise ValueError( + f"registration_channel={registration_channel} is out of bounds for {channel_count} channels." + ) + + root_attrs = dict(root.attrs) + spatial_calibration = spatial_calibration_from_dict(root_attrs.get("spatial_calibration")) + stage_rows = _load_stage_rows(root_attrs) + if len(positions) > 1 and len(stage_rows) < len(positions): + raise ValueError( + "registration requires multiposition stage metadata when more than one position is present." + ) + if not stage_rows: + stage_rows = [{"x": 0.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0} for _ in positions] + + configured_anchor = parameters.get("anchor_position") + if str(parameters.get("anchor_mode", "central")).strip().lower() == "manual" and configured_anchor is not None: + anchor_position = int(configured_anchor) + else: + anchor_position = _position_centroid_anchor(stage_rows, spatial_calibration, positions) + if anchor_position < 0 or anchor_position >= len(positions): + raise ValueError("registration anchor_position is out of bounds.") + + full_voxel_size_um_zyx = _extract_voxel_size_um_zyx(root, source_component) + level_factor_zyx = _pyramid_factor_zyx_for_level(root, level=effective_level) + pairwise_voxel_size_um_zyx = ( + float(full_voxel_size_um_zyx[0]) * float(level_factor_zyx[0]), + float(full_voxel_size_um_zyx[1]) * float(level_factor_zyx[1]), + float(full_voxel_size_um_zyx[2]) * float(level_factor_zyx[2]), + ) + + nominal_transforms_xyz = _build_nominal_transforms_xyz( + stage_rows, + spatial_calibration, + anchor_position=anchor_position, + positions=positions, + ) + pairwise_tile_extent_xyz = _tile_extent_xyz( + pairwise_shape_tpczyx[3:], pairwise_voxel_size_um_zyx + ) + edge_specs = _build_edge_specs( + nominal_transforms_xyz, + positions=positions, + tile_extent_xyz=pairwise_tile_extent_xyz, + voxel_size_um_zyx=pairwise_voxel_size_um_zyx, + ) + edge_count = len(edge_specs) + _emit(progress_callback, 5, f"Prepared {edge_count} registration graph edges") + + correction_affines_tex44 = np.zeros( + (int(source_shape_tpczyx[0]), int(edge_count), 4, 4), dtype=np.float64 + ) + edge_status_te = np.zeros((int(source_shape_tpczyx[0]), int(edge_count)), dtype=np.uint8) + edge_residual_te = np.full((int(source_shape_tpczyx[0]), int(edge_count)), np.nan, dtype=np.float32) + anchor_positions: list[int] = [] + effective_corrections_tpx44 = np.repeat( + np.eye(4, dtype=np.float64)[np.newaxis, np.newaxis, :, :], + int(source_shape_tpczyx[0]), + axis=0, + ) + effective_corrections_tpx44 = np.repeat( + effective_corrections_tpx44, + int(source_shape_tpczyx[1]), + axis=1, + ) + successful_edge_count = 0 + + for t_index in range(int(source_shape_tpczyx[0])): + anchor_positions.append(int(anchor_position)) + if edge_count == 0: + continue + delayed_edges = [ + delayed(_register_pairwise_overlap)( + zarr_path=str(zarr_path), + source_component=pairwise_source_component, + t_index=int(t_index), + registration_channel=int(registration_channel), + edge=edge, + nominal_fixed_transform_xyz=nominal_transforms_xyz[int(edge.fixed_position)], + nominal_moving_transform_xyz=nominal_transforms_xyz[int(edge.moving_position)], + voxel_size_um_zyx=pairwise_voxel_size_um_zyx, + overlap_zyx=[int(value) for value in parameters.get("overlap_zyx", [8, 32, 32])], + registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + ) + for edge in edge_specs + ] + if client is None: + pairwise_results = list(dask.compute(*delayed_edges, scheduler="processes")) + else: + from dask.distributed import as_completed + + pairwise_results = [] + futures = client.compute(delayed_edges) + completed = 0 + total = max(1, len(futures)) + for future in as_completed(futures): + pairwise_results.append(future.result()) + completed += 1 + _emit( + progress_callback, + 5 + int((completed / total) * 35), + f"Pairwise registration {completed}/{total} for t={t_index}", + ) + successful_edge_count += sum( + 1 for result in pairwise_results if bool(result.get("success", False)) + ) + + solved, active_mask, residuals = _solve_with_pruning( + positions=positions, + edges=edge_specs, + edge_results=pairwise_results, + anchor_position=anchor_position, + registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + voxel_size_um_zyx=pairwise_voxel_size_um_zyx, + ) + for position_index in positions: + effective_corrections_tpx44[int(t_index), int(position_index)] = solved[int(position_index)] + for edge_index, result in enumerate(pairwise_results): + correction_affines_tex44[int(t_index), int(edge_index)] = np.asarray( + result["correction_matrix_xyz"], dtype=np.float64 + ) + edge_status_te[int(t_index), int(edge_index)] = ( + 1 if bool(active_mask[int(edge_index)]) else 0 + ) + edge_residual_te[int(t_index), int(edge_index)] = float( + residuals[int(edge_index)] + ) + + effective_transforms_tpx44 = np.zeros( + (int(source_shape_tpczyx[0]), int(source_shape_tpczyx[1]), 4, 4), + dtype=np.float64, + ) + transformed_bboxes_tpx6 = np.zeros( + (int(source_shape_tpczyx[0]), int(source_shape_tpczyx[1]), 6), + dtype=np.float64, + ) + full_tile_extent_xyz = _tile_extent_xyz(source_shape_tpczyx[3:], full_voxel_size_um_zyx) + all_bbox_mins: list[np.ndarray] = [] + all_bbox_maxs: list[np.ndarray] = [] + for t_index in range(int(source_shape_tpczyx[0])): + for position_index in positions: + effective_transform = ( + effective_corrections_tpx44[int(t_index), int(position_index)] + @ nominal_transforms_xyz[int(position_index)] + ) + effective_transforms_tpx44[int(t_index), int(position_index)] = effective_transform + bbox_min, bbox_max = _tile_bbox_xyz(effective_transform, full_tile_extent_xyz) + transformed_bboxes_tpx6[int(t_index), int(position_index), :3] = bbox_min + transformed_bboxes_tpx6[int(t_index), int(position_index), 3:] = bbox_max + all_bbox_mins.append(bbox_min) + all_bbox_maxs.append(bbox_max) + + if all_bbox_mins: + output_min_xyz = np.min(np.vstack(all_bbox_mins), axis=0) + output_max_xyz = np.max(np.vstack(all_bbox_maxs), axis=0) + else: + output_min_xyz = np.zeros(3, dtype=np.float64) + output_max_xyz = np.asarray(full_tile_extent_xyz, dtype=np.float64) + output_shape_xyz = np.maximum( + 1, + np.ceil( + (output_max_xyz - output_min_xyz) + / np.asarray( + [ + float(full_voxel_size_um_zyx[2]), + float(full_voxel_size_um_zyx[1]), + float(full_voxel_size_um_zyx[0]), + ], + dtype=np.float64, + ) + ).astype(int), + ) + output_shape_tpczyx = ( + int(source_shape_tpczyx[0]), + 1, + int(source_shape_tpczyx[2]), + int(output_shape_xyz[2]), + int(output_shape_xyz[1]), + int(output_shape_xyz[0]), + ) + source_chunks = tuple(int(value) for value in full_source.chunks) + output_chunks_tpczyx = ( + 1, + 1, + 1, + max(1, min(int(source_chunks[3]), int(output_shape_tpczyx[3]))), + max(1, min(int(source_chunks[4]), int(output_shape_tpczyx[4]))), + max(1, min(int(source_chunks[5]), int(output_shape_tpczyx[5]))), + ) + component, data_component, affines_component = _prepare_output_group( + zarr_path=zarr_path, + source_component=source_component, + parameters=parameters, + output_shape_tpczyx=output_shape_tpczyx, + output_chunks_tpczyx=output_chunks_tpczyx, + voxel_size_um_zyx=full_voxel_size_um_zyx, + output_origin_xyz=output_min_xyz, + ) + write_root = zarr.open_group(str(zarr_path), mode="a") + latest_group = write_root["results/registration/latest"] + latest_group.create_dataset( + name="edges_pe2", + data=( + np.asarray( + [ + [int(edge.fixed_position), int(edge.moving_position)] + for edge in edge_specs + ], + dtype=np.int32, + ).reshape(-1, 2) + if edge_specs + else np.zeros((0, 2), dtype=np.int32) + ), + overwrite=True, + ) + latest_group.create_dataset( + name="pairwise_affines_tex44", + data=correction_affines_tex44, + overwrite=True, + ) + latest_group.create_dataset( + name="edge_status_te", + data=edge_status_te, + overwrite=True, + ) + latest_group.create_dataset( + name="edge_residual_te", + data=edge_residual_te, + overwrite=True, + ) + latest_group.create_dataset( + name="transformed_bboxes_tpx6", + data=transformed_bboxes_tpx6, + overwrite=True, + ) + latest_group["affines_tpx44"][:] = effective_transforms_tpx44 + + z_bounds = _axis_chunk_bounds(output_shape_tpczyx[3], output_chunks_tpczyx[3]) + y_bounds = _axis_chunk_bounds(output_shape_tpczyx[4], output_chunks_tpczyx[4]) + x_bounds = _axis_chunk_bounds(output_shape_tpczyx[5], output_chunks_tpczyx[5]) + fusion_tasks = [ + ( + int(t_index), + int(c_index), + z_chunk, + y_chunk, + x_chunk, + ) + for t_index in range(int(output_shape_tpczyx[0])) + for c_index in range(int(output_shape_tpczyx[2])) + for z_chunk in z_bounds + for y_chunk in y_bounds + for x_chunk in x_bounds + ] + _emit(progress_callback, 45, f"Prepared {len(fusion_tasks)} fusion chunk tasks") + + task_kwargs = [ + dict( + zarr_path=str(zarr_path), + source_component=source_component, + output_component=data_component, + affines_component=affines_component, + transformed_bboxes_component="results/registration/latest/transformed_bboxes_tpx6", + t_index=t_index, + c_index=c_index, + z_bounds=z_chunk, + y_bounds=y_chunk, + x_bounds=x_chunk, + output_origin_xyz=output_min_xyz, + voxel_size_um_zyx=full_voxel_size_um_zyx, + blend_mode=str(parameters.get("blend_mode", "feather")).strip().lower(), + overlap_zyx=[int(value) for value in parameters.get("overlap_zyx", [8, 32, 32])], + output_dtype=str(full_source.dtype), + ) + for t_index, c_index, z_chunk, y_chunk, x_chunk in fusion_tasks + ] + if client is None: + delayed_tasks = [ + delayed(_process_and_write_registration_chunk)(**kwargs) + for kwargs in task_kwargs + ] + dask.compute(*delayed_tasks, scheduler="processes") + else: + from dask.distributed import as_completed + + max_in_flight = max(16, int(_estimate_worker_thread_capacity(client)) * 2) + completion_queue = as_completed() + pending_count = 0 + completed = 0 + total = max(1, len(task_kwargs)) + for kwargs in task_kwargs: + completion_queue.add( + client.submit( + _process_and_write_registration_chunk, + **kwargs, + pure=False, + ) + ) + pending_count += 1 + if pending_count < max_in_flight: + continue + finished = next(completion_queue) + _ = finished.result() + pending_count -= 1 + completed += 1 + _emit( + progress_callback, + 45 + int((completed / total) * 50), + f"Fusion chunk {completed}/{total}", + ) + for finished in completion_queue: + _ = finished.result() + completed += 1 + _emit( + progress_callback, + 45 + int((completed / total) * 50), + f"Fusion chunk {completed}/{total}", + ) + + active_edge_count = int(np.count_nonzero(edge_status_te)) + dropped_edge_count = max(0, int(successful_edge_count) - int(active_edge_count)) + latest_group.attrs.update( + { + "requested_source_component": str(requested_source_component), + "source_component": str(source_component), + "pairwise_source_component": str(pairwise_source_component), + "requested_input_resolution_level": int(requested_resolution_level), + "input_resolution_level": int(effective_level), + "registration_channel": int(registration_channel), + "registration_type": str(parameters.get("registration_type", "rigid")), + "anchor_positions": [int(value) for value in anchor_positions], + "edge_count": int(edge_count), + "active_edge_count": int(active_edge_count), + "dropped_edge_count": int(dropped_edge_count), + "blend_mode": str(parameters.get("blend_mode", "feather")), + } + ) + + register_latest_output_reference( + zarr_path=zarr_path, + analysis_name="registration", + component=component, + metadata={ + "data_component": data_component, + "affines_component": affines_component, + "requested_source_component": requested_source_component, + "source_component": source_component, + "pairwise_source_component": pairwise_source_component, + "requested_input_resolution_level": int(requested_resolution_level), + "input_resolution_level": int(effective_level), + "registration_channel": int(registration_channel), + "registration_type": str(parameters.get("registration_type", "rigid")), + "anchor_positions": [int(value) for value in anchor_positions], + "edge_count": int(edge_count), + "active_edge_count": int(active_edge_count), + "dropped_edge_count": int(dropped_edge_count), + "blend_mode": str(parameters.get("blend_mode", "feather")), + "output_shape_tpczyx": [int(value) for value in output_shape_tpczyx], + "output_chunks_tpczyx": [int(value) for value in output_chunks_tpczyx], + "voxel_size_um_zyx": [float(value) for value in full_voxel_size_um_zyx], + "output_origin_xyz_um": [float(value) for value in output_min_xyz], + }, + ) + _emit(progress_callback, 100, "Registration complete") + return RegistrationSummary( + component=component, + data_component=data_component, + affines_component=affines_component, + source_component=source_component, + requested_source_component=requested_source_component, + pairwise_source_component=pairwise_source_component, + input_resolution_level=int(effective_level), + requested_input_resolution_level=int(requested_resolution_level), + registration_channel=int(registration_channel), + registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + anchor_positions=tuple(int(value) for value in anchor_positions), + positions=int(source_shape_tpczyx[1]), + timepoints=int(source_shape_tpczyx[0]), + edge_count=int(edge_count), + active_edge_count=int(active_edge_count), + dropped_edge_count=int(dropped_edge_count), + output_shape_tpczyx=tuple(int(value) for value in output_shape_tpczyx), + output_chunks_tpczyx=tuple(int(value) for value in output_chunks_tpczyx), + blend_mode=str(parameters.get("blend_mode", "feather")).strip().lower(), + ) diff --git a/src/clearex/workflow.py b/src/clearex/workflow.py index 0a42ae6..342b5cb 100644 --- a/src/clearex/workflow.py +++ b/src/clearex/workflow.py @@ -64,6 +64,7 @@ "deconvolution": "results/deconvolution/latest/data", "shear_transform": "results/shear_transform/latest/data", "usegment3d": "results/usegment3d/latest/data", + "registration": "results/registration/latest/data", } ANALYSIS_KNOWN_OUTPUT_COMPONENTS: Dict[str, str] = { "data": "data", @@ -568,7 +569,7 @@ class WorkflowExecutionCancelled(RuntimeError): "roi_padding_zyx": [2, 2, 2], }, "particle_detection": { - "execution_order": 4, + "execution_order": 5, "input_source": "data", "force_rerun": False, "channel_index": 0, @@ -589,7 +590,7 @@ class WorkflowExecutionCancelled(RuntimeError): "min_distance_sigma": 10.0, }, "usegment3d": { - "execution_order": 5, + "execution_order": 6, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -647,7 +648,7 @@ class WorkflowExecutionCancelled(RuntimeError): "output_dtype": "uint32", }, "registration": { - "execution_order": 6, + "execution_order": 4, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -655,6 +656,12 @@ class WorkflowExecutionCancelled(RuntimeError): "use_map_overlap": True, "overlap_zyx": [8, 32, 32], "memory_overhead_factor": 2.5, + "registration_channel": 0, + "registration_type": "rigid", + "input_resolution_level": 0, + "anchor_mode": "central", + "anchor_position": None, + "blend_mode": "feather", }, "display_pyramid": { "execution_order": 7, @@ -1703,6 +1710,70 @@ def _normalize_usegment3d_parameters( return normalized +def _normalize_registration_parameters( + params: Dict[str, Any], +) -> Dict[str, Any]: + """Normalize registration runtime parameters. + + Parameters + ---------- + params : dict[str, Any] + Candidate registration parameters. + + Returns + ------- + dict[str, Any] + Normalized registration parameter mapping. + + Raises + ------ + ValueError + If required values are invalid. + """ + normalized = _normalize_common_operation_parameters("registration", params) + normalized["detect_2d_per_slice"] = False + normalized["use_map_overlap"] = bool(normalized.get("use_map_overlap", True)) + normalized["registration_channel"] = max( + 0, int(normalized.get("registration_channel", 0)) + ) + + registration_type = ( + str(normalized.get("registration_type", "rigid")).strip().lower() or "rigid" + ) + if registration_type not in {"translation", "rigid", "similarity"}: + raise ValueError( + "registration registration_type must be one of translation, rigid, or similarity." + ) + normalized["registration_type"] = registration_type + + input_resolution_level = int(normalized.get("input_resolution_level", 0)) + if input_resolution_level < 0: + raise ValueError("registration input_resolution_level must be >= 0.") + normalized["input_resolution_level"] = input_resolution_level + + anchor_mode = ( + str(normalized.get("anchor_mode", "central")).strip().lower() or "central" + ) + if anchor_mode not in {"central", "manual"}: + raise ValueError("registration anchor_mode must be 'central' or 'manual'.") + normalized["anchor_mode"] = anchor_mode + + anchor_position = normalized.get("anchor_position") + if anchor_position in {None, ""}: + normalized["anchor_position"] = None + else: + parsed_anchor = int(anchor_position) + if parsed_anchor < 0: + raise ValueError("registration anchor_position must be >= 0.") + normalized["anchor_position"] = parsed_anchor + + blend_mode = str(normalized.get("blend_mode", "feather")).strip().lower() + if blend_mode not in {"average", "feather"}: + raise ValueError("registration blend_mode must be 'average' or 'feather'.") + normalized["blend_mode"] = blend_mode + return normalized + + def _normalize_shear_transform_parameters( params: Dict[str, Any], ) -> Dict[str, Any]: @@ -2214,6 +2285,10 @@ def normalize_analysis_operation_parameters( merged[operation_name] = _normalize_usegment3d_parameters( merged[operation_name] ) + elif operation_name == "registration": + merged[operation_name] = _normalize_registration_parameters( + merged[operation_name] + ) elif operation_name == "visualization": merged[operation_name] = _normalize_visualization_parameters( merged[operation_name] diff --git a/tests/gui/test_gui_execution.py b/tests/gui/test_gui_execution.py index cfc6110..08ad4f8 100644 --- a/tests/gui/test_gui_execution.py +++ b/tests/gui/test_gui_execution.py @@ -286,6 +286,69 @@ def test_analysis_dialog_scrolls_body_on_short_screens(monkeypatch) -> None: dialog.close() +def test_registration_operation_moves_to_preprocessing_tab() -> None: + if not hasattr(app_module, "AnalysisSelectionDialog"): + return + + tab_map = dict(app_module.AnalysisSelectionDialog._OPERATION_TABS) + + assert "registration" in tab_map["Preprocessing"] + assert "registration" not in tab_map.get("Postprocessing", ()) + + +def test_analysis_dialog_persists_registration_parameters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + if not app_module.HAS_PYQT6: + return + + app = app_module.QApplication.instance() + if app is None: + app = app_module.QApplication([]) + + monkeypatch.setattr( + app_module, + "_save_last_used_dask_backend_config", + lambda _config: None, + ) + + dialog = app_module.AnalysisSelectionDialog( + initial=app_module.WorkflowConfig(file="/tmp/test/data_store.fake") + ) + dialog._persist_analysis_gui_state_for_target = lambda _target: None + dialog._operation_checkboxes["registration"].setChecked(True) + dialog._registration_type_combo.setCurrentIndex( + dialog._registration_type_combo.findData("similarity") + ) + dialog._registration_anchor_mode_combo.setCurrentIndex( + dialog._registration_anchor_mode_combo.findData("manual") + ) + dialog._registration_anchor_position_spin.setMaximum(3) + dialog._registration_anchor_position_spin.setValue(2) + dialog._registration_overlap_z_spin.setValue(5) + dialog._registration_overlap_y_spin.setValue(12) + dialog._registration_overlap_x_spin.setValue(20) + dialog._registration_blend_mode_combo.setCurrentIndex( + dialog._registration_blend_mode_combo.findData("average") + ) + dialog._registration_resolution_level_spin.setMaximum(3) + dialog._registration_resolution_level_spin.setValue(1) + app_module.AnalysisSelectionDialog._set_registration_parameter_enabled_state( + dialog + ) + + dialog._on_run() + + params = dialog.result_config.analysis_parameters["registration"] + assert dialog.result_config.registration is True + assert params["registration_type"] == "similarity" + assert params["anchor_mode"] == "manual" + assert params["anchor_position"] == 2 + assert params["overlap_zyx"] == [5, 12, 20] + assert params["blend_mode"] == "average" + assert params["input_resolution_level"] == 1 + + def test_zarr_dialog_scrolls_body_on_short_screens(monkeypatch) -> None: if not app_module.HAS_PYQT6: return diff --git a/tests/registration/test_pipeline.py b/tests/registration/test_pipeline.py new file mode 100644 index 0000000..da42407 --- /dev/null +++ b/tests/registration/test_pipeline.py @@ -0,0 +1,365 @@ +# Copyright (c) 2021-2025 The University of Texas Southwestern Medical Center. +# All rights reserved. + +from __future__ import annotations + +import json +from pathlib import Path + +import dask as dask_module +import numpy as np +import pytest +from scipy.spatial.transform import Rotation +import zarr + +import clearex.registration.pipeline as registration_pipeline +from clearex.workflow import SPATIAL_CALIBRATION_SCHEMA, SpatialCalibrationConfig + + +def _translation_matrix(x: float, y: float = 0.0, z: float = 0.0) -> np.ndarray: + """Build a homogeneous XYZ translation matrix.""" + matrix = np.eye(4, dtype=np.float64) + matrix[:3, 3] = np.asarray([x, y, z], dtype=np.float64) + return matrix + + +def _edge_result( + edge: registration_pipeline._EdgeSpec, + matrix_xyz: np.ndarray, + *, + success: bool = True, +) -> dict[str, object]: + """Build one synthetic pairwise edge result.""" + return { + "fixed_position": int(edge.fixed_position), + "moving_position": int(edge.moving_position), + "success": bool(success), + "reason": "", + "correction_matrix_xyz": np.asarray(matrix_xyz, dtype=np.float64).tolist(), + "overlap_voxels": int(edge.overlap_voxels), + "nominal_overlap_pixels": int(edge.overlap_voxels), + } + + +def _write_multiposition_sidecar(directory: Path, rows: list[list[float]]) -> Path: + """Write minimal Navigate experiment and multiposition sidecar files.""" + experiment_path = directory / "experiment.yml" + experiment_path.write_text("Saving:\n file_type: OME-ZARR\n", encoding="utf-8") + sidecar_path = directory / "multi_positions.yml" + payload = [["X", "Y", "Z", "Theta", "F"], *rows] + sidecar_path.write_text(json.dumps(payload), encoding="utf-8") + return experiment_path + + +def _create_registration_store( + tmp_path: Path, + *, + timepoints: int = 2, + positions: int = 3, + channels: int = 1, + shape_zyx: tuple[int, int, int] = (4, 4, 6), + include_pyramid: bool = True, +) -> Path: + """Create a synthetic canonical 6D store with overlapping multiposition tiles.""" + store_path = tmp_path / "registration_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + data_shape = (timepoints, positions, channels, *shape_zyx) + data = root.create_dataset( + name="data", + shape=data_shape, + chunks=(1, 1, 1, *shape_zyx), + dtype="uint16", + overwrite=True, + ) + data.attrs["voxel_size_um_zyx"] = [1.0, 1.0, 1.0] + root.attrs["voxel_size_um_zyx"] = [1.0, 1.0, 1.0] + root.attrs["spatial_calibration"] = { + "schema": SPATIAL_CALIBRATION_SCHEMA, + "stage_axis_map_zyx": {"z": "+z", "y": "+y", "x": "+x"}, + "theta_mode": "rotate_zy_about_x", + } + experiment_path = _write_multiposition_sidecar( + tmp_path, + rows=[[0.0, 0.0, 0.0, 0.0, 0.0], [4.0, 0.0, 0.0, 0.0, 0.0], [8.0, 0.0, 0.0, 0.0, 0.0]][ + :positions + ], + ) + root.attrs["source_experiment"] = str(experiment_path) + root.attrs["data_pyramid_factors_tpczyx"] = [ + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 2, 2, 2], + ] + + for t_index in range(timepoints): + for position_index in range(positions): + volume = np.full( + shape_zyx, + fill_value=np.uint16((position_index + 1) * 10 + t_index), + dtype=np.uint16, + ) + data[t_index, position_index, 0] = volume + + if include_pyramid: + pyramid = root.require_group("data_pyramid") + level_1 = pyramid.create_dataset( + name="level_1", + shape=(timepoints, positions, channels, 2, 2, 3), + chunks=(1, 1, 1, 2, 2, 3), + dtype="uint16", + overwrite=True, + ) + level_1[:] = data[:, :, :, ::2, ::2, ::2] + + return store_path + + +def test_build_edge_specs_only_keeps_overlapping_neighbors() -> None: + nominal = { + 0: _translation_matrix(0.0), + 1: _translation_matrix(4.0), + 2: _translation_matrix(8.0), + } + + edges = registration_pipeline._build_edge_specs( + nominal, + positions=(0, 1, 2), + tile_extent_xyz=np.asarray([6.0, 4.0, 4.0], dtype=np.float64), + voxel_size_um_zyx=(1.0, 1.0, 1.0), + ) + + assert [(edge.fixed_position, edge.moving_position) for edge in edges] == [ + (0, 1), + (1, 2), + ] + + +def test_resolve_source_components_for_level_uses_requested_pyramid( + tmp_path: Path, +) -> None: + store_path = _create_registration_store(tmp_path) + root = zarr.open_group(str(store_path), mode="r") + + source_component, pairwise_component, level = ( + registration_pipeline._resolve_source_components_for_level( + root=root, + requested_source_component="data", + input_resolution_level=1, + ) + ) + + assert source_component == "data" + assert pairwise_component == "data_pyramid/level_1" + assert level == 1 + + +def test_resolve_source_components_for_level_rejects_missing_level( + tmp_path: Path, +) -> None: + store_path = _create_registration_store(tmp_path, include_pyramid=False) + root = zarr.open_group(str(store_path), mode="r") + + with pytest.raises(ValueError, match="input_resolution_level=1"): + registration_pipeline._resolve_source_components_for_level( + root=root, + requested_source_component="data", + input_resolution_level=1, + ) + + +def test_position_centroid_anchor_prefers_central_tile() -> None: + stage_rows = [ + {"x": 0.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0}, + {"x": 4.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0}, + {"x": 8.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0}, + ] + + anchor = registration_pipeline._position_centroid_anchor( + stage_rows, + SpatialCalibrationConfig(), + positions=(0, 1, 2), + ) + + assert anchor == 1 + + +def test_solve_with_pruning_recovers_translation_and_prunes_outlier() -> None: + edges = [ + registration_pipeline._EdgeSpec( + 0, 1, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 10_000 + ), + registration_pipeline._EdgeSpec( + 1, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 10_000 + ), + registration_pipeline._EdgeSpec( + 0, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 1 + ), + ] + results = [ + _edge_result(edges[0], _translation_matrix(1.0)), + _edge_result(edges[1], _translation_matrix(1.0)), + _edge_result(edges[2], _translation_matrix(20.0)), + ] + + solved, active_mask, residuals = registration_pipeline._solve_with_pruning( + positions=(0, 1, 2), + edges=edges, + edge_results=results, + anchor_position=0, + registration_type="translation", + voxel_size_um_zyx=(1.0, 1.0, 1.0), + ) + + assert active_mask.tolist() == [True, True, False] + assert solved[1][:3, 3] == pytest.approx([1.0, 0.0, 0.0]) + assert solved[2][:3, 3] == pytest.approx([2.0, 0.0, 0.0]) + assert np.isnan(residuals[2]) + + +def test_solve_with_pruning_preserves_disconnected_component_identity() -> None: + edges = [ + registration_pipeline._EdgeSpec(0, 1, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 32), + registration_pipeline._EdgeSpec(1, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 32), + ] + results = [ + _edge_result(edges[0], _translation_matrix(1.0)), + _edge_result(edges[1], np.eye(4, dtype=np.float64), success=False), + ] + + solved, active_mask, _ = registration_pipeline._solve_with_pruning( + positions=(0, 1, 2), + edges=edges, + edge_results=results, + anchor_position=0, + registration_type="translation", + voxel_size_um_zyx=(1.0, 1.0, 1.0), + ) + + assert active_mask.tolist() == [True, False] + assert solved[1][:3, 3] == pytest.approx([1.0, 0.0, 0.0]) + assert solved[2] == pytest.approx(np.eye(4, dtype=np.float64)) + + +def test_solve_with_pruning_recovers_rigid_transform() -> None: + edge = registration_pipeline._EdgeSpec( + 0, 1, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 48 + ) + measured = np.eye(4, dtype=np.float64) + measured[:3, :3] = Rotation.from_rotvec([0.0, 0.0, 0.15]).as_matrix() + measured[:3, 3] = np.asarray([1.25, -0.5, 0.0], dtype=np.float64) + + solved, active_mask, residuals = registration_pipeline._solve_with_pruning( + positions=(0, 1), + edges=[edge], + edge_results=[_edge_result(edge, measured)], + anchor_position=0, + registration_type="rigid", + voxel_size_um_zyx=(1.0, 1.0, 1.0), + ) + + assert active_mask.tolist() == [True] + assert residuals[0] == pytest.approx(0.0, abs=1e-5) + assert solved[1] == pytest.approx(measured, abs=1e-5) + + +def test_solve_with_pruning_recovers_similarity_transform() -> None: + edge = registration_pipeline._EdgeSpec( + 0, 1, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 48 + ) + measured = registration_pipeline._matrix_from_pose( + np.asarray([0.0, 0.0, 0.1, 0.75, 0.5, 0.0, np.log(1.08)], dtype=np.float64), + "similarity", + ) + + solved, active_mask, residuals = registration_pipeline._solve_with_pruning( + positions=(0, 1), + edges=[edge], + edge_results=[_edge_result(edge, measured)], + anchor_position=0, + registration_type="similarity", + voxel_size_um_zyx=(1.0, 1.0, 1.0), + ) + + assert active_mask.tolist() == [True] + assert residuals[0] == pytest.approx(0.0, abs=1e-5) + assert solved[1] == pytest.approx(measured, abs=1e-5) + + +def test_run_registration_analysis_fuses_output_and_writes_metadata( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + store_path = _create_registration_store(tmp_path) + original_compute = dask_module.compute + + def _sync_compute(*args, **kwargs): + del kwargs + return original_compute(*args, scheduler="synchronous") + + def _fake_pairwise(**kwargs): + edge = kwargs["edge"] + t_index = int(kwargs["t_index"]) + correction = np.eye(4, dtype=np.float64) + if t_index == 1 and int(edge.fixed_position) == 1 and int(edge.moving_position) == 2: + correction[0, 3] = -1.0 + return { + "fixed_position": int(edge.fixed_position), + "moving_position": int(edge.moving_position), + "success": True, + "reason": "", + "correction_matrix_xyz": correction.tolist(), + "overlap_voxels": int(edge.overlap_voxels), + "nominal_overlap_pixels": int(edge.overlap_voxels), + } + + monkeypatch.setattr(registration_pipeline.dask, "compute", _sync_compute) + monkeypatch.setattr( + registration_pipeline, + "_register_pairwise_overlap", + _fake_pairwise, + ) + + progress_updates: list[tuple[int, str]] = [] + summary = registration_pipeline.run_registration_analysis( + zarr_path=store_path, + parameters={ + "input_source": "data", + "registration_channel": 0, + "registration_type": "rigid", + "input_resolution_level": 1, + "anchor_mode": "central", + "anchor_position": None, + "blend_mode": "feather", + "overlap_zyx": [0, 0, 2], + }, + client=None, + progress_callback=lambda percent, message: progress_updates.append( + (int(percent), str(message)) + ), + ) + + root = zarr.open_group(str(store_path), mode="r") + latest = root["results/registration/latest"] + data = latest["data"] + affines = latest["affines_tpx44"] + + assert summary.source_component == "data" + assert summary.pairwise_source_component == "data_pyramid/level_1" + assert summary.input_resolution_level == 1 + assert summary.anchor_positions == (1, 1) + assert summary.edge_count == 2 + assert summary.output_shape_tpczyx == (2, 1, 1, 4, 4, 14) + assert data.shape == (2, 1, 1, 4, 4, 14) + assert latest["edges_pe2"].shape == (2, 2) + assert latest["pairwise_affines_tex44"].shape == (2, 2, 4, 4) + assert latest["edge_status_te"].shape == (2, 2) + assert latest["edge_residual_te"].shape == (2, 2) + assert latest["transformed_bboxes_tpx6"].shape == (2, 3, 6) + assert affines[1, 2, 0, 3] == pytest.approx(3.0, abs=1e-6) + assert int(data[0, 0, 0, 0, 0, 4]) == 10 + assert int(data[0, 0, 0, 0, 0, 5]) == 20 + assert int(data[0, 0, 0, 0, 0, 8]) == 20 + assert int(data[0, 0, 0, 0, 0, 9]) == 30 + assert latest.attrs["pairwise_source_component"] == "data_pyramid/level_1" + assert latest.attrs["input_resolution_level"] == 1 + assert latest.attrs["blend_mode"] == "feather" + assert progress_updates + assert progress_updates[-1][0] == 100 diff --git a/tests/test_main.py b/tests/test_main.py index 2ec6b7f..3a913f4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -665,7 +665,9 @@ def _cancel_on_progress(percent: int, message: str) -> None: ) -def test_run_workflow_registration_skips_without_crashing(monkeypatch) -> None: +def test_run_workflow_registration_starts_analysis_dask_startup( + monkeypatch, +) -> None: workloads: list[str] = [] def _fake_configure_dask_backend(*, workflow, logger, exit_stack, workload="io"): @@ -689,7 +691,127 @@ def _fake_configure_dask_backend(*, workflow, logger, exit_stack, workload="io") logger=_test_logger("clearex.test.main.registration"), ) - assert workloads == [] + assert workloads == ["analysis"] + + +def test_run_workflow_chains_registration_output_to_visualization( + tmp_path: Path, monkeypatch +) -> None: + store_path = tmp_path / "analysis_store_registration_chain.zarr" + root = main_module.zarr.open_group(str(store_path), mode="w") + root.create_dataset( + name="data", + shape=(1, 2, 1, 2, 2, 2), + chunks=(1, 1, 1, 2, 2, 2), + dtype="uint16", + overwrite=True, + ) + + def _fake_configure_dask_backend(*, workflow, logger, exit_stack, workload="io"): + del workflow, logger, exit_stack, workload + return None + + def _fake_registration(*, zarr_path, parameters, client, progress_callback): + del parameters, client, progress_callback + fake_root = main_module.zarr.open_group(str(zarr_path), mode="a") + latest = ( + fake_root.require_group("results") + .require_group("registration") + .require_group("latest") + ) + latest.create_dataset( + name="data", + shape=(1, 1, 1, 2, 2, 3), + chunks=(1, 1, 1, 2, 2, 3), + dtype="uint16", + overwrite=True, + ) + latest.create_dataset( + name="affines_tpx44", + shape=(1, 2, 4, 4), + chunks=(1, 1, 4, 4), + dtype="float64", + overwrite=True, + ) + return SimpleNamespace( + component="results/registration/latest", + data_component="results/registration/latest/data", + affines_component="results/registration/latest/affines_tpx44", + source_component="data", + pairwise_source_component="data", + requested_source_component="data", + requested_input_resolution_level=0, + input_resolution_level=0, + registration_channel=0, + registration_type="rigid", + anchor_positions=(0,), + edge_count=1, + active_edge_count=1, + dropped_edge_count=0, + output_shape_tpczyx=(1, 1, 1, 2, 2, 3), + output_chunks_tpczyx=(1, 1, 1, 2, 2, 3), + blend_mode="feather", + ) + + captured: dict[str, object] = {} + + def _fake_visualization(*, zarr_path, parameters, progress_callback): + del zarr_path, progress_callback + captured["input_source"] = str(parameters["input_source"]) + fake_root = main_module.zarr.open_group(str(store_path), mode="a") + latest = ( + fake_root.require_group("results") + .require_group("visualization") + .require_group("latest") + ) + latest.attrs["source_component"] = str(parameters["input_source"]) + return SimpleNamespace( + component="results/visualization/latest", + source_component=str(parameters["input_source"]), + source_components=(str(parameters["input_source"]),), + position_index=0, + overlay_points_count=0, + launch_mode="in_process", + viewer_pid=None, + keyframe_manifest_path="", + keyframe_count=0, + ) + + monkeypatch.setattr( + main_module, "_configure_dask_backend", _fake_configure_dask_backend + ) + monkeypatch.setattr(main_module, "run_registration_analysis", _fake_registration) + monkeypatch.setattr(main_module, "run_visualization_analysis", _fake_visualization) + monkeypatch.setattr(main_module, "is_navigate_experiment_file", lambda path: False) + + workflow = WorkflowConfig( + file=str(store_path), + prefer_dask=True, + registration=True, + visualization=True, + analysis_parameters={ + "registration": { + "execution_order": 1, + "input_source": "data", + }, + "visualization": { + "execution_order": 2, + "input_source": "registration", + }, + }, + ) + main_module._run_workflow( + workflow=workflow, + logger=_test_logger("clearex.test.main.registration_chain"), + ) + + assert captured["input_source"] == "results/registration/latest/data" + latest_ref = dict( + main_module.zarr.open_group(str(store_path), mode="r")["provenance"][ + "latest_outputs" + ]["registration"].attrs + ) + assert latest_ref["component"] == "results/registration/latest" def test_run_workflow_non_experiment_file_skips_io_dask_startup( diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 6d9fee3..1ff978e 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -178,12 +178,21 @@ def test_default_zarr_save_config(self): assert "shear_transform" in cfg.analysis_parameters assert cfg.analysis_parameters["shear_transform"]["execution_order"] == 3 assert cfg.analysis_parameters["shear_transform"]["interpolation"] == "linear" + assert "registration" in cfg.analysis_parameters + assert cfg.analysis_parameters["registration"]["execution_order"] == 4 + assert cfg.analysis_parameters["registration"]["input_source"] == "data" + assert cfg.analysis_parameters["registration"]["registration_channel"] == 0 + assert cfg.analysis_parameters["registration"]["registration_type"] == "rigid" + assert cfg.analysis_parameters["registration"]["input_resolution_level"] == 0 + assert cfg.analysis_parameters["registration"]["anchor_mode"] == "central" + assert cfg.analysis_parameters["registration"]["anchor_position"] is None + assert cfg.analysis_parameters["registration"]["blend_mode"] == "feather" assert "particle_detection" in cfg.analysis_parameters assert cfg.analysis_parameters["particle_detection"]["bg_sigma"] == 20.0 - assert cfg.analysis_parameters["particle_detection"]["execution_order"] == 4 + assert cfg.analysis_parameters["particle_detection"]["execution_order"] == 5 assert cfg.analysis_parameters["particle_detection"]["input_source"] == "data" assert "usegment3d" in cfg.analysis_parameters - assert cfg.analysis_parameters["usegment3d"]["execution_order"] == 5 + assert cfg.analysis_parameters["usegment3d"]["execution_order"] == 6 assert cfg.analysis_parameters["usegment3d"]["input_source"] == "data" assert cfg.analysis_parameters["usegment3d"]["all_channels"] is False assert cfg.analysis_parameters["usegment3d"]["channel_indices"] == [0] @@ -629,6 +638,49 @@ def test_normalizes_usegment3d_parameters(self): assert params["output_reference_space"] == "native_level" assert params["save_native_labels"] is True + def test_normalizes_registration_parameters(self): + cfg = WorkflowConfig( + analysis_parameters={ + "registration": { + "execution_order": "9", + "input_source": "deconvolution", + "memory_overhead_factor": "4.5", + "overlap_zyx": [2, 6, 10], + "registration_channel": "2", + "registration_type": "similarity", + "input_resolution_level": "3", + "anchor_mode": "manual", + "anchor_position": "4", + "blend_mode": "average", + "force_rerun": 1, + } + } + ) + params = cfg.analysis_parameters["registration"] + assert params["execution_order"] == 9 + assert params["input_source"] == "deconvolution" + assert params["memory_overhead_factor"] == 4.5 + assert params["overlap_zyx"] == [2, 6, 10] + assert params["registration_channel"] == 2 + assert params["registration_type"] == "similarity" + assert params["input_resolution_level"] == 3 + assert params["anchor_mode"] == "manual" + assert params["anchor_position"] == 4 + assert params["blend_mode"] == "average" + assert params["force_rerun"] is True + + def test_rejects_invalid_registration_resolution_level(self): + with pytest.raises(ValueError): + WorkflowConfig( + analysis_parameters={"registration": {"input_resolution_level": -1}} + ) + + def test_rejects_invalid_registration_type(self): + with pytest.raises(ValueError): + WorkflowConfig( + analysis_parameters={"registration": {"registration_type": "affine"}} + ) + def test_rejects_invalid_usegment3d_resolution_level(self): with pytest.raises(ValueError): WorkflowConfig( @@ -913,7 +965,8 @@ def test_normalize_analysis_operation_parameters_returns_defaults(): assert normalized["deconvolution"]["execution_order"] == 2 assert normalized["flatfield"]["execution_order"] == 1 assert normalized["shear_transform"]["execution_order"] == 3 - assert normalized["usegment3d"]["execution_order"] == 5 + assert normalized["registration"]["execution_order"] == 4 + assert normalized["usegment3d"]["execution_order"] == 6 assert normalized["display_pyramid"]["execution_order"] == 7 assert normalized["visualization"]["input_source"] == "data" assert normalized["visualization"]["show_all_positions"] is False @@ -995,6 +1048,12 @@ def test_resolve_analysis_input_component_prefers_same_run_outputs() -> None: assert resolved == "results/flatfield/latest/data_run_2" +def test_resolve_analysis_input_component_supports_registration_alias() -> None: + resolved = resolve_analysis_input_component("registration") + + assert resolved == "results/registration/latest/data" + + def test_validate_analysis_input_references_accepts_scheduled_chainable_input() -> None: issues = validate_analysis_input_references( execution_sequence=("flatfield", "deconvolution"), From a33b4dc97b6fc51e7e7281183eb9f097272f8523 Mon Sep 17 00:00:00 2001 From: Kevin Dean Date: Sat, 21 Mar 2026 15:58:41 -0500 Subject: [PATCH 2/4] GUI registration: source-aware channel/resolution dropdowns and manual anchor visibility --- src/clearex/gui/app.py | 599 ++++++++++++++++++++++++++++---- tests/gui/test_gui_execution.py | 176 +++++++++- 2 files changed, 694 insertions(+), 81 deletions(-) diff --git a/src/clearex/gui/app.py b/src/clearex/gui/app.py index 6fdc27f..ab2cd35 100644 --- a/src/clearex/gui/app.py +++ b/src/clearex/gui/app.py @@ -1237,6 +1237,261 @@ def _discover_available_operation_output_components( return available +def _parse_source_component_level( + *, + component: str, + full_resolution_source: str, +) -> Optional[int]: + """Parse one component path into a pyramid resolution level. + + Parameters + ---------- + component : str + Candidate component token or full component path. + full_resolution_source : str + Full-resolution source component path associated with the token. + + Returns + ------- + int, optional + Parsed non-negative level when the component represents a pyramid level; + otherwise ``None``. + """ + token = str(component).strip() + if not token: + return None + + suffix_text: Optional[str] = None + if token.startswith("level_"): + suffix_text = token.split("_", maxsplit=1)[1] + elif full_resolution_source == "data" and token.startswith("data_pyramid/level_"): + suffix_text = token.split("data_pyramid/level_", maxsplit=1)[1] + else: + prefix = f"{full_resolution_source}_pyramid/level_" + if token.startswith(prefix): + suffix_text = token.split(prefix, maxsplit=1)[1] + + if suffix_text is None: + return None + suffix_text = str(suffix_text).split("/", maxsplit=1)[0] + try: + parsed = int(suffix_text) + except Exception: + return None + if parsed < 0: + return None + return parsed + + +def _full_resolution_source_component(source_component: str) -> str: + """Return the full-resolution source component for a requested path. + + Parameters + ---------- + source_component : str + Requested source component path. + + Returns + ------- + str + Corresponding full-resolution component path. + """ + requested = str(source_component).strip() or "data" + if requested.startswith("data_pyramid/level_"): + return "data" + if "_pyramid/level_" in requested: + return requested.split("_pyramid/level_", maxsplit=1)[0] + return requested + + +def _component_channel_count( + *, + root: Any, + component: str, +) -> Optional[int]: + """Return channel count for a 6D-like source component when available. + + Parameters + ---------- + root : Any + Open Zarr root group. + component : str + Candidate component path. + + Returns + ------- + int, optional + Positive channel count, or ``None`` if unavailable. + """ + candidate = str(component).strip() + if not candidate or not _zarr_component_exists_in_root(root, candidate): + return None + try: + source = root[candidate] + except Exception: + return None + + shape = getattr(source, "shape", None) + if not isinstance(shape, (tuple, list)) or len(shape) < 3: + return None + try: + count = int(shape[2]) + except Exception: + return None + if count <= 0: + return None + return int(count) + + +def _discover_component_resolution_levels( + *, + root: Any, + source_component: str, +) -> tuple[int, ...]: + """Discover available resolution levels for one source component. + + Parameters + ---------- + root : Any + Open Zarr root group. + source_component : str + Resolved source component path requested by the user. + + Returns + ------- + tuple[int, ...] + Sorted unique resolution levels available for the requested source. + + Notes + ----- + This helper is GUI-facing and intentionally tolerant to missing metadata. + Resolution level ``0`` is always retained as a safe fallback. + """ + requested = str(source_component).strip() or "data" + full_resolution_source = _full_resolution_source_component(requested) + + available_levels: set[int] = set() + if _zarr_component_exists_in_root(root, full_resolution_source): + available_levels.add(0) + + direct_level = _parse_source_component_level( + component=requested, + full_resolution_source=full_resolution_source, + ) + if direct_level is not None and _zarr_component_exists_in_root(root, requested): + available_levels.add(int(direct_level)) + + pyramid_group_component = ( + "data_pyramid" + if full_resolution_source == "data" + else f"{full_resolution_source}_pyramid" + ) + if _zarr_component_exists_in_root(root, pyramid_group_component): + try: + pyramid_group = root[pyramid_group_component] + except Exception: + pyramid_group = None + if pyramid_group is not None: + member_keys: set[str] = set() + try: + member_keys.update( + str(key).strip() for key in pyramid_group.group_keys() + ) + except Exception: + pass + try: + member_keys.update( + str(key).strip() for key in pyramid_group.array_keys() + ) + except Exception: + pass + for key in member_keys: + parsed_level = _parse_source_component_level( + component=key, + full_resolution_source=full_resolution_source, + ) + if parsed_level is None: + continue + component = f"{pyramid_group_component}/{key}" + if _zarr_component_exists_in_root(root, component): + available_levels.add(int(parsed_level)) + + metadata_components: list[str] = [] + if full_resolution_source == "data": + try: + data_levels = root.attrs.get("data_pyramid_levels") + except Exception: + data_levels = None + if isinstance(data_levels, (tuple, list)): + metadata_components.extend(str(value).strip() for value in data_levels) + try: + source_node = root[full_resolution_source] + except Exception: + source_node = None + if source_node is not None: + try: + pyramid_levels = source_node.attrs.get("pyramid_levels") + except Exception: + pyramid_levels = None + if isinstance(pyramid_levels, (tuple, list)): + metadata_components.extend(str(value).strip() for value in pyramid_levels) + for component in metadata_components: + if not component: + continue + parsed_level = _parse_source_component_level( + component=component, + full_resolution_source=full_resolution_source, + ) + if parsed_level is None: + continue + if _zarr_component_exists_in_root(root, component): + available_levels.add(int(parsed_level)) + + if not available_levels: + available_levels.add(0) + return tuple(sorted(available_levels)) + + +def _discover_component_channels( + *, + root: Any, + source_component: str, +) -> tuple[int, ...]: + """Discover available channel indices for one source component. + + Parameters + ---------- + root : Any + Open Zarr root group. + source_component : str + Resolved source component path requested by the user. + + Returns + ------- + tuple[int, ...] + Sorted channel indices available for the requested source. + + Notes + ----- + This helper is GUI-facing and intentionally tolerant to missing metadata. + Channel index ``0`` is always retained as a safe fallback. + """ + requested = str(source_component).strip() or "data" + full_resolution_source = _full_resolution_source_component(requested) + + channel_count = _component_channel_count( + root=root, component=full_resolution_source + ) + if channel_count is None: + channel_count = _component_channel_count(root=root, component=requested) + if channel_count is None: + channel_count = _component_channel_count(root=root, component="data") + + if channel_count is None: + return (0,) + return tuple(int(index) for index in range(int(channel_count))) + + def _build_input_source_options( *, operation_name: str, @@ -2014,8 +2269,7 @@ def _format_spatial_calibration_summary(config: SpatialCalibrationConfig) -> str Human-readable summary for setup and analysis dialogs. """ binding_labels = { - binding: label - for label, binding in _spatial_calibration_binding_choices() + binding: label for label, binding in _spatial_calibration_binding_choices() } lines = [ f"Canonical: {format_spatial_calibration(config)}", @@ -2026,8 +2280,7 @@ def _format_spatial_calibration_summary(config: SpatialCalibrationConfig) -> str strict=False, ): lines.append( - f"World {axis_name.upper()}: " - f"{binding_labels.get(binding, binding)}" + f"World {axis_name.upper()}: " f"{binding_labels.get(binding, binding)}" ) lines.append("Theta: Rotate Z/Y about world X") return "\n".join(lines) @@ -2889,12 +3142,8 @@ def _build_ui(self) -> None: self._defaults_button = _configure_fixed_height_button( QPushButton("Reset Defaults") ) - self._cancel_button = _configure_fixed_height_button( - QPushButton("Cancel") - ) - self._apply_button = _configure_fixed_height_button( - QPushButton("Apply") - ) + self._cancel_button = _configure_fixed_height_button(QPushButton("Cancel")) + self._apply_button = _configure_fixed_height_button(QPushButton("Apply")) self._apply_button.setObjectName("runButton") footer.addWidget(self._defaults_button) footer.addStretch(1) @@ -3107,12 +3356,8 @@ def _build_ui(self) -> None: self._defaults_button = _configure_fixed_height_button( QPushButton("Reset Identity") ) - self._cancel_button = _configure_fixed_height_button( - QPushButton("Cancel") - ) - self._apply_button = _configure_fixed_height_button( - QPushButton("Apply") - ) + self._cancel_button = _configure_fixed_height_button(QPushButton("Cancel")) + self._apply_button = _configure_fixed_height_button(QPushButton("Apply")) self._apply_button.setObjectName("runButton") footer.addWidget(self._defaults_button) footer.addStretch(1) @@ -3311,12 +3556,8 @@ def _build_ui(self) -> None: self._defaults_button = _configure_fixed_height_button( QPushButton("Reset Defaults") ) - self._cancel_button = _configure_fixed_height_button( - QPushButton("Cancel") - ) - self._apply_button = _configure_fixed_height_button( - QPushButton("Apply") - ) + self._cancel_button = _configure_fixed_height_button(QPushButton("Cancel")) + self._apply_button = _configure_fixed_height_button(QPushButton("Apply")) self._apply_button.setObjectName("runButton") footer.addWidget(self._defaults_button) footer.addStretch(1) @@ -3388,7 +3629,10 @@ def _build_local_cluster_page(self) -> QWidget: self._local_recommendation_label.setWordWrap(True) self._local_recommendation_label.setObjectName("metadataFieldValue") self._local_recommendation_label.setMinimumHeight( - max(28, int(self._local_recommendation_label.fontMetrics().height()) + 10) + max( + 28, + int(self._local_recommendation_label.fontMetrics().height()) + 10, + ) ) form.addRow("", self._local_recommendation_label) return page @@ -4838,9 +5082,7 @@ def _build_ui(self) -> None: spatial_button_row = QHBoxLayout() apply_row_spacing(spatial_button_row) spatial_button_row.addStretch(1) - self._spatial_calibration_button = QPushButton( - "Edit Spatial Calibration" - ) + self._spatial_calibration_button = QPushButton("Edit Spatial Calibration") spatial_button_row.addWidget(self._spatial_calibration_button) spatial_layout.addLayout(spatial_button_row) root.addWidget(spatial_group) @@ -7386,9 +7628,7 @@ def __init__(self, initial: WorkflowConfig) -> None: self._DEFAULT_REGISTRATION_PARAMETERS, ) ) - self._operation_defaults["registration"] = dict( - self._registration_defaults - ) + self._operation_defaults["registration"] = dict(self._registration_defaults) self._visualization_defaults = dict( self._operation_defaults.get("visualization", {}) ) @@ -7495,7 +7735,10 @@ def _build_analysis_scope_panel(self) -> QGroupBox: self._analysis_scope_summary_label.setObjectName("statusLabel") self._analysis_scope_summary_label.setWordWrap(True) self._analysis_scope_summary_label.setMinimumHeight( - max(28, int(self._analysis_scope_summary_label.fontMetrics().height()) + 10) + max( + 28, + int(self._analysis_scope_summary_label.fontMetrics().height()) + 10, + ) ) layout.addWidget(self._analysis_scope_summary_label) @@ -7507,7 +7750,10 @@ def _build_analysis_scope_panel(self) -> QGroupBox: self._analysis_state_source_label.setObjectName("statusLabel") self._analysis_state_source_label.setWordWrap(True) self._analysis_state_source_label.setMinimumHeight( - max(28, int(self._analysis_state_source_label.fontMetrics().height()) + 10) + max( + 28, + int(self._analysis_state_source_label.fontMetrics().height()) + 10, + ) ) restore_row.addWidget(self._analysis_state_source_label, 1) @@ -8499,6 +8745,10 @@ def _build_operation_panel(self, operation_name: str) -> QWidget: input_combo.currentIndexChanged.connect( self._on_visualization_input_source_changed ) + if operation_name == "registration": + input_combo.currentIndexChanged.connect( + self._on_registration_input_source_changed + ) if operation_name == "deconvolution": self._build_deconvolution_parameter_rows(form) @@ -9670,11 +9920,11 @@ def _build_registration_parameter_rows(self, form: QFormLayout) -> None: pairwise_section, pairwise_form = self._build_parameter_section_card( "Pairwise Registration" ) - self._registration_channel_spin = QSpinBox() - self._registration_channel_spin.setRange(0, 0) - pairwise_form.addRow("channel", self._registration_channel_spin) + self._registration_channel_combo = QComboBox() + self._registration_channel_combo.addItem("Channel 0", 0) + pairwise_form.addRow("channel", self._registration_channel_combo) self._register_parameter_hint( - self._registration_channel_spin, + self._registration_channel_combo, self._PARAMETER_HINTS["registration_channel"], ) @@ -9688,14 +9938,14 @@ def _build_registration_parameter_rows(self, form: QFormLayout) -> None: self._PARAMETER_HINTS["registration_type"], ) - self._registration_resolution_level_spin = QSpinBox() - self._registration_resolution_level_spin.setRange(0, 0) + self._registration_resolution_level_combo = QComboBox() + self._registration_resolution_level_combo.addItem("Level 0", 0) pairwise_form.addRow( "resolution level", - self._registration_resolution_level_spin, + self._registration_resolution_level_combo, ) self._register_parameter_hint( - self._registration_resolution_level_spin, + self._registration_resolution_level_combo, self._PARAMETER_HINTS["registration_resolution_level"], ) form.addRow(pairwise_section) @@ -9704,12 +9954,8 @@ def _build_registration_parameter_rows(self, form: QFormLayout) -> None: "Global Optimization" ) self._registration_anchor_mode_combo = QComboBox() - self._registration_anchor_mode_combo.addItem( - "Central tile", "central" - ) - self._registration_anchor_mode_combo.addItem( - "Manual tile", "manual" - ) + self._registration_anchor_mode_combo.addItem("Central tile", "central") + self._registration_anchor_mode_combo.addItem("Manual tile", "manual") global_form.addRow("anchor mode", self._registration_anchor_mode_combo) self._register_parameter_hint( self._registration_anchor_mode_combo, @@ -9718,8 +9964,9 @@ def _build_registration_parameter_rows(self, form: QFormLayout) -> None: self._registration_anchor_position_spin = QSpinBox() self._registration_anchor_position_spin.setRange(0, 0) + self._registration_anchor_position_label = QLabel("anchor tile") global_form.addRow( - "anchor tile", + self._registration_anchor_position_label, self._registration_anchor_position_spin, ) self._register_parameter_hint( @@ -10155,9 +10402,12 @@ def _sync_visualization_input_source_from_volume_layers(self) -> None: ).strip() or "data" ) - selected_value = str( - analysis_operation_for_output_component(component) or component - ).strip() or "data" + selected_value = ( + str( + analysis_operation_for_output_component(component) or component + ).strip() + or "data" + ) combo.blockSignals(True) combo_index = combo.findData(selected_value) if combo_index < 0: @@ -10230,6 +10480,196 @@ def _on_visualization_input_source_changed(self, _index: int) -> None: refresh_summary=True ) + def _registration_selected_input_source(self) -> str: + """Return the registration input-source selector value. + + Parameters + ---------- + None + + Returns + ------- + str + Selected registration input-source alias or component path. + """ + combo = self._operation_input_combos.get("registration") + if combo is None: + return "data" + selected = combo.currentData() + return (str(selected).strip() if selected is not None else "data") or "data" + + def _registration_available_resolution_levels(self) -> tuple[int, ...]: + """Return available registration resolution levels for current source. + + Parameters + ---------- + None + + Returns + ------- + tuple[int, ...] + Sorted level values available for the selected registration input. + """ + store_path = str(self._base_config.file or "").strip() + if not store_path or not is_zarr_store_path(store_path): + return (0,) + + requested_source = self._registration_selected_input_source() + resolved_source = resolve_analysis_input_component(requested_source) + try: + root = zarr.open_group(store_path, mode="r") + except Exception: + return (0,) + return _discover_component_resolution_levels( + root=root, + source_component=resolved_source, + ) + + def _registration_available_channels(self) -> tuple[int, ...]: + """Return available registration channel indices for current source. + + Parameters + ---------- + None + + Returns + ------- + tuple[int, ...] + Sorted channel values available for the selected registration input. + """ + store_path = str(self._base_config.file or "").strip() + if not store_path or not is_zarr_store_path(store_path): + return (0,) + + requested_source = self._registration_selected_input_source() + resolved_source = resolve_analysis_input_component(requested_source) + try: + root = zarr.open_group(store_path, mode="r") + except Exception: + return (0,) + return _discover_component_channels( + root=root, + source_component=resolved_source, + ) + + def _refresh_registration_channel_options( + self, + *, + preferred_channel: Optional[int] = None, + ) -> None: + """Refresh registration channel combo options. + + Parameters + ---------- + preferred_channel : int, optional + Preferred selected channel after options are rebuilt. + + Returns + ------- + None + Combo items and selected index are updated in-place. + """ + combo = getattr(self, "_registration_channel_combo", None) + if combo is None: + return + + current_data = combo.currentData() + try: + current_channel = max(0, int(current_data)) + except (TypeError, ValueError): + current_channel = 0 + + if preferred_channel is None: + target_channel = current_channel + else: + target_channel = max(0, int(preferred_channel)) + + available_channels = self._registration_available_channels() + selected_channel = int(available_channels[0]) + for channel in available_channels: + if int(channel) <= target_channel: + selected_channel = int(channel) + + combo.blockSignals(True) + combo.clear() + for channel in available_channels: + channel_value = int(channel) + combo.addItem(f"Channel {channel_value}", channel_value) + selected_index = combo.findData(int(selected_channel)) + if selected_index < 0: + selected_index = 0 + combo.setCurrentIndex(selected_index) + combo.blockSignals(False) + + def _refresh_registration_resolution_level_options( + self, + *, + preferred_level: Optional[int] = None, + ) -> None: + """Refresh registration resolution-level combo options. + + Parameters + ---------- + preferred_level : int, optional + Preferred selected level after options are rebuilt. + + Returns + ------- + None + Combo items and selected index are updated in-place. + """ + combo = getattr(self, "_registration_resolution_level_combo", None) + if combo is None: + return + + current_data = combo.currentData() + try: + current_level = max( + 0, + int(current_data), + ) + except (TypeError, ValueError): + current_level = 0 + + if preferred_level is None: + target_level = current_level + else: + target_level = max(0, int(preferred_level)) + + available_levels = self._registration_available_resolution_levels() + selected_level = int(available_levels[0]) + for level in available_levels: + if int(level) <= target_level: + selected_level = int(level) + + combo.blockSignals(True) + combo.clear() + for level in available_levels: + level_value = int(level) + combo.addItem(f"Level {level_value}", level_value) + selected_index = combo.findData(int(selected_level)) + if selected_index < 0: + selected_index = 0 + combo.setCurrentIndex(selected_index) + combo.blockSignals(False) + + def _on_registration_input_source_changed(self, _index: int) -> None: + """Refresh registration source-dependent choices after input changes. + + Parameters + ---------- + _index : int + Selected combo index (unused). + + Returns + ------- + None + Registration channel and resolution options are rebuilt in-place. + """ + self._refresh_registration_channel_options() + self._refresh_registration_channel_options() + self._refresh_registration_resolution_level_options() + def _refresh_visualization_volume_layers_summary(self) -> None: """Refresh summary text for visualization volume-layer rows.""" label = self._visualization_volume_layers_summary_label @@ -11400,6 +11840,8 @@ def _refresh_input_source_options(self) -> None: combo.setEnabled(self._operation_checkboxes[operation_name].isChecked()) combo.blockSignals(False) + self._refresh_registration_resolution_level_options() + def _has_particle_detection_output(self) -> bool: """Return whether the active store already contains detections. @@ -12626,11 +13068,13 @@ def _set_registration_parameter_enabled_state(self) -> None: None Widget enabled states are updated in-place. """ - registration_enabled = self._operation_checkboxes["registration"].isChecked() + registration_enabled = self._operation_checkboxes[ + "registration" + ].isChecked() widgets = ( - self._registration_channel_spin, + self._registration_channel_combo, self._registration_type_combo, - self._registration_resolution_level_spin, + self._registration_resolution_level_combo, self._registration_anchor_mode_combo, self._registration_overlap_z_spin, self._registration_overlap_y_spin, @@ -12648,6 +13092,8 @@ def _set_registration_parameter_enabled_state(self) -> None: ) manual_anchor = registration_enabled and anchor_mode == "manual" self._registration_anchor_position_spin.setEnabled(manual_anchor) + self._registration_anchor_position_spin.setVisible(manual_anchor) + self._registration_anchor_position_label.setVisible(manual_anchor) def _set_flatfield_parameter_enabled_state(self) -> None: """Enable/disable flatfield widgets based on selection and overlap mode. @@ -12960,10 +13406,8 @@ def _hydrate(self, initial: WorkflowConfig) -> None: self._usegment3d_resolution_level_spin.setMaximum( max(0, int(max_usegment3d_resolution_level)) ) - self._registration_channel_spin.setMaximum(channel_count - 1) - self._registration_resolution_level_spin.setMaximum( - max(0, int(max_usegment3d_resolution_level)) - ) + self._refresh_registration_channel_options() + self._refresh_registration_resolution_level_options() self._registration_anchor_position_spin.setMaximum(position_count - 1) requested_usegment_channels = usegment3d_params.get("channel_indices", []) if isinstance(requested_usegment_channels, str): @@ -13554,13 +13998,10 @@ def _hydrate(self, initial: WorkflowConfig) -> None: dtform_index = self._usegment3d_postprocess_dtform_combo.count() - 1 self._usegment3d_postprocess_dtform_combo.setCurrentIndex(dtform_index) - self._registration_channel_spin.setValue( - max( + self._refresh_registration_channel_options( + preferred_channel=max( 0, - min( - int(self._registration_channel_spin.maximum()), - int(registration_params.get("registration_channel", 0)), - ), + int(registration_params.get("registration_channel", 0)), ) ) registration_type = ( @@ -13584,17 +14025,12 @@ def _hydrate(self, initial: WorkflowConfig) -> None: 0, int(registration_params.get("input_resolution_level", 0)), ) - self._registration_resolution_level_spin.setValue( - min( - int(self._registration_resolution_level_spin.maximum()), - int(registration_resolution_level), - ) + self._refresh_registration_resolution_level_options( + preferred_level=int(registration_resolution_level) ) anchor_mode = ( - str(registration_params.get("anchor_mode", "central")) - .strip() - .lower() + str(registration_params.get("anchor_mode", "central")).strip().lower() or "central" ) if anchor_mode not in {"central", "manual"}: @@ -13623,7 +14059,9 @@ def _hydrate(self, initial: WorkflowConfig) -> None: ), ) ) - registration_overlap_zyx = registration_params.get("overlap_zyx", [8, 32, 32]) + registration_overlap_zyx = registration_params.get( + "overlap_zyx", [8, 32, 32] + ) if ( not isinstance(registration_overlap_zyx, (tuple, list)) or len(registration_overlap_zyx) != 3 @@ -14505,6 +14943,19 @@ def _collect_registration_parameters(self) -> Dict[str, Any]: else: anchor_position = None + resolution_level_data = ( + self._registration_resolution_level_combo.currentData() + ) + try: + input_resolution_level = max(0, int(resolution_level_data)) + except (TypeError, ValueError): + input_resolution_level = 0 + channel_data = self._registration_channel_combo.currentData() + try: + registration_channel = max(0, int(channel_data)) + except (TypeError, ValueError): + registration_channel = 0 + return { "chunk_basis": "3d", "detect_2d_per_slice": False, @@ -14517,14 +14968,12 @@ def _collect_registration_parameters(self) -> Dict[str, Any]: "memory_overhead_factor": float( self._registration_defaults.get("memory_overhead_factor", 2.5) ), - "registration_channel": int(self._registration_channel_spin.value()), + "registration_channel": int(registration_channel), "registration_type": str( self._registration_type_combo.currentData() or "rigid" ).strip() or "rigid", - "input_resolution_level": int( - self._registration_resolution_level_spin.value() - ), + "input_resolution_level": int(input_resolution_level), "anchor_mode": anchor_mode, "anchor_position": anchor_position, "blend_mode": str( diff --git a/tests/gui/test_gui_execution.py b/tests/gui/test_gui_execution.py index 08ad4f8..561696a 100644 --- a/tests/gui/test_gui_execution.py +++ b/tests/gui/test_gui_execution.py @@ -12,6 +12,7 @@ import clearex.gui.app as app_module from clearex.io.experiment import NavigateChannel, NavigateExperiment import pytest +import zarr def _make_navigate_experiment(path: Path) -> NavigateExperiment: @@ -43,6 +44,47 @@ def _make_navigate_experiment(path: Path) -> NavigateExperiment: ) +def _create_gui_analysis_store(tmp_path: Path) -> Path: + """Create a minimal analysis store for GUI parameter-selection tests.""" + store_path = tmp_path / "analysis_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + data_shape = (1, 1, 2, 1, 4, 4) + data_chunks = (1, 1, 2, 1, 4, 4) + root.create_dataset( + name="data", + shape=data_shape, + chunks=data_chunks, + dtype="uint16", + overwrite=True, + ) + data_pyramid = root.require_group("data_pyramid") + data_pyramid.create_dataset( + name="level_1", + shape=(1, 1, 2, 1, 2, 2), + chunks=(1, 1, 2, 1, 2, 2), + dtype="uint16", + overwrite=True, + ) + root.attrs["data_pyramid_levels"] = ["data", "data_pyramid/level_1"] + root.attrs["data_pyramid_factors_tpczyx"] = [ + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 2, 2, 2], + ] + shear_latest = ( + root.require_group("results") + .require_group("shear_transform") + .require_group("latest") + ) + shear_latest.create_dataset( + name="data", + shape=(1, 1, 1, 1, 4, 4), + chunks=(1, 1, 1, 1, 4, 4), + dtype="uint16", + overwrite=True, + ) + return store_path + + def _install_fake_gui_runtime(monkeypatch): """Install fake Qt primitives for deterministic run-loop tests. @@ -298,6 +340,7 @@ def test_registration_operation_moves_to_preprocessing_tab() -> None: def test_analysis_dialog_persists_registration_parameters( monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, ) -> None: if not app_module.HAS_PYQT6: return @@ -312,8 +355,9 @@ def test_analysis_dialog_persists_registration_parameters( lambda _config: None, ) + store_path = _create_gui_analysis_store(tmp_path) dialog = app_module.AnalysisSelectionDialog( - initial=app_module.WorkflowConfig(file="/tmp/test/data_store.fake") + initial=app_module.WorkflowConfig(file=str(store_path)) ) dialog._persist_analysis_gui_state_for_target = lambda _target: None dialog._operation_checkboxes["registration"].setChecked(True) @@ -331,11 +375,13 @@ def test_analysis_dialog_persists_registration_parameters( dialog._registration_blend_mode_combo.setCurrentIndex( dialog._registration_blend_mode_combo.findData("average") ) - dialog._registration_resolution_level_spin.setMaximum(3) - dialog._registration_resolution_level_spin.setValue(1) - app_module.AnalysisSelectionDialog._set_registration_parameter_enabled_state( - dialog - ) + channel_1_index = dialog._registration_channel_combo.findData(1) + assert channel_1_index >= 0 + dialog._registration_channel_combo.setCurrentIndex(channel_1_index) + level_1_index = dialog._registration_resolution_level_combo.findData(1) + assert level_1_index >= 0 + dialog._registration_resolution_level_combo.setCurrentIndex(level_1_index) + app_module.AnalysisSelectionDialog._set_registration_parameter_enabled_state(dialog) dialog._on_run() @@ -346,9 +392,127 @@ def test_analysis_dialog_persists_registration_parameters( assert params["anchor_position"] == 2 assert params["overlap_zyx"] == [5, 12, 20] assert params["blend_mode"] == "average" + assert params["registration_channel"] == 1 assert params["input_resolution_level"] == 1 +def test_registration_resolution_levels_follow_selected_input_source( + tmp_path: Path, +) -> None: + if not app_module.HAS_PYQT6: + return + + app = app_module.QApplication.instance() + if app is None: + app = app_module.QApplication([]) + + store_path = _create_gui_analysis_store(tmp_path) + dialog = app_module.AnalysisSelectionDialog( + initial=app_module.WorkflowConfig(file=str(store_path)) + ) + dialog._persist_analysis_gui_state_for_target = lambda _target: None + dialog._operation_checkboxes["registration"].setChecked(True) + input_combo = dialog._operation_input_combos["registration"] + level_combo = dialog._registration_resolution_level_combo + + data_index = input_combo.findData("data") + assert data_index >= 0 + input_combo.setCurrentIndex(data_index) + app.processEvents() + assert [int(level_combo.itemData(idx)) for idx in range(level_combo.count())] == [ + 0, + 1, + ] + + level_1_index = level_combo.findData(1) + assert level_1_index >= 0 + level_combo.setCurrentIndex(level_1_index) + + shear_index = input_combo.findData("shear_transform") + assert shear_index >= 0 + input_combo.setCurrentIndex(shear_index) + app.processEvents() + assert [int(level_combo.itemData(idx)) for idx in range(level_combo.count())] == [0] + assert int(level_combo.currentData()) == 0 + params = dialog._collect_registration_parameters() + assert params["input_resolution_level"] == 0 + + +def test_registration_channels_follow_selected_input_source( + tmp_path: Path, +) -> None: + if not app_module.HAS_PYQT6: + return + + app = app_module.QApplication.instance() + if app is None: + app = app_module.QApplication([]) + + store_path = _create_gui_analysis_store(tmp_path) + dialog = app_module.AnalysisSelectionDialog( + initial=app_module.WorkflowConfig(file=str(store_path)) + ) + dialog._persist_analysis_gui_state_for_target = lambda _target: None + dialog._operation_checkboxes["registration"].setChecked(True) + input_combo = dialog._operation_input_combos["registration"] + channel_combo = dialog._registration_channel_combo + + data_index = input_combo.findData("data") + assert data_index >= 0 + input_combo.setCurrentIndex(data_index) + app.processEvents() + assert [ + int(channel_combo.itemData(idx)) for idx in range(channel_combo.count()) + ] == [0, 1] + + channel_1_index = channel_combo.findData(1) + assert channel_1_index >= 0 + channel_combo.setCurrentIndex(channel_1_index) + + shear_index = input_combo.findData("shear_transform") + assert shear_index >= 0 + input_combo.setCurrentIndex(shear_index) + app.processEvents() + assert [ + int(channel_combo.itemData(idx)) for idx in range(channel_combo.count()) + ] == [0] + assert int(channel_combo.currentData()) == 0 + params = dialog._collect_registration_parameters() + assert params["registration_channel"] == 0 + + +def test_registration_anchor_tile_visibility_follows_anchor_mode( + tmp_path: Path, +) -> None: + if not app_module.HAS_PYQT6: + return + + app = app_module.QApplication.instance() + if app is None: + app = app_module.QApplication([]) + + store_path = _create_gui_analysis_store(tmp_path) + dialog = app_module.AnalysisSelectionDialog( + initial=app_module.WorkflowConfig(file=str(store_path)) + ) + dialog._persist_analysis_gui_state_for_target = lambda _target: None + dialog._operation_checkboxes["registration"].setChecked(True) + + dialog._registration_anchor_mode_combo.setCurrentIndex( + dialog._registration_anchor_mode_combo.findData("central") + ) + app_module.AnalysisSelectionDialog._set_registration_parameter_enabled_state(dialog) + assert dialog._registration_anchor_position_label.isHidden() is True + assert dialog._registration_anchor_position_spin.isHidden() is True + + dialog._registration_anchor_mode_combo.setCurrentIndex( + dialog._registration_anchor_mode_combo.findData("manual") + ) + app_module.AnalysisSelectionDialog._set_registration_parameter_enabled_state(dialog) + assert dialog._registration_anchor_position_label.isHidden() is False + assert dialog._registration_anchor_position_spin.isHidden() is False + + def test_zarr_dialog_scrolls_body_on_short_screens(monkeypatch) -> None: if not app_module.HAS_PYQT6: return From c978246ba716533907a477f47dbf753ccf412cfc Mon Sep 17 00:00:00 2001 From: Kevin Dean Date: Sat, 21 Mar 2026 16:27:53 -0500 Subject: [PATCH 3/4] Rename display pyramid UX, source-adjacent pyramids, and format codebase with Black --- docs/source/conf.py | 1 - .../registration/image_registration_class.py | 1 - .../registration/register_round_function.py | 1 - setup.py | 1 - src/clearex/context/__init__.py | 1 - src/clearex/detect/pipeline.py | 21 +- src/clearex/file_operations/__init__.py | 1 - src/clearex/filter/__init__.py | 1 - src/clearex/filter/filters.py | 4 +- src/clearex/filter/kernels.py | 2 +- src/clearex/flatfield/pipeline.py | 162 +++++++--- src/clearex/gui/app.py | 26 +- src/clearex/io/experiment.py | 95 +++--- src/clearex/io/log.py | 1 - src/clearex/io/provenance.py | 20 +- src/clearex/main.py | 43 ++- src/clearex/mip_export/pipeline.py | 32 +- src/clearex/preprocess/pad.py | 2 +- src/clearex/registration/pipeline.py | 286 +++++++++++++----- src/clearex/registration/tre.py | 1 - src/clearex/segmentation/pointsource.py | 7 +- src/clearex/shear/__init__.py | 1 - src/clearex/shear/pipeline.py | 26 +- src/clearex/stats/__init__.py | 2 +- src/clearex/usegment3d/pipeline.py | 9 +- src/clearex/visualization/pipeline.py | 166 +++++++--- src/clearex/workflow.py | 67 ++-- tests/__init__.py | 6 +- tests/conftest.py | 1 - tests/flatfield/test_pipeline.py | 79 +++-- tests/gui/test_gui_execution.py | 13 +- tests/io/test_experiment.py | 108 ++++--- tests/io/test_provenance.py | 9 +- tests/io/test_read.py | 1 - tests/mip_export/test_pipeline.py | 32 +- tests/registration/test_image_registration.py | 15 +- tests/registration/test_pipeline.py | 51 +++- tests/segmentation/__init__.py | 1 - tests/segmentation/test_pointsource.py | 7 +- tests/shear/test_pipeline.py | 24 +- tests/test_main.py | 15 +- tests/test_workflow.py | 18 +- tests/usegment3d/test_pipeline.py | 4 +- tests/visualization/test_pipeline.py | 61 +++- 44 files changed, 982 insertions(+), 443 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c80a801..4b408ef 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,6 @@ import sys import types - sys.path.insert(0, os.path.abspath("../../src")) diff --git a/examples/scripts/registration/image_registration_class.py b/examples/scripts/registration/image_registration_class.py index 158bbb9..66c0a1a 100644 --- a/examples/scripts/registration/image_registration_class.py +++ b/examples/scripts/registration/image_registration_class.py @@ -105,4 +105,3 @@ def main_class_based(): if __name__ == "__main__": main_class_based() - diff --git a/examples/scripts/registration/register_round_function.py b/examples/scripts/registration/register_round_function.py index 568c041..43e4dbd 100644 --- a/examples/scripts/registration/register_round_function.py +++ b/examples/scripts/registration/register_round_function.py @@ -73,4 +73,3 @@ def main_functional(): if __name__ == "__main__": main_functional() - diff --git a/setup.py b/setup.py index ccd594b..8bbb566 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,6 @@ from setuptools import setup - if sys.version_info[:2] != (3, 12): detected = ".".join(str(part) for part in sys.version_info[:3]) raise RuntimeError( diff --git a/src/clearex/context/__init__.py b/src/clearex/context/__init__.py index 59465b8..3074984 100755 --- a/src/clearex/context/__init__.py +++ b/src/clearex/context/__init__.py @@ -23,4 +23,3 @@ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - diff --git a/src/clearex/detect/pipeline.py b/src/clearex/detect/pipeline.py index 2a80c63..b39320c 100644 --- a/src/clearex/detect/pipeline.py +++ b/src/clearex/detect/pipeline.py @@ -54,7 +54,14 @@ DetectionsArray = np.ndarray -RegionBounds = tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]] +RegionBounds = tuple[ + tuple[int, int], + tuple[int, int], + tuple[int, int], + tuple[int, int], + tuple[int, int], + tuple[int, int], +] ProgressCallback = Callable[[int, str], None] _PARTICLE_COLUMNS = ( @@ -134,9 +141,7 @@ def _normalize_particle_parameters( normalized["remove_close_particles"] = bool( normalized.get("remove_close_particles", False) ) - normalized["min_distance_sigma"] = float( - normalized.get("min_distance_sigma", 10.0) - ) + normalized["min_distance_sigma"] = float(normalized.get("min_distance_sigma", 10.0)) overlap_zyx = normalized.get("overlap_zyx", [0, 0, 0]) if not isinstance(overlap_zyx, (tuple, list)) or len(overlap_zyx) != 3: raise ValueError( @@ -162,12 +167,13 @@ def _axis_chunk_bounds(size: int, chunk_size: int) -> list[tuple[int, int]]: Ordered ``(start, stop)`` bounds covering the full axis. """ return [ - (start, min(start + chunk_size, size)) - for start in range(0, size, chunk_size) + (start, min(start + chunk_size, size)) for start in range(0, size, chunk_size) ] -def _region_to_slices(region: RegionBounds) -> tuple[slice, slice, slice, slice, slice, slice]: +def _region_to_slices( + region: RegionBounds, +) -> tuple[slice, slice, slice, slice, slice, slice]: """Convert six-axis integer bounds into Python slices. Parameters @@ -635,6 +641,7 @@ def run_particle_detection_analysis( ValueError If required source data are missing or channel index is out of bounds. """ + def _emit(percent: int, message: str) -> None: if progress_callback is None: return diff --git a/src/clearex/file_operations/__init__.py b/src/clearex/file_operations/__init__.py index 59465b8..3074984 100755 --- a/src/clearex/file_operations/__init__.py +++ b/src/clearex/file_operations/__init__.py @@ -23,4 +23,3 @@ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - diff --git a/src/clearex/filter/__init__.py b/src/clearex/filter/__init__.py index 59465b8..3074984 100755 --- a/src/clearex/filter/__init__.py +++ b/src/clearex/filter/__init__.py @@ -23,4 +23,3 @@ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - diff --git a/src/clearex/filter/filters.py b/src/clearex/filter/filters.py index 2a53d81..cc076df 100755 --- a/src/clearex/filter/filters.py +++ b/src/clearex/filter/filters.py @@ -32,11 +32,11 @@ import numpy as np from skimage import filters as skfilters - # Local Imports + def fwhm_to_sigma(fwhm_px: float) -> float: - """ Convert from FWHM to sigma. + """Convert from FWHM to sigma. FWHM = 2*sqrt(2*ln2)*sigma ≈ 2.35482*sigma diff --git a/src/clearex/filter/kernels.py b/src/clearex/filter/kernels.py index 7b25977..adf072b 100755 --- a/src/clearex/filter/kernels.py +++ b/src/clearex/filter/kernels.py @@ -50,7 +50,7 @@ def make_3d_structured_element(radius: int, shape: str = "sphere") -> np.ndarray radius = int(radius) structured_element = np.zeros((radius, radius, radius)) - (z_len, y_len, x_len) = structured_element.shape + z_len, y_len, x_len = structured_element.shape if shape == "sphere": for i in range(int(z_len)): diff --git a/src/clearex/flatfield/pipeline.py b/src/clearex/flatfield/pipeline.py index 6c42107..44404df 100644 --- a/src/clearex/flatfield/pipeline.py +++ b/src/clearex/flatfield/pipeline.py @@ -36,7 +36,16 @@ from itertools import product import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Sequence, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Mapping, + Optional, + Sequence, + Union, + cast, +) import dask import dask.array as da @@ -388,15 +397,14 @@ def _normalize_parameters(parameters: Mapping[str, Any]) -> dict[str, Any]: raise ValueError("flatfield smoothness_flatfield must be greater than zero.") normalized["working_size"] = max(1, int(normalized.get("working_size", 128))) normalized["is_timelapse"] = bool(normalized.get("is_timelapse", False)) - fit_mode = str(normalized.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + fit_mode = ( + str(normalized.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + ) if fit_mode not in {"tiled", "full_volume"}: raise ValueError("flatfield fit_mode must be 'tiled' or 'full_volume'.") normalized["fit_mode"] = fit_mode fit_tile_shape = normalized.get("fit_tile_shape_yx", [256, 256]) - if ( - not isinstance(fit_tile_shape, (tuple, list)) - or len(fit_tile_shape) != 2 - ): + if not isinstance(fit_tile_shape, (tuple, list)) or len(fit_tile_shape) != 2: raise ValueError( "flatfield fit_tile_shape_yx must define tile sizes in (y, x) order." ) @@ -681,8 +689,7 @@ def _axis_chunk_bounds(size: int, chunk_size: int) -> list[tuple[int, int]]: Ordered ``(start, stop)`` bounds covering the full axis. """ return [ - (start, min(start + chunk_size, size)) - for start in range(0, size, chunk_size) + (start, min(start + chunk_size, size)) for start in range(0, size, chunk_size) ] @@ -816,7 +823,9 @@ def _fit_basic_profile( ) -def _region_to_slices(region: RegionBounds) -> tuple[slice, slice, slice, slice, slice, slice]: +def _region_to_slices( + region: RegionBounds, +) -> tuple[slice, slice, slice, slice, slice, slice]: """Convert integer region bounds into six-axis Python slices. Parameters @@ -1266,6 +1275,7 @@ def _materialize_output_pyramid( ValueError If base output component shape/chunks are not canonical. """ + def _emit(percent: int, message: str) -> None: if progress_callback is None: return @@ -1385,9 +1395,7 @@ def _emit(percent: int, message: str) -> None: { "axes": ["t", "p", "c", "z", "y", "x"], "pyramid_level": int(level_index), - "downsample_factors_tpczyx": [ - int(value) for value in absolute_factors - ], + "downsample_factors_tpczyx": [int(value) for value in absolute_factors], "chunk_shape_tpczyx": [int(value) for value in level_chunks], "source_component": str(level_source_component), } @@ -1460,7 +1468,13 @@ def _checkpoint_dataset_specs( specs: dict[str, tuple[tuple[int, ...], Any]] = { "fit_profile_done_pc": ((p_count, c_count), np.bool_), "transform_done_tpcyx": ( - (t_count, p_count, c_count, int(transform_grid_yx[0]), int(transform_grid_yx[1])), + ( + t_count, + p_count, + c_count, + int(transform_grid_yx[0]), + int(transform_grid_yx[1]), + ), np.bool_, ), } @@ -1471,16 +1485,28 @@ def _checkpoint_dataset_specs( (p_count, c_count, int(fit_grid_yx[0]), int(fit_grid_yx[1])), np.bool_, ), - "fit_baseline_sum_pctz": ((p_count, c_count, t_count, z_count), np.float32), + "fit_baseline_sum_pctz": ( + (p_count, c_count, t_count, z_count), + np.float32, + ), "fit_baseline_count_pc": ((p_count, c_count), np.uint32), } ) if bool(blend_tiles_effective): specs.update( { - "fit_flatfield_sum_pcyx": ((p_count, c_count, y_count, x_count), np.float32), - "fit_darkfield_sum_pcyx": ((p_count, c_count, y_count, x_count), np.float32), - "fit_weight_sum_pcyx": ((p_count, c_count, y_count, x_count), np.float32), + "fit_flatfield_sum_pcyx": ( + (p_count, c_count, y_count, x_count), + np.float32, + ), + "fit_darkfield_sum_pcyx": ( + (p_count, c_count, y_count, x_count), + np.float32, + ), + "fit_weight_sum_pcyx": ( + (p_count, c_count, y_count, x_count), + np.float32, + ), } ) return specs @@ -1534,9 +1560,13 @@ def _checkpoint_is_compatible( int(shape_tpczyx[4]), int(shape_tpczyx[5]), ) - if not _has_dataset(latest_group, name="data", shape=shape_tpczyx, dtype=np.float32): + if not _has_dataset( + latest_group, name="data", shape=shape_tpczyx, dtype=np.float32 + ): return False - if tuple(int(v) for v in latest_group["data"].chunks) != tuple(int(v) for v in chunks_tpczyx): + if tuple(int(v) for v in latest_group["data"].chunks) != tuple( + int(v) for v in chunks_tpczyx + ): return False if not _has_dataset( latest_group, @@ -1567,7 +1597,9 @@ def _checkpoint_is_compatible( checkpoint_attrs = dict(checkpoint_group.attrs) if str(checkpoint_attrs.get("schema_version", "")) != RESUME_SCHEMA_VERSION: return False - if str(checkpoint_attrs.get("parameter_fingerprint", "")) != str(parameter_fingerprint): + if str(checkpoint_attrs.get("parameter_fingerprint", "")) != str( + parameter_fingerprint + ): return False if str(checkpoint_attrs.get("source_component", "")) != str(source_component): return False @@ -1966,7 +1998,9 @@ def _prepare_output_arrays( f"Input component '{source_component}' is incompatible." ) - fit_mode = str(parameters.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + fit_mode = ( + str(parameters.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + ) fit_tile_shape_yx = cast( tuple[int, int], parameters.get("fit_tile_shape_yx", (256, 256)), @@ -1983,7 +2017,9 @@ def _prepare_output_arrays( parameters.get("blend_tiles", False) and parameters.get("use_map_overlap", True) ) parameter_payload = _resume_parameter_payload(parameters) - parameter_json, parameter_fingerprint = _resume_parameter_fingerprint(parameter_payload) + parameter_json, parameter_fingerprint = _resume_parameter_fingerprint( + parameter_payload + ) component = "results/flatfield/latest" data_component = "results/flatfield/latest/data" @@ -2115,7 +2151,9 @@ def _fit_profile( ValueError If the fitted baseline length cannot be reshaped back to ``(t, z)``. """ - fit_mode = str(parameters.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + fit_mode = ( + str(parameters.get("fit_mode", "tiled")).strip().lower().replace("-", "_") + ) if fit_mode == "full_volume": return _fit_profile_full_volume( zarr_path=zarr_path, @@ -2297,7 +2335,9 @@ def _fit_profile_tiled( dtype=np.float32, ) _, _, y_read_count, x_read_count = volume_tzyx.shape - fit_images = volume_tzyx.reshape(t_count * z_count, y_read_count, x_read_count) + fit_images = volume_tzyx.reshape( + t_count * z_count, y_read_count, x_read_count + ) flatfield_tile, darkfield_tile, baseline = _fit_basic_profile( fit_images=fit_images, basic_class=basic_class, @@ -2322,11 +2362,15 @@ def _fit_profile_tiled( flatfield_sum[ y_read_start:y_read_stop, x_read_start:x_read_stop, - ] += flatfield_tile.astype(np.float64) * blend_weights_64 + ] += ( + flatfield_tile.astype(np.float64) * blend_weights_64 + ) darkfield_sum[ y_read_start:y_read_stop, x_read_start:x_read_stop, - ] += darkfield_tile.astype(np.float64) * blend_weights_64 + ] += ( + darkfield_tile.astype(np.float64) * blend_weights_64 + ) weight_sum[ y_read_start:y_read_stop, x_read_start:x_read_stop, @@ -2892,10 +2936,12 @@ def _transform_region( ) corrected = corrected - baseline[:, None, None, :, None, None] - corrected_core = corrected[_crop_slices_for_core( - read_region=read_region, - core_region=core_region, - )].astype(np.float32, copy=False) + corrected_core = corrected[ + _crop_slices_for_core( + read_region=read_region, + core_region=core_region, + ) + ].astype(np.float32, copy=False) write_root = zarr.open_group(str(zarr_path), mode="a") write_root[output_data_component][_region_to_slices(core_region)] = corrected_core @@ -3044,7 +3090,9 @@ def _execute_tasks( if client is None: for completed, task_input in enumerate(task_inputs, start=1): try: - result = dask.compute(build_task(task_input), scheduler="processes")[0] + result = dask.compute( + build_task(task_input), scheduler="processes" + )[0] except Exception as exc: # pragma: no cover - mirrored distributed path if handle_task_error is None or not bool( handle_task_error(task_input, exc) @@ -3063,14 +3111,18 @@ def _execute_tasks( delayed_tasks = [build_task(task_input) for task_input in task_inputs] futures = cast(list[Any], client.compute(delayed_tasks)) - future_to_input = {future: task_inputs[index] for index, future in enumerate(futures)} + future_to_input = { + future: task_inputs[index] for index, future in enumerate(futures) + } completed = 0 for future in as_completed(futures): task_input = future_to_input[future] try: result = future.result() except Exception as exc: - if handle_task_error is None or not bool(handle_task_error(task_input, exc)): + if handle_task_error is None or not bool( + handle_task_error(task_input, exc) + ): raise else: consume_result(task_input, result) @@ -3119,18 +3171,24 @@ def _execute_tasks( profile_count = int(len(profile_pairs)) _emit( 6, - "Resuming existing flatfield checkpoint" - if layout.resumed - else "Initialized fresh flatfield checkpoint", + ( + "Resuming existing flatfield checkpoint" + if layout.resumed + else "Initialized fresh flatfield checkpoint" + ), ) - checkpoint_group = zarr.open_group(str(zarr_path), mode="a")[layout.checkpoint_component] + checkpoint_group = zarr.open_group(str(zarr_path), mode="a")[ + layout.checkpoint_component + ] raw_fallback_records = checkpoint_group.attrs.get("fit_fallback_records", []) fit_fallback_records: list[dict[str, Any]] = [] if isinstance(raw_fallback_records, list): for record in raw_fallback_records: if isinstance(record, Mapping): - fit_fallback_records.append(cast(dict[str, Any], _to_jsonable(dict(record)))) + fit_fallback_records.append( + cast(dict[str, Any], _to_jsonable(dict(record))) + ) checkpoint_group.attrs.update( { "run_status": "running", @@ -3187,7 +3245,9 @@ def _consume_profile_result( profile=profile_result, ) position_index, channel_index = profile_key - fit_profile_done_array[position_index, channel_index] = np.bool_(True) + fit_profile_done_array[position_index, channel_index] = np.bool_( + True + ) _execute_tasks( task_inputs=pending_profile_pairs, @@ -3230,7 +3290,9 @@ def _consume_profile_result( if "position_index" in record and "channel_index" in record } tiles_per_profile = ( - int(len(tile_specs) // max(1, profile_count)) if profile_count > 0 else 0 + int(len(tile_specs) // max(1, profile_count)) + if profile_count > 0 + else 0 ) tile_done_array = checkpoint_group["fit_tile_done_pcyx"] tile_done_mask = np.asarray(tile_done_array, dtype=bool) @@ -3326,7 +3388,9 @@ def _record_profile_fallback( "channel_index": int(profile_key[1]), "fallback_mode": str(fallback_mode), "trigger_error": str(trigger_error), - "retry_error": None if retry_error is None else str(retry_error), + "retry_error": ( + None if retry_error is None else str(retry_error) + ), "recorded_utc": _utc_now_iso(), } fit_fallback_records.append(record) @@ -3423,8 +3487,12 @@ def _consume_tile_result( profile_key = (position_index, channel_index) if profile_key in fallback_profile_keys: return - y_start, y_stop = int(tile_result.y_bounds[0]), int(tile_result.y_bounds[1]) - x_start, x_stop = int(tile_result.x_bounds[0]), int(tile_result.x_bounds[1]) + y_start, y_stop = int(tile_result.y_bounds[0]), int( + tile_result.y_bounds[1] + ) + x_start, x_stop = int(tile_result.x_bounds[0]), int( + tile_result.x_bounds[1] + ) profile_selection = ( slice(position_index, position_index + 1), slice(channel_index, channel_index + 1), @@ -3442,14 +3510,18 @@ def _consume_tile_result( baseline_sum_array[profile_selection], dtype=np.float32, ) - baseline_sum += np.asarray(tile_result.baseline_tz, dtype=np.float32)[ + baseline_sum += np.asarray( + tile_result.baseline_tz, dtype=np.float32 + )[ None, None, :, :, ] baseline_sum_array[profile_selection] = baseline_sum - baseline_count = int(baseline_count_array[position_index, channel_index]) + baseline_count = int( + baseline_count_array[position_index, channel_index] + ) baseline_count_array[position_index, channel_index] = np.uint32( baseline_count + 1 ) diff --git a/src/clearex/gui/app.py b/src/clearex/gui/app.py index ab2cd35..bbcdcff 100644 --- a/src/clearex/gui/app.py +++ b/src/clearex/gui/app.py @@ -7076,10 +7076,10 @@ class AnalysisSelectionDialog(QDialog): "flatfield", "deconvolution", "shear_transform", + "display_pyramid", + "registration", "particle_detection", "usegment3d", - "registration", - "display_pyramid", "visualization", "mip_export", ) @@ -7100,17 +7100,23 @@ class AnalysisSelectionDialog(QDialog): "particle_detection": "Particle Detection", "usegment3d": "uSegment3D", "registration": "Registration", - "display_pyramid": "Display Pyramid", + "display_pyramid": "Pyramidal Downsampling", "visualization": "Napari", "mip_export": "MIP Export", } _OPERATION_TABS: tuple[tuple[str, tuple[str, ...]], ...] = ( ( "Preprocessing", - ("flatfield", "deconvolution", "shear_transform", "registration"), + ( + "flatfield", + "deconvolution", + "shear_transform", + "display_pyramid", + "registration", + ), ), ("Segmentation", ("particle_detection", "usegment3d")), - ("Visualization", ("display_pyramid", "visualization", "mip_export")), + ("Visualization", ("visualization", "mip_export")), ) _OPERATION_OUTPUT_COMPONENTS: Dict[str, str] = { "flatfield": "results/flatfield/latest/data", @@ -7125,7 +7131,7 @@ class AnalysisSelectionDialog(QDialog): "results/particle_detection/latest/detections" ) _DEFAULT_USEGMENT3D_PARAMETERS: Dict[str, Any] = { - "execution_order": 6, + "execution_order": 7, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -7182,7 +7188,7 @@ class AnalysisSelectionDialog(QDialog): "output_dtype": "uint32", } _DEFAULT_REGISTRATION_PARAMETERS: Dict[str, Any] = { - "execution_order": 4, + "execution_order": 5, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -7516,9 +7522,9 @@ class AnalysisSelectionDialog(QDialog): "from multi_positions.yml." ), "display_pyramid": ( - "Prepare reusable display pyramids and stored per-channel 1/95 " - "contrast limits for the selected source component before napari " - "launch." + "Prepare reusable pyramidal downsampling levels and stored " + "per-channel 1/95 contrast limits for the selected source " + "component." ), "use_multiscale": ( "When enabled, napari uses existing display pyramids for 2D " diff --git a/src/clearex/io/experiment.py b/src/clearex/io/experiment.py index 7a23145..3aa9d46 100644 --- a/src/clearex/io/experiment.py +++ b/src/clearex/io/experiment.py @@ -34,7 +34,16 @@ from datetime import datetime, timezone from itertools import islice, product from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + Optional, + Sequence, + Union, +) import json import math import os @@ -350,7 +359,10 @@ def _expected_pyramid_components( None This helper does not raise custom exceptions. """ - return ["data", *[f"data_pyramid/level_{idx}" for idx in range(1, len(level_factors))]] + return [ + "data", + *[f"data_pyramid/level_{idx}" for idx in range(1, len(level_factors))], + ] def _has_expected_pyramid_structure( @@ -589,7 +601,9 @@ def has_complete_canonical_data_store( if not _has_expected_pyramid_structure(root=root, level_factors=level_factors): return False required_components = ( - _expected_pyramid_components(level_factors) if level_factors is not None else ["data"] + _expected_pyramid_components(level_factors) + if level_factors is not None + else ["data"] ) record = _read_ingestion_progress_record(root) @@ -979,6 +993,7 @@ def _parse_navigate_bdv_setup_index_map( None Parsing is best-effort and falls back to ``None``. """ + def _candidate_bdv_xml_paths(path: Path) -> list[Path]: """Return candidate BDV XML sidecar paths for one source path. @@ -1790,8 +1805,7 @@ def _axis_chunk_bounds(size: int, chunk_size: int) -> list[tuple[int, int]]: Ordered ``(start, stop)`` bounds covering the full axis. """ return [ - (start, min(start + chunk_size, size)) - for start in range(0, size, chunk_size) + (start, min(start + chunk_size, size)) for start in range(0, size, chunk_size) ] @@ -1876,7 +1890,9 @@ def _estimate_write_batch_region_count( int Recommended number of chunk regions per submission batch. """ - chunk_bytes = max(1, math.prod(int(value) for value in chunks_tpczyx) * int(dtype_itemsize)) + chunk_bytes = max( + 1, math.prod(int(value) for value in chunks_tpczyx) * int(dtype_itemsize) + ) target_batch_bytes = 512 << 20 return max(1, min(64, target_batch_bytes // chunk_bytes)) @@ -1987,14 +2003,15 @@ def _should_use_source_aligned_plane_writes( """ min_chunks, max_chunks = _dask_chunk_min_max_tpczyx(array) source_is_single_plane_z = min_chunks[3] == 1 and max_chunks[3] == 1 - source_is_full_lateral = ( - min_chunks[4] >= int(shape_tpczyx[4]) and min_chunks[5] >= int(shape_tpczyx[5]) - ) - target_splits_lateral = ( - int(target_chunks_tpczyx[4]) < int(shape_tpczyx[4]) - or int(target_chunks_tpczyx[5]) < int(shape_tpczyx[5]) + source_is_full_lateral = min_chunks[4] >= int(shape_tpczyx[4]) and min_chunks[ + 5 + ] >= int(shape_tpczyx[5]) + target_splits_lateral = int(target_chunks_tpczyx[4]) < int(shape_tpczyx[4]) or int( + target_chunks_tpczyx[5] + ) < int(shape_tpczyx[5]) + return bool( + source_is_single_plane_z and source_is_full_lateral and target_splits_lateral ) - return bool(source_is_single_plane_z and source_is_full_lateral and target_splits_lateral) def _detect_runtime_memory_bytes() -> int: @@ -2290,7 +2307,11 @@ def _estimate_source_aligned_submission_batch_count( cpu_count = max(1, int(os.cpu_count() or 1)) effective_worker_count = max( 1, - int(worker_count) if worker_count is not None and int(worker_count) > 0 else cpu_count, + ( + int(worker_count) + if worker_count is not None and int(worker_count) > 0 + else cpu_count + ), ) per_worker_memory_limit = ( @@ -2448,8 +2469,8 @@ def _write_dask_array_in_batches( completed_regions = int(start_region) region_iter = islice( _iter_tpczyx_chunk_regions( - shape_tpczyx=shape_tpczyx, - chunks_tpczyx=chunks_tpczyx, + shape_tpczyx=shape_tpczyx, + chunks_tpczyx=chunks_tpczyx, ), int(start_region), None, @@ -2569,7 +2590,9 @@ def _write_dask_array_source_aligned_plane_batches( detected_worker_count = worker_count detected_worker_memory_limit_bytes = worker_memory_limit_bytes if detected_worker_count is None or detected_worker_memory_limit_bytes is None: - auto_worker_count, auto_worker_memory_limit = _detect_client_worker_resources(client) + auto_worker_count, auto_worker_memory_limit = _detect_client_worker_resources( + client + ) if detected_worker_count is None: detected_worker_count = auto_worker_count if detected_worker_memory_limit_bytes is None: @@ -2666,9 +2689,9 @@ def _component_matches_shape_and_chunks( actual_chunks = tuple(int(size) for size in target.chunks) except Exception: return False - return actual_shape == tuple(int(v) for v in shape_tpczyx) and actual_chunks == tuple( - int(v) for v in chunks_tpczyx - ) + return actual_shape == tuple( + int(v) for v in shape_tpczyx + ) and actual_chunks == tuple(int(v) for v in chunks_tpczyx) def _create_ingestion_progress_record( @@ -2795,9 +2818,7 @@ def _ingestion_progress_record_matches( return False record_source_component = record.get("source_component") normalized_record_source_component = ( - "" - if record_source_component is None - else str(record_source_component).strip() + "" if record_source_component is None else str(record_source_component).strip() ) normalized_source_component = ( "" if source_component is None else str(source_component).strip() @@ -2809,7 +2830,9 @@ def _ingestion_progress_record_matches( if str(record.get("write_mode", "")).strip() != str(write_mode): return False try: - record_shape = tuple(int(value) for value in record.get("canonical_shape_tpczyx", [])) + record_shape = tuple( + int(value) for value in record.get("canonical_shape_tpczyx", []) + ) record_chunks = tuple(int(value) for value in record.get("chunks_tpczyx", [])) record_factors = tuple( tuple(int(value) for value in level) @@ -3486,7 +3509,9 @@ def _emit_progress(percent: int, message: str) -> None: if (not force_rebuild) and has_complete_canonical_data_store(store_path): _emit_progress(100, "Canonical data store is already complete") data_root = zarr.open_group(str(store_path), mode="r") - data_chunks = tuple(int(value) for value in (data_root["data"].chunks or normalized_chunks)) + data_chunks = tuple( + int(value) for value in (data_root["data"].chunks or normalized_chunks) + ) return MaterializedDataStore( source_path=source_resolved, store_path=store_path, @@ -3582,7 +3607,8 @@ def _write_canonical_component( write_mode = ( "source_aligned_plane_batches" - if use_source_aligned_plane_writes and source_aligned_z_batch_depth is not None + if use_source_aligned_plane_writes + and source_aligned_z_batch_depth is not None else "chunk_region_batches" ) if write_mode == "source_aligned_plane_batches": @@ -3606,8 +3632,7 @@ def _write_canonical_component( base_start_region = 0 if ( not force_rebuild - and - checkpoint_resume_supported + and checkpoint_resume_supported and existing_progress_record is not None and _ingestion_progress_record_matches( record=existing_progress_record, @@ -3786,7 +3811,9 @@ def _persist_level_progress( completed = int(payload.get("completed_regions", 0)) except Exception: continue - start_regions_by_component[str(component)] = max(0, int(completed)) + start_regions_by_component[str(component)] = max( + 0, int(completed) + ) _materialize_data_pyramid( store_path=store_path, @@ -5201,9 +5228,7 @@ def create_dask_client( ) worker_count = ( - max(1, int(n_workers)) - if n_workers is not None - else len(assigned_gpu_ids) + max(1, int(n_workers)) if n_workers is not None else len(assigned_gpu_ids) ) worker_gpu_ids = [ str(assigned_gpu_ids[idx % len(assigned_gpu_ids)]) @@ -5309,9 +5334,9 @@ def _library_path_env_vars_for_platform( Ordered environment variable names used for runtime library lookup. """ effective_os_name = str(os.name if os_name is None else os_name).strip().lower() - effective_platform = str( - sys.platform if platform is None else platform - ).strip().lower() + effective_platform = ( + str(sys.platform if platform is None else platform).strip().lower() + ) if effective_os_name == "nt": return ("PATH",) if effective_platform == "darwin": diff --git a/src/clearex/io/log.py b/src/clearex/io/log.py index fd60e9e..6fabbc0 100644 --- a/src/clearex/io/log.py +++ b/src/clearex/io/log.py @@ -34,7 +34,6 @@ from datetime import datetime from typing import Any, TypeVar - # Local Imports # Third Party Imports diff --git a/src/clearex/io/provenance.py b/src/clearex/io/provenance.py index 60ba764..77864ef 100644 --- a/src/clearex/io/provenance.py +++ b/src/clearex/io/provenance.py @@ -238,9 +238,13 @@ def _git_metadata(repo_root: Path) -> Dict[str, Any]: Dictionary with commit, branch, dirty state, and remote URL. """ commit = _git_command(repo_root=repo_root, args=["rev-parse", "HEAD"]) - branch = _git_command(repo_root=repo_root, args=["rev-parse", "--abbrev-ref", "HEAD"]) + branch = _git_command( + repo_root=repo_root, args=["rev-parse", "--abbrev-ref", "HEAD"] + ) status = _git_command(repo_root=repo_root, args=["status", "--porcelain"]) - remote = _git_command(repo_root=repo_root, args=["config", "--get", "remote.origin.url"]) + remote = _git_command( + repo_root=repo_root, args=["config", "--get", "remote.origin.url"] + ) dirty: Optional[bool] if status is None: @@ -274,7 +278,9 @@ def _clearex_version() -> Optional[str]: return None -def _input_summary(workflow: WorkflowConfig, image_info: Optional[ImageInfo]) -> Dict[str, Any]: +def _input_summary( + workflow: WorkflowConfig, image_info: Optional[ImageInfo] +) -> Dict[str, Any]: """Build input-summary metadata for provenance records. Parameters @@ -980,7 +986,9 @@ def persist_run_provenance( else _default_steps(workflow) ) output_records = ( - _to_jsonable(dict(outputs)) if outputs is not None else _default_outputs(workflow) + _to_jsonable(dict(outputs)) + if outputs is not None + else _default_outputs(workflow) ) workflow_payload = { @@ -1011,9 +1019,7 @@ def persist_run_provenance( "spatial_calibration_text": format_spatial_calibration( workflow.spatial_calibration ), - "spatial_calibration_explicit": bool( - workflow.spatial_calibration_explicit - ), + "spatial_calibration_explicit": bool(workflow.spatial_calibration_explicit), "analysis_parameters": _to_jsonable(workflow.analysis_parameters), "analysis_output_policy": "latest_only", } diff --git a/src/clearex/main.py b/src/clearex/main.py index 02869ac..11ede39 100644 --- a/src/clearex/main.py +++ b/src/clearex/main.py @@ -36,7 +36,6 @@ import math import os - # Third Party Imports import zarr @@ -79,6 +78,7 @@ from clearex.mip_export.pipeline import ( run_mip_export_analysis, ) + try: from clearex.registration.pipeline import ( run_registration_analysis, @@ -115,6 +115,7 @@ def run_registration_analysis(*, zarr_path, parameters, client, progress_callbac "could not import clearex.registration.pipeline." ) + try: from clearex.usegment3d.pipeline import ( run_usegment3d_analysis, @@ -691,9 +692,7 @@ def _resolve_persisted_dask_backend_settings_path() -> Path: Full JSON path used for persisted Dask backend settings. """ return ( - Path.home() - / _CLEAREX_SETTINGS_DIR_NAME - / _CLEAREX_DASK_BACKEND_SETTINGS_FILE + Path.home() / _CLEAREX_SETTINGS_DIR_NAME / _CLEAREX_DASK_BACKEND_SETTINGS_FILE ).expanduser() @@ -1612,9 +1611,11 @@ def _emit_analysis_progress(percent: int, message: str) -> None: "Matching provenance run found for %s, but required output " "components are missing (%s). Re-running operation.", operation_name, - ", ".join(missing_components) - if missing_components - else "none", + ( + ", ".join(missing_components) + if missing_components + else "none" + ), ) if ( @@ -2194,9 +2195,7 @@ def _usegment3d_progress(percent: int, message: str) -> None: ): progress_state = {"last_percent": -5} - def _registration_progress( - percent: int, message: str - ) -> None: + def _registration_progress(percent: int, message: str) -> None: """Throttle registration progress logs. Parameters @@ -2444,6 +2443,7 @@ def _registration_progress( if provenance_store_path and is_zarr_store_path( provenance_store_path ): + def _display_pyramid_progress( percent: int, message: str ) -> None: @@ -2774,6 +2774,7 @@ def _visualization_progress(percent: int, message: str) -> None: if provenance_store_path and is_zarr_store_path( provenance_store_path ): + def _mip_export_progress(percent: int, message: str) -> None: """Map MIP-export progress into workflow-scale progress. @@ -2871,12 +2872,9 @@ def _mip_export_progress(percent: int, message: str) -> None: current_operation_name, provenance_store_path, ) - if ( - current_operation_name - and ( - not step_records - or str(step_records[-1].get("name")) != current_operation_name - ) + if current_operation_name and ( + not step_records + or str(step_records[-1].get("name")) != current_operation_name ): step_records.append( { @@ -2901,14 +2899,11 @@ def _mip_export_progress(percent: int, message: str) -> None: current_resolved_source, current_operation_parameters, ) - if ( - current_operation_name - and ( - not step_records - or str(step_records[-1].get("name")) != current_operation_name - or str(step_records[-1].get("parameters", {}).get("status", "")) - not in {"failed", "cancelled"} - ) + if current_operation_name and ( + not step_records + or str(step_records[-1].get("name")) != current_operation_name + or str(step_records[-1].get("parameters", {}).get("status", "")) + not in {"failed", "cancelled"} ): reason = ( "missing_input_dependency" diff --git a/src/clearex/mip_export/pipeline.py b/src/clearex/mip_export/pipeline.py index 284e351..7f0b7a2 100644 --- a/src/clearex/mip_export/pipeline.py +++ b/src/clearex/mip_export/pipeline.py @@ -259,9 +259,7 @@ def _normalize_parameters(parameters: Mapping[str, Any]) -> dict[str, Any]: normalized["resample_z_to_lateral"] = _coerce_bool( normalized.get("resample_z_to_lateral", True) ) - normalized["output_directory"] = str( - normalized.get("output_directory", "") - ).strip() + normalized["output_directory"] = str(normalized.get("output_directory", "")).strip() return normalized @@ -480,7 +478,9 @@ def _resample_axis_linear_to_uint16( dst_flat[:, :] = _to_uint16(repeated) return - sample_positions = np.linspace(0.0, float(src_len - 1), num=dst_len, dtype=np.float64) + sample_positions = np.linspace( + 0.0, float(src_len - 1), num=dst_len, dtype=np.float64 + ) lower = np.floor(sample_positions).astype(np.int64) upper = np.minimum(lower + 1, src_len - 1) weight_upper = (sample_positions - lower).astype(np.float32) @@ -576,9 +576,7 @@ def _resample_tiff_projection_z_axis_if_needed( metadata=_ome_metadata_for_projection_output( output_shape=tuple(int(v) for v in output_shape), axes=tuple(str(axis) for axis in axes), - voxel_size_um_zyx=tuple( - float(v) for v in effective_voxel_size_um_zyx - ), + voxel_size_um_zyx=tuple(float(v) for v in effective_voxel_size_um_zyx), image_name=str(output_path.stem), ), ) @@ -1131,9 +1129,7 @@ def _choose_preserved_tile_shape( ) budget = max(1, int(max_read_bytes)) while read_bytes > budget: - largest_axis = int( - np.argmax(np.asarray(tile_lengths, dtype=np.int64)) - ) + largest_axis = int(np.argmax(np.asarray(tile_lengths, dtype=np.int64))) if tile_lengths[largest_axis] <= 1: break tile_lengths[largest_axis] = max(1, tile_lengths[largest_axis] // 2) @@ -1181,7 +1177,9 @@ def _iter_tile_slices( ] -def _tile_slices_to_bounds(tile_slices: tuple[slice, ...]) -> tuple[tuple[int, int], ...]: +def _tile_slices_to_bounds( + tile_slices: tuple[slice, ...], +) -> tuple[tuple[int, int], ...]: """Convert tile slices into serializable integer bounds. Parameters @@ -1912,7 +1910,9 @@ def _run_export_task( ) stored_dtype = str(np.dtype(np.uint16)) else: - stored_dtype = str(np.dtype(getattr(output_target, "dtype", source_array.dtype))) + stored_dtype = str( + np.dtype(getattr(output_target, "dtype", source_array.dtype)) + ) effective_voxel_size_um_zyx = tuple(float(v) for v in voxel_size_um_zyx) if _is_tiff_export_format(export_format): @@ -2075,7 +2075,9 @@ def _emit(percent: int, message: str) -> None: metadata=_ome_metadata_for_projection_output( output_shape=tuple(int(v) for v in planned.output_shape), axes=tuple(str(axis) for axis in planned.axes), - voxel_size_um_zyx=tuple(float(v) for v in voxel_size_um_zyx), + voxel_size_um_zyx=tuple( + float(v) for v in voxel_size_um_zyx + ), image_name=str(planned.output_path.stem), ), ) @@ -2105,7 +2107,9 @@ def _emit(percent: int, message: str) -> None: {"axes": [str(axis) for axis in planned.axes]} ) stored_dtypes[int(planned.task_index)] = str( - np.dtype(getattr(output_target, "dtype", source_array.dtype)) + np.dtype( + getattr(output_target, "dtype", source_array.dtype) + ) ) finally: _close_zarr_store(output_root) diff --git a/src/clearex/preprocess/pad.py b/src/clearex/preprocess/pad.py index ac0c884..0431de4 100755 --- a/src/clearex/preprocess/pad.py +++ b/src/clearex/preprocess/pad.py @@ -46,7 +46,7 @@ def add_median_border(image_data: np.ndarray) -> np.ndarray: np.ndarray The image data with a border added. """ - (z_len, y_len, x_len) = image_data.shape + z_len, y_len, x_len = image_data.shape median_intensity = np.median(image_data) padded_image_data = np.full((z_len + 2, y_len + 2, x_len + 2), median_intensity) padded_image_data[1 : z_len + 1, 1 : y_len + 1, 1 : x_len + 1] = image_data diff --git a/src/clearex/registration/pipeline.py b/src/clearex/registration/pipeline.py index 40834b2..de0c7df 100644 --- a/src/clearex/registration/pipeline.py +++ b/src/clearex/registration/pipeline.py @@ -118,7 +118,9 @@ class _EdgeSpec: fixed_position: int moving_position: int - overlap_bbox_xyz: tuple[tuple[float, float], tuple[float, float], tuple[float, float]] + overlap_bbox_xyz: tuple[ + tuple[float, float], tuple[float, float], tuple[float, float] + ] overlap_voxels: int @@ -386,7 +388,11 @@ def _resolve_source_components_for_level( return full_resolution_source, full_resolution_source, 0 direct_level = _component_level_suffix(requested) - if direct_level is not None and requested in root and direct_level == effective_level: + if ( + direct_level is not None + and requested in root + and direct_level == effective_level + ): return full_resolution_source, requested, effective_level candidate_components = [ @@ -442,10 +448,15 @@ def _tile_extent_xyz( ) -def _transform_points_xyz(transform_xyz: np.ndarray, points_xyz: np.ndarray) -> np.ndarray: +def _transform_points_xyz( + transform_xyz: np.ndarray, points_xyz: np.ndarray +) -> np.ndarray: """Apply a homogeneous affine to point rows.""" homogeneous = np.concatenate( - [points_xyz.astype(np.float64), np.ones((points_xyz.shape[0], 1), dtype=np.float64)], + [ + points_xyz.astype(np.float64), + np.ones((points_xyz.shape[0], 1), dtype=np.float64), + ], axis=1, ) return (transform_xyz @ homogeneous.T).T[:, :3] @@ -507,7 +518,9 @@ def _bbox_volume_voxels( ], dtype=np.float64, ) - counts = np.maximum(1, np.floor(shape_xyz / np.maximum(voxel_xyz, 1e-6))).astype(int) + counts = np.maximum(1, np.floor(shape_xyz / np.maximum(voxel_xyz, 1e-6))).astype( + int + ) return int(np.prod(counts, dtype=np.int64)) @@ -520,7 +533,9 @@ def _build_edge_specs( ) -> list[_EdgeSpec]: """Build overlap graph edges from nominal transformed tile boxes.""" bboxes = { - int(position): _tile_bbox_xyz(nominal_transforms_xyz[int(position)], tile_extent_xyz) + int(position): _tile_bbox_xyz( + nominal_transforms_xyz[int(position)], tile_extent_xyz + ) for position in positions } edges: list[_EdgeSpec] = [] @@ -574,7 +589,9 @@ def _world_to_input_affine_zyx( float(voxel_size_um_zyx[0]), ) scale_xyz = _xyz_scale_diagonal(voxel_size_xyz) - scale_inv_xyz = np.diag([1.0 / value for value in voxel_size_xyz]).astype(np.float64) + scale_inv_xyz = np.diag([1.0 / value for value in voxel_size_xyz]).astype( + np.float64 + ) world_to_local_xyz = np.linalg.inv(local_to_world_xyz) matrix = ( _PERMUTE_ZYX_TO_XYZ @@ -583,8 +600,13 @@ def _world_to_input_affine_zyx( @ scale_xyz @ _PERMUTE_ZYX_TO_XYZ ) - offset = _PERMUTE_ZYX_TO_XYZ @ scale_inv_xyz @ ( - (world_to_local_xyz[:3, :3] @ reference_origin_xyz) + world_to_local_xyz[:3, 3] + offset = ( + _PERMUTE_ZYX_TO_XYZ + @ scale_inv_xyz + @ ( + (world_to_local_xyz[:3, :3] @ reference_origin_xyz) + + world_to_local_xyz[:3, 3] + ) ) return matrix.astype(np.float64), offset.astype(np.float64) @@ -618,7 +640,9 @@ def _resample_source_to_world_grid( def _crop_from_overlap_bbox( - overlap_bbox_xyz: tuple[tuple[float, float], tuple[float, float], tuple[float, float]], + overlap_bbox_xyz: tuple[ + tuple[float, float], tuple[float, float], tuple[float, float] + ], *, voxel_size_um_zyx: Sequence[float], overlap_zyx: Sequence[int], @@ -632,14 +656,20 @@ def _crop_from_overlap_bbox( ], dtype=np.float64, ) - minimum_xyz = np.asarray( - [overlap_bbox_xyz[0][0], overlap_bbox_xyz[1][0], overlap_bbox_xyz[2][0]], - dtype=np.float64, - ) - pad_xyz - maximum_xyz = np.asarray( - [overlap_bbox_xyz[0][1], overlap_bbox_xyz[1][1], overlap_bbox_xyz[2][1]], - dtype=np.float64, - ) + pad_xyz + minimum_xyz = ( + np.asarray( + [overlap_bbox_xyz[0][0], overlap_bbox_xyz[1][0], overlap_bbox_xyz[2][0]], + dtype=np.float64, + ) + - pad_xyz + ) + maximum_xyz = ( + np.asarray( + [overlap_bbox_xyz[0][1], overlap_bbox_xyz[1][1], overlap_bbox_xyz[2][1]], + dtype=np.float64, + ) + + pad_xyz + ) voxel_xyz = np.asarray( [ float(voxel_size_um_zyx[2]), @@ -649,7 +679,9 @@ def _crop_from_overlap_bbox( dtype=np.float64, ) size_xyz = np.maximum(voxel_xyz, maximum_xyz - minimum_xyz) - shape_xyz = np.maximum(1, np.ceil(size_xyz / np.maximum(voxel_xyz, 1e-6))).astype(int) + shape_xyz = np.maximum(1, np.ceil(size_xyz / np.maximum(voxel_xyz, 1e-6))).astype( + int + ) shape_zyx = (int(shape_xyz[2]), int(shape_xyz[1]), int(shape_xyz[0])) return minimum_xyz, shape_zyx @@ -716,11 +748,15 @@ def _register_pairwise_overlap( root = zarr.open_group(str(zarr_path), mode="r") source = root[source_component] fixed_source = np.asarray( - source[int(t_index), int(edge.fixed_position), int(registration_channel), :, :, :], + source[ + int(t_index), int(edge.fixed_position), int(registration_channel), :, :, : + ], dtype=np.float32, ) moving_source = np.asarray( - source[int(t_index), int(edge.moving_position), int(registration_channel), :, :, :], + source[ + int(t_index), int(edge.moving_position), int(registration_channel), :, :, : + ], dtype=np.float32, ) crop_origin_xyz, crop_shape_zyx = _crop_from_overlap_bbox( @@ -749,7 +785,11 @@ def _register_pairwise_overlap( fixed_mask = np.asarray(fixed_crop > 0, dtype=np.float32) moving_mask = np.asarray(moving_crop > 0, dtype=np.float32) overlap_pixels = int(np.count_nonzero((fixed_mask > 0) & (moving_mask > 0))) - if overlap_pixels <= 0 or float(np.std(fixed_crop)) <= 1e-6 or float(np.std(moving_crop)) <= 1e-6: + if ( + overlap_pixels <= 0 + or float(np.std(fixed_crop)) <= 1e-6 + or float(np.std(moving_crop)) <= 1e-6 + ): return { "fixed_position": int(edge.fixed_position), "moving_position": int(edge.moving_position), @@ -761,7 +801,9 @@ def _register_pairwise_overlap( } try: - fixed_image = _ants_image_from_zyx(fixed_crop, voxel_size_um_zyx=voxel_size_um_zyx) + fixed_image = _ants_image_from_zyx( + fixed_crop, voxel_size_um_zyx=voxel_size_um_zyx + ) moving_image = _ants_image_from_zyx( moving_crop, voxel_size_um_zyx=voxel_size_um_zyx ) @@ -814,7 +856,9 @@ def _matrix_from_pose(params: np.ndarray, registration_type: str) -> np.ndarray: matrix[:3, 3] = np.asarray(params[:3], dtype=np.float64) return matrix - rotation = Rotation.from_rotvec(np.asarray(params[:3], dtype=np.float64)).as_matrix() + rotation = Rotation.from_rotvec( + np.asarray(params[:3], dtype=np.float64) + ).as_matrix() translation = np.asarray(params[3:6], dtype=np.float64) if registration_type == "rigid": matrix[:3, :3] = rotation @@ -838,7 +882,9 @@ def _pose_from_matrix(matrix_xyz: np.ndarray, registration_type: str) -> np.ndar scale = float(np.cbrt(max(np.linalg.det(linear), 1e-12))) rotation_matrix = linear / max(scale, 1e-12) rotation = Rotation.from_matrix(rotation_matrix) - return np.concatenate([rotation.as_rotvec(), translation, np.asarray([math.log(scale)])]) + return np.concatenate( + [rotation.as_rotvec(), translation, np.asarray([math.log(scale)])] + ) rotation = Rotation.from_matrix(linear) return np.concatenate([rotation.as_rotvec(), translation]) @@ -857,7 +903,9 @@ def _translation_residual_voxels( def _component_positions( - positions: Sequence[int], active_edge_indices: Sequence[int], edges: Sequence[_EdgeSpec] + positions: Sequence[int], + active_edge_indices: Sequence[int], + edges: Sequence[_EdgeSpec], ) -> list[list[int]]: """Return connected position components for the active edge graph.""" adjacency: dict[int, set[int]] = {int(position): set() for position in positions} @@ -896,9 +944,13 @@ def _solve_translation_component( anchor_position: int, ) -> dict[int, np.ndarray]: """Solve translation-only correction poses for one connected component.""" - solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + solved = { + int(position): np.eye(4, dtype=np.float64) for position in component_positions + } variable_positions = [ - int(position) for position in component_positions if int(position) != int(anchor_position) + int(position) + for position in component_positions + if int(position) != int(anchor_position) ] if not variable_positions: return solved @@ -910,9 +962,17 @@ def _solve_translation_component( edge = edges[int(edge_index)] fixed_position = int(edge.fixed_position) moving_position = int(edge.moving_position) - if fixed_position not in column_index and fixed_position != int(anchor_position) and fixed_position not in component_positions: + if ( + fixed_position not in column_index + and fixed_position != int(anchor_position) + and fixed_position not in component_positions + ): continue - if moving_position not in column_index and moving_position != int(anchor_position) and moving_position not in component_positions: + if ( + moving_position not in column_index + and moving_position != int(anchor_position) + and moving_position not in component_positions + ): continue measurement = np.asarray( edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 @@ -955,9 +1015,13 @@ def _solve_nonlinear_component( voxel_size_um_zyx: Sequence[float], ) -> dict[int, np.ndarray]: """Solve rigid/similarity correction poses for one connected component.""" - solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + solved = { + int(position): np.eye(4, dtype=np.float64) for position in component_positions + } variable_positions = [ - int(position) for position in component_positions if int(position) != int(anchor_position) + int(position) + for position in component_positions + if int(position) != int(anchor_position) ] if not variable_positions: return solved @@ -1003,11 +1067,17 @@ def _residuals(params: np.ndarray) -> np.ndarray: measured_linear = np.asarray(measured[:3, :3], dtype=np.float64) predicted_linear = np.asarray(predicted[:3, :3], dtype=np.float64) if registration_type == "similarity": - measured_scale = float(np.cbrt(max(np.linalg.det(measured_linear), 1e-12))) - predicted_scale = float(np.cbrt(max(np.linalg.det(predicted_linear), 1e-12))) + measured_scale = float( + np.cbrt(max(np.linalg.det(measured_linear), 1e-12)) + ) + predicted_scale = float( + np.cbrt(max(np.linalg.det(predicted_linear), 1e-12)) + ) measured_rotation = measured_linear / max(measured_scale, 1e-12) predicted_rotation = predicted_linear / max(predicted_scale, 1e-12) - scale_residual = math.log(max(predicted_scale, 1e-12) / max(measured_scale, 1e-12)) + scale_residual = math.log( + max(predicted_scale, 1e-12) / max(measured_scale, 1e-12) + ) else: measured_rotation = measured_linear predicted_rotation = predicted_linear @@ -1016,10 +1086,9 @@ def _residuals(params: np.ndarray) -> np.ndarray: delta_rotation = Rotation.from_matrix( measured_rotation.T @ predicted_rotation ).as_rotvec() - delta_translation = ( - np.asarray(predicted[:3, 3] - measured[:3, 3], dtype=np.float64) - / max(mean_voxel, 1e-6) - ) + delta_translation = np.asarray( + predicted[:3, 3] - measured[:3, 3], dtype=np.float64 + ) / max(mean_voxel, 1e-6) residual_values.extend((weight * delta_rotation).tolist()) residual_values.extend((weight * delta_translation).tolist()) if registration_type == "similarity": @@ -1033,7 +1102,9 @@ def _residuals(params: np.ndarray) -> np.ndarray: x_scale="jac", max_nfev=200, ) - solved = {int(position): np.eye(4, dtype=np.float64) for position in component_positions} + solved = { + int(position): np.eye(4, dtype=np.float64) for position in component_positions + } for position in variable_positions: base = variable_index[int(position)] * pose_size solved[int(position)] = _matrix_from_pose( @@ -1105,9 +1176,10 @@ def _solve_with_pruning( measured = np.asarray( edge_results[int(edge_index)]["correction_matrix_xyz"], dtype=np.float64 ) - predicted = np.linalg.inv(solved[int(edge.fixed_position)]) @ solved[ - int(edge.moving_position) - ] + predicted = ( + np.linalg.inv(solved[int(edge.fixed_position)]) + @ solved[int(edge.moving_position)] + ) residuals[int(edge_index)] = _translation_residual_voxels( measured, predicted, voxel_size_um_zyx=voxel_size_um_zyx ) @@ -1118,9 +1190,11 @@ def _solve_with_pruning( worst_edge_index = int(np.nanargmax(residuals)) worst_residual = float(residuals[worst_edge_index]) mean_residual = float(np.mean(active_residuals)) - if ( - worst_residual <= float(_ABSOLUTE_RESIDUAL_THRESHOLD_PX) - or worst_residual <= max(float(_ABSOLUTE_RESIDUAL_THRESHOLD_PX), mean_residual * float(_RELATIVE_RESIDUAL_THRESHOLD)) + if worst_residual <= float( + _ABSOLUTE_RESIDUAL_THRESHOLD_PX + ) or worst_residual <= max( + float(_ABSOLUTE_RESIDUAL_THRESHOLD_PX), + mean_residual * float(_RELATIVE_RESIDUAL_THRESHOLD), ): break @@ -1128,7 +1202,8 @@ def _solve_with_pruning( component_positions = next( component for component in components - if int(edge.fixed_position) in component and int(edge.moving_position) in component + if int(edge.fixed_position) in component + and int(edge.moving_position) in component ) component_active = [ int(edge_index) @@ -1159,9 +1234,7 @@ def _blend_weight_volume( profile = np.ones(int(axis_size), dtype=np.float32) width = max(0, min(int(ramp_width), max(0, int(axis_size // 2)))) if width > 0: - ramp = 0.5 - 0.5 * np.cos( - np.linspace(0.0, np.pi, width, dtype=np.float32) - ) + ramp = 0.5 - 0.5 * np.cos(np.linspace(0.0, np.pi, width, dtype=np.float32)) profile[:width] = ramp profile[-width:] = np.minimum(profile[-width:], ramp[::-1]) profiles.append(profile) @@ -1224,9 +1297,21 @@ def _process_and_write_registration_chunk( ) chunk_bbox_xyz = np.asarray( [ - [float(chunk_origin_xyz[0]), float(chunk_origin_xyz[0]) + (float(chunk_shape_zyx[2]) * float(voxel_size_um_zyx[2]))], - [float(chunk_origin_xyz[1]), float(chunk_origin_xyz[1]) + (float(chunk_shape_zyx[1]) * float(voxel_size_um_zyx[1]))], - [float(chunk_origin_xyz[2]), float(chunk_origin_xyz[2]) + (float(chunk_shape_zyx[0]) * float(voxel_size_um_zyx[0]))], + [ + float(chunk_origin_xyz[0]), + float(chunk_origin_xyz[0]) + + (float(chunk_shape_zyx[2]) * float(voxel_size_um_zyx[2])), + ], + [ + float(chunk_origin_xyz[1]), + float(chunk_origin_xyz[1]) + + (float(chunk_shape_zyx[1]) * float(voxel_size_um_zyx[1])), + ], + [ + float(chunk_origin_xyz[2]), + float(chunk_origin_xyz[2]) + + (float(chunk_shape_zyx[0]) * float(voxel_size_um_zyx[0])), + ], ], dtype=np.float64, ) @@ -1242,7 +1327,9 @@ def _process_and_write_registration_chunk( ) bbox_min = bbox_payload[:3] bbox_max = bbox_payload[3:] - if np.any(bbox_max <= chunk_bbox_xyz[:, 0]) or np.any(chunk_bbox_xyz[:, 1] <= bbox_min): + if np.any(bbox_max <= chunk_bbox_xyz[:, 0]) or np.any( + chunk_bbox_xyz[:, 1] <= bbox_min + ): continue source_volume = np.asarray( source[int(t_index), int(position_index), int(c_index), :, :, :], @@ -1283,8 +1370,8 @@ def _process_and_write_registration_chunk( where=chunk_weight > 0, ) write_root = zarr.open_group(str(zarr_path), mode="a") - write_root[output_component][int(t_index), 0, int(c_index), z0:z1, y0:y1, x0:x1] = _cast_to_dtype( - normalized, np.dtype(output_dtype) + write_root[output_component][int(t_index), 0, int(c_index), z0:z1, y0:y1, x0:x1] = ( + _cast_to_dtype(normalized, np.dtype(output_dtype)) ) return 1 @@ -1390,12 +1477,18 @@ def run_registration_analysis( If source components or required stage metadata are missing. """ root = zarr.open_group(str(zarr_path), mode="r") - requested_source_component = str(parameters.get("input_source", "data")).strip() or "data" - requested_resolution_level = max(0, int(parameters.get("input_resolution_level", 0))) - source_component, pairwise_source_component, effective_level = _resolve_source_components_for_level( - root=root, - requested_source_component=requested_source_component, - input_resolution_level=requested_resolution_level, + requested_source_component = ( + str(parameters.get("input_source", "data")).strip() or "data" + ) + requested_resolution_level = max( + 0, int(parameters.get("input_resolution_level", 0)) + ) + source_component, pairwise_source_component, effective_level = ( + _resolve_source_components_for_level( + root=root, + requested_source_component=requested_source_component, + input_resolution_level=requested_resolution_level, + ) ) full_source = root[source_component] pairwise_source = root[pairwise_source_component] @@ -1418,20 +1511,29 @@ def run_registration_analysis( ) root_attrs = dict(root.attrs) - spatial_calibration = spatial_calibration_from_dict(root_attrs.get("spatial_calibration")) + spatial_calibration = spatial_calibration_from_dict( + root_attrs.get("spatial_calibration") + ) stage_rows = _load_stage_rows(root_attrs) if len(positions) > 1 and len(stage_rows) < len(positions): raise ValueError( "registration requires multiposition stage metadata when more than one position is present." ) if not stage_rows: - stage_rows = [{"x": 0.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0} for _ in positions] + stage_rows = [ + {"x": 0.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0} for _ in positions + ] configured_anchor = parameters.get("anchor_position") - if str(parameters.get("anchor_mode", "central")).strip().lower() == "manual" and configured_anchor is not None: + if ( + str(parameters.get("anchor_mode", "central")).strip().lower() == "manual" + and configured_anchor is not None + ): anchor_position = int(configured_anchor) else: - anchor_position = _position_centroid_anchor(stage_rows, spatial_calibration, positions) + anchor_position = _position_centroid_anchor( + stage_rows, spatial_calibration, positions + ) if anchor_position < 0 or anchor_position >= len(positions): raise ValueError("registration anchor_position is out of bounds.") @@ -1464,8 +1566,12 @@ def run_registration_analysis( correction_affines_tex44 = np.zeros( (int(source_shape_tpczyx[0]), int(edge_count), 4, 4), dtype=np.float64 ) - edge_status_te = np.zeros((int(source_shape_tpczyx[0]), int(edge_count)), dtype=np.uint8) - edge_residual_te = np.full((int(source_shape_tpczyx[0]), int(edge_count)), np.nan, dtype=np.float32) + edge_status_te = np.zeros( + (int(source_shape_tpczyx[0]), int(edge_count)), dtype=np.uint8 + ) + edge_residual_te = np.full( + (int(source_shape_tpczyx[0]), int(edge_count)), np.nan, dtype=np.float32 + ) anchor_positions: list[int] = [] effective_corrections_tpx44 = np.repeat( np.eye(4, dtype=np.float64)[np.newaxis, np.newaxis, :, :], @@ -1490,11 +1596,19 @@ def run_registration_analysis( t_index=int(t_index), registration_channel=int(registration_channel), edge=edge, - nominal_fixed_transform_xyz=nominal_transforms_xyz[int(edge.fixed_position)], - nominal_moving_transform_xyz=nominal_transforms_xyz[int(edge.moving_position)], + nominal_fixed_transform_xyz=nominal_transforms_xyz[ + int(edge.fixed_position) + ], + nominal_moving_transform_xyz=nominal_transforms_xyz[ + int(edge.moving_position) + ], voxel_size_um_zyx=pairwise_voxel_size_um_zyx, - overlap_zyx=[int(value) for value in parameters.get("overlap_zyx", [8, 32, 32])], - registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + overlap_zyx=[ + int(value) for value in parameters.get("overlap_zyx", [8, 32, 32]) + ], + registration_type=str(parameters.get("registration_type", "rigid")) + .strip() + .lower(), ) for edge in edge_specs ] @@ -1524,11 +1638,15 @@ def run_registration_analysis( edges=edge_specs, edge_results=pairwise_results, anchor_position=anchor_position, - registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + registration_type=str(parameters.get("registration_type", "rigid")) + .strip() + .lower(), voxel_size_um_zyx=pairwise_voxel_size_um_zyx, ) for position_index in positions: - effective_corrections_tpx44[int(t_index), int(position_index)] = solved[int(position_index)] + effective_corrections_tpx44[int(t_index), int(position_index)] = solved[ + int(position_index) + ] for edge_index, result in enumerate(pairwise_results): correction_affines_tex44[int(t_index), int(edge_index)] = np.asarray( result["correction_matrix_xyz"], dtype=np.float64 @@ -1548,7 +1666,9 @@ def run_registration_analysis( (int(source_shape_tpczyx[0]), int(source_shape_tpczyx[1]), 6), dtype=np.float64, ) - full_tile_extent_xyz = _tile_extent_xyz(source_shape_tpczyx[3:], full_voxel_size_um_zyx) + full_tile_extent_xyz = _tile_extent_xyz( + source_shape_tpczyx[3:], full_voxel_size_um_zyx + ) all_bbox_mins: list[np.ndarray] = [] all_bbox_maxs: list[np.ndarray] = [] for t_index in range(int(source_shape_tpczyx[0])): @@ -1557,8 +1677,12 @@ def run_registration_analysis( effective_corrections_tpx44[int(t_index), int(position_index)] @ nominal_transforms_xyz[int(position_index)] ) - effective_transforms_tpx44[int(t_index), int(position_index)] = effective_transform - bbox_min, bbox_max = _tile_bbox_xyz(effective_transform, full_tile_extent_xyz) + effective_transforms_tpx44[int(t_index), int(position_index)] = ( + effective_transform + ) + bbox_min, bbox_max = _tile_bbox_xyz( + effective_transform, full_tile_extent_xyz + ) transformed_bboxes_tpx6[int(t_index), int(position_index), :3] = bbox_min transformed_bboxes_tpx6[int(t_index), int(position_index), 3:] = bbox_max all_bbox_mins.append(bbox_min) @@ -1683,7 +1807,9 @@ def run_registration_analysis( output_origin_xyz=output_min_xyz, voxel_size_um_zyx=full_voxel_size_um_zyx, blend_mode=str(parameters.get("blend_mode", "feather")).strip().lower(), - overlap_zyx=[int(value) for value in parameters.get("overlap_zyx", [8, 32, 32])], + overlap_zyx=[ + int(value) for value in parameters.get("overlap_zyx", [8, 32, 32]) + ], output_dtype=str(full_source.dtype), ) for t_index, c_index, z_chunk, y_chunk, x_chunk in fusion_tasks @@ -1786,7 +1912,9 @@ def run_registration_analysis( input_resolution_level=int(effective_level), requested_input_resolution_level=int(requested_resolution_level), registration_channel=int(registration_channel), - registration_type=str(parameters.get("registration_type", "rigid")).strip().lower(), + registration_type=str(parameters.get("registration_type", "rigid")) + .strip() + .lower(), anchor_positions=tuple(int(value) for value in anchor_positions), positions=int(source_shape_tpczyx[1]), timepoints=int(source_shape_tpczyx[0]), diff --git a/src/clearex/registration/tre.py b/src/clearex/registration/tre.py index 034af05..eec5c54 100644 --- a/src/clearex/registration/tre.py +++ b/src/clearex/registration/tre.py @@ -42,7 +42,6 @@ intensity_weighted_centroids, ) - TransformModel = SimilarityTransform | AffineTransform diff --git a/src/clearex/segmentation/pointsource.py b/src/clearex/segmentation/pointsource.py index 509b826..9a63af8 100755 --- a/src/clearex/segmentation/pointsource.py +++ b/src/clearex/segmentation/pointsource.py @@ -30,8 +30,11 @@ from skimage.feature import blob_log from scipy.ndimage import gaussian_filter -from clearex.detect.particles import remove_close_blobs, \ - eliminate_insignificant_point_sources +from clearex.detect.particles import ( + remove_close_blobs, + eliminate_insignificant_point_sources, +) + # Local Imports from clearex.plot.images import mips from clearex.preprocess.scale import resize_data diff --git a/src/clearex/shear/__init__.py b/src/clearex/shear/__init__.py index e2c4a55..4dac7dd 100644 --- a/src/clearex/shear/__init__.py +++ b/src/clearex/shear/__init__.py @@ -6,4 +6,3 @@ from .pipeline import ShearTransformSummary, run_shear_transform_analysis __all__ = ["ShearTransformSummary", "run_shear_transform_analysis"] - diff --git a/src/clearex/shear/pipeline.py b/src/clearex/shear/pipeline.py index e5bb4db..71f52b9 100644 --- a/src/clearex/shear/pipeline.py +++ b/src/clearex/shear/pipeline.py @@ -32,7 +32,16 @@ from dataclasses import dataclass from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Mapping, + Optional, + Sequence, + Union, +) # Third Party Imports import ants @@ -932,11 +941,17 @@ def _source_bounds_for_output_region( pad_z, pad_y, pad_x = roi_padding_zyx src_z0 = max(0, int(np.floor(min_xyz[2] / z_um)) - int(pad_z)) - src_z1 = min(int(source_shape_zyx[0]), int(np.ceil(max_xyz[2] / z_um)) + int(pad_z) + 1) + src_z1 = min( + int(source_shape_zyx[0]), int(np.ceil(max_xyz[2] / z_um)) + int(pad_z) + 1 + ) src_y0 = max(0, int(np.floor(min_xyz[1] / y_um)) - int(pad_y)) - src_y1 = min(int(source_shape_zyx[1]), int(np.ceil(max_xyz[1] / y_um)) + int(pad_y) + 1) + src_y1 = min( + int(source_shape_zyx[1]), int(np.ceil(max_xyz[1] / y_um)) + int(pad_y) + 1 + ) src_x0 = max(0, int(np.floor(min_xyz[0] / x_um)) - int(pad_x)) - src_x1 = min(int(source_shape_zyx[2]), int(np.ceil(max_xyz[0] / x_um)) + int(pad_x) + 1) + src_x1 = min( + int(source_shape_zyx[2]), int(np.ceil(max_xyz[0] / x_um)) + int(pad_x) + 1 + ) if src_z0 >= src_z1 or src_y0 >= src_y1 or src_x0 >= src_x1: return None @@ -1202,8 +1217,7 @@ def _emit(percent: int, message: str) -> None: normalized["shear_yz"] = float(np.tan(np.deg2rad(estimated_shear_yz_deg))) _emit( 4, - "Auto-estimated shear_yz_deg=" - f"{float(estimated_shear_yz_deg):.3f}", + "Auto-estimated shear_yz_deg=" f"{float(estimated_shear_yz_deg):.3f}", ) else: _emit(4, "Auto-estimation failed; using configured shear parameters") diff --git a/src/clearex/stats/__init__.py b/src/clearex/stats/__init__.py index d2ba29c..3074984 100644 --- a/src/clearex/stats/__init__.py +++ b/src/clearex/stats/__init__.py @@ -22,4 +22,4 @@ # BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +# POSSIBILITY OF SUCH DAMAGE. diff --git a/src/clearex/usegment3d/pipeline.py b/src/clearex/usegment3d/pipeline.py index 4ac1e77..1ffe183 100644 --- a/src/clearex/usegment3d/pipeline.py +++ b/src/clearex/usegment3d/pipeline.py @@ -336,10 +336,7 @@ def _summarize_client_worker_state(client: "Client") -> str: try: scheduler_info = client.scheduler_info() except Exception as exc: - return ( - "scheduler_info_unavailable=" - f"{type(exc).__name__}: {exc}" - ) + return "scheduler_info_unavailable=" f"{type(exc).__name__}: {exc}" worker_infos = dict(scheduler_info.get("workers", {})) active_workers = len(worker_infos) @@ -1415,7 +1412,9 @@ def _emit(percent: int, message: str) -> None: output_array = root[data_component] output_array.attrs.update( { - "voxel_size_um_zyx": [float(value) for value in output_voxel_size_um_zyx], + "voxel_size_um_zyx": [ + float(value) for value in output_voxel_size_um_zyx + ], "scale_tczyx": [ 1.0, 1.0, diff --git a/src/clearex/visualization/pipeline.py b/src/clearex/visualization/pipeline.py index b3fe69d..b750c32 100644 --- a/src/clearex/visualization/pipeline.py +++ b/src/clearex/visualization/pipeline.py @@ -32,7 +32,6 @@ import argparse from dataclasses import dataclass from datetime import datetime, timezone -import hashlib import json import math from pathlib import Path @@ -57,7 +56,6 @@ spatial_calibration_to_dict, ) - ProgressCallback = Callable[[int, str], None] _AXIS_LABELS_TCZYX = ("t", "c", "z", "y", "x") _TPCZYX_TO_TCZYX = (0, 2, 3, 4, 5) @@ -715,12 +713,16 @@ def _extract_scale_tczyx_from_navigate_raw( ) lateral_from_fov = _first_positive_float( ( - (fov_x / img_x) - if fov_x is not None and img_x is not None and img_x > 0 - else None, - (fov_y / img_y) - if fov_y is not None and img_y is not None and img_y > 0 - else None, + ( + (fov_x / img_x) + if fov_x is not None and img_x is not None and img_x > 0 + else None + ), + ( + (fov_y / img_y) + if fov_y is not None and img_y is not None and img_y > 0 + else None + ), ) ) @@ -2056,14 +2058,81 @@ def _component_matches_shape_chunks( return array_chunks == tuple(int(v) for v in chunks_tpczyx) -def _visualization_multiscale_cache_prefix(source_component: str) -> str: - """Return deterministic component prefix used for prepared display pyramids.""" +def _display_pyramid_level_component( + *, + source_component: str, + level: int, +) -> str: + """Return canonical component path for one generated display-pyramid level. + + Parameters + ---------- + source_component : str + Base source component path. + level : int + Pyramid level index (must be >= 1). + + Returns + ------- + str + Target component path for this level. + + Raises + ------ + ValueError + If ``level`` is less than 1. + + Notes + ----- + Raw data keeps canonical root-level ``data_pyramid`` naming to remain + compatible with existing analysis readers. Derived components are written + alongside their source at ``_pyramid/level_``. + """ + level_index = int(level) + if level_index < 1: + raise ValueError("display-pyramid level index must be >= 1.") component = str(source_component).strip() or "data" - safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", component).strip("_") - if not safe_name: - safe_name = "data" - digest = hashlib.sha1(component.encode("utf-8")).hexdigest()[:10] - return f"results/display_pyramid/by_component/{safe_name}_{digest}" + if component == "data": + return f"data_pyramid/level_{level_index}" + return f"{component}_pyramid/level_{level_index}" + + +def _is_source_pyramid_layout_compatible( + *, + source_component: str, + source_components: Sequence[str], +) -> bool: + """Return whether a discovered multiscale set follows canonical layout. + + Parameters + ---------- + source_component : str + Base source component path. + source_components : sequence[str] + Ordered component paths discovered for multiscale rendering. + + Returns + ------- + bool + ``True`` when all levels follow canonical source-adjacent pyramid + component naming. + """ + components = [str(item).strip() for item in source_components] + if not components: + return False + + base_component = str(source_component).strip() or "data" + if components[0] != base_component: + return False + + for level_index, component in enumerate(components[1:], start=1): + expected = _display_pyramid_level_component( + source_component=base_component, + level=level_index, + ) + if str(component).strip() != expected: + return False + return True def _collect_existing_multiscale_components( @@ -2181,8 +2250,6 @@ def _build_visualization_multiscale_components( factor_payload: list[list[int]] = [ [int(value) for value in level_factors_tpczyx[0]] ] - cache_prefix = _visualization_multiscale_cache_prefix(str(source_component)) - prior_component = str(source_component) prior_factors = tuple(int(value) for value in level_factors_tpczyx[0]) for level_index, absolute_factors in enumerate(level_factors_tpczyx[1:], start=1): @@ -2212,7 +2279,10 @@ def _build_visualization_multiscale_components( int(absolute_factors[5]), ) - level_component = f"{cache_prefix}/level_{level_index}" + level_component = _display_pyramid_level_component( + source_component=str(source_component), + level=int(level_index), + ) source_level = da.from_zarr(str(zarr_path), component=source_level_component) downsampled = _downsample_tpczyx_by_stride(source_level, downsample_factors) level_shape = tuple(int(size) for size in tuple(downsampled.shape)) @@ -2283,9 +2353,7 @@ def _build_visualization_multiscale_components( if not isinstance(component_map, dict): component_map = {} component_map[str(source_component)] = [str(item) for item in level_paths] - root.attrs[_DISPLAY_PYRAMID_ROOT_MAP_ATTR] = _sanitize_metadata_value( - component_map - ) + root.attrs[_DISPLAY_PYRAMID_ROOT_MAP_ATTR] = _sanitize_metadata_value(component_map) legacy_component_map = root.attrs.get(_LEGACY_DISPLAY_PYRAMID_ROOT_MAP_ATTR) if not isinstance(legacy_component_map, dict): legacy_component_map = {} @@ -2514,9 +2582,7 @@ def _load_display_contrast_limits_by_channel( attrs = dict(root[str(source_component)].attrs) except Exception: return None - return _coerce_contrast_limits_by_channel( - attrs.get(_DISPLAY_CONTRAST_LIMITS_ATTR) - ) + return _coerce_contrast_limits_by_channel(attrs.get(_DISPLAY_CONTRAST_LIMITS_ATTR)) def _save_display_contrast_metadata( @@ -2602,9 +2668,7 @@ def _sample_channel_for_display_contrast( if sampled_voxels <= target: return sampled - tp_stride = int( - max(1, math.ceil(math.sqrt(float(sampled_voxels) / float(target)))) - ) + tp_stride = int(max(1, math.ceil(math.sqrt(float(sampled_voxels) / float(target))))) return sampled[::tp_stride, ::tp_stride, :, :, :] @@ -3291,10 +3355,13 @@ def _resolve_layer_contrast_limits_by_channel( key = str(component).strip() or "data" cached = layer_display_contrast_cache.get(key) if cached is None: - cached = _load_display_contrast_limits_by_channel( - root=scale_root, - source_component=key, - ) or tuple() + cached = ( + _load_display_contrast_limits_by_channel( + root=scale_root, + source_component=key, + ) + or tuple() + ) layer_display_contrast_cache[key] = cached if len(cached) >= int(channel_count): return tuple(cached[: int(channel_count)]) @@ -3419,7 +3486,9 @@ def _resolve_layer_contrast_limits_by_channel( display_strategy = "none" if not display_level_arrays: continue - is_multiscale = int(effective_ndisplay) < 3 and len(display_level_arrays) > 1 + is_multiscale = ( + int(effective_ndisplay) < 3 and len(display_level_arrays) > 1 + ) channel_count = max(1, int(display_level_arrays[0].shape[1])) requested_channels = tuple( int(index) @@ -3764,9 +3833,9 @@ def _pop_keyframe(_viewer: Any = None) -> None: napari.run() _persist_keyframes() return { - "keyframe_manifest_path": str(manifest_path) - if manifest_path is not None - else None, + "keyframe_manifest_path": ( + str(manifest_path) if manifest_path is not None else None + ), "keyframe_count": int(len(keyframes)), "renderer": dict(renderer_info), } @@ -3862,11 +3931,26 @@ def _emit(percent: int, message: str) -> None: root=root, source_component=source_component, ) - reused_existing_levels = bool(len(existing_components) > 1 and not force_rerun) + has_multiscale_levels = bool(len(existing_components) > 1) + layout_compatible = _is_source_pyramid_layout_compatible( + source_component=source_component, + source_components=existing_components, + ) + reused_existing_levels = bool( + has_multiscale_levels and not force_rerun and layout_compatible + ) + should_rebuild_for_layout = bool( + has_multiscale_levels and not force_rerun and not layout_compatible + ) if reused_existing_levels: _emit(30, f"Reusing existing display pyramid for {source_component}") source_components = tuple(str(item) for item in existing_components) else: + if should_rebuild_for_layout: + _emit( + 15, + "Existing display pyramid uses legacy layout; rebuilding source-adjacent levels.", + ) _emit(30, f"Preparing display pyramid for {source_component}") level_factors = _resolve_visualization_pyramid_factors_tpczyx( root_attrs=root_attrs, @@ -3884,8 +3968,7 @@ def _emit(percent: int, message: str) -> None: source_attrs=source_attrs, ) persisted_factors = [ - list(row) - for row in level_factors[: max(1, len(tuple(source_components)))] + list(row) for row in level_factors[: max(1, len(tuple(source_components)))] ] root[str(source_component)].attrs[_DISPLAY_PYRAMID_LEVELS_ATTR] = [ str(item) for item in source_components @@ -4070,9 +4153,10 @@ def _save_visualization_metadata( } if viewer_pid is not None: payload["viewer_pid"] = int(viewer_pid) - if display_mode_fallback_reason is not None and str( - display_mode_fallback_reason - ).strip(): + if ( + display_mode_fallback_reason is not None + and str(display_mode_fallback_reason).strip() + ): payload["display_mode_fallback_reason"] = str(display_mode_fallback_reason) if keyframe_manifest_path is not None and str(keyframe_manifest_path).strip(): payload["keyframe_manifest_path"] = str(keyframe_manifest_path) diff --git a/src/clearex/workflow.py b/src/clearex/workflow.py index 342b5cb..df74087 100644 --- a/src/clearex/workflow.py +++ b/src/clearex/workflow.py @@ -29,8 +29,17 @@ import math import os import subprocess -from typing import Any, Collection, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union - +from typing import ( + Any, + Collection, + Dict, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) ChunkSpec = Optional[Union[int, Tuple[int, ...]]] ZarrAxisSpec = Tuple[int, int, int, int, int, int] @@ -50,10 +59,10 @@ "flatfield", "deconvolution", "shear_transform", + "display_pyramid", + "registration", "particle_detection", "usegment3d", - "registration", - "display_pyramid", "visualization", "mip_export", ) @@ -357,7 +366,9 @@ def _validate_analysis_input_reference( resolved_component=resolved_component, ) if producer_operation in order_map: - consumer_index = order_map.get(str(reference.consumer_operation).strip(), -1) + consumer_index = order_map.get( + str(reference.consumer_operation).strip(), -1 + ) producer_index = order_map[producer_operation] if producer_index >= consumer_index: return AnalysisInputDependencyIssue( @@ -569,7 +580,7 @@ class WorkflowExecutionCancelled(RuntimeError): "roi_padding_zyx": [2, 2, 2], }, "particle_detection": { - "execution_order": 5, + "execution_order": 6, "input_source": "data", "force_rerun": False, "channel_index": 0, @@ -590,7 +601,7 @@ class WorkflowExecutionCancelled(RuntimeError): "min_distance_sigma": 10.0, }, "usegment3d": { - "execution_order": 6, + "execution_order": 7, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -648,7 +659,7 @@ class WorkflowExecutionCancelled(RuntimeError): "output_dtype": "uint32", }, "registration": { - "execution_order": 4, + "execution_order": 5, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -664,7 +675,7 @@ class WorkflowExecutionCancelled(RuntimeError): "blend_mode": "feather", }, "display_pyramid": { - "execution_order": 7, + "execution_order": 4, "input_source": "data", "force_rerun": False, "chunk_basis": "3d", @@ -1611,9 +1622,11 @@ def _normalize_usegment3d_parameters( shape_zyx = normalized["aggregation_tile_shape_zyx"] overlap_values = [max(0, int(v)) for v in tile_overlap_zyx] ratio_candidates = [ - (float(overlap_values[idx]) / float(shape_zyx[idx])) - if int(shape_zyx[idx]) > 0 - else 0.0 + ( + (float(overlap_values[idx]) / float(shape_zyx[idx])) + if int(shape_zyx[idx]) > 0 + else 0.0 + ) for idx in range(3) ] tile_overlap_ratio = max(ratio_candidates) if ratio_candidates else 0.0 @@ -2352,14 +2365,14 @@ def selected_analysis_operations( selected.append("deconvolution") if shear_transform: selected.append("shear_transform") + if display_pyramid: + selected.append("display_pyramid") + if registration: + selected.append("registration") if particle_detection: selected.append("particle_detection") if usegment3d: selected.append("usegment3d") - if registration: - selected.append("registration") - if display_pyramid: - selected.append("display_pyramid") if visualization: selected.append("visualization") if mip_export: @@ -4309,9 +4322,7 @@ def parse_spatial_calibration( axis_name, binding = item.split("=", 1) key = str(axis_name).strip().lower() if key not in SPATIAL_CALIBRATION_WORLD_AXES: - raise ValueError( - "Spatial calibration world axes must be z, y, or x." - ) + raise ValueError("Spatial calibration world axes must be z, y, or x.") if key in assignments: raise ValueError( f"Spatial calibration world axis '{key}' is assigned more than once." @@ -4364,16 +4375,12 @@ def normalize_spatial_calibration( ) theta_mode = ( - str( - value.get("theta_mode", SPATIAL_CALIBRATION_DEFAULT_THETA_MODE) - ).strip() + str(value.get("theta_mode", SPATIAL_CALIBRATION_DEFAULT_THETA_MODE)).strip() or SPATIAL_CALIBRATION_DEFAULT_THETA_MODE ) schema = str(value.get("schema", SPATIAL_CALIBRATION_SCHEMA)).strip() if schema and schema != SPATIAL_CALIBRATION_SCHEMA: - raise ValueError( - f"Unsupported spatial calibration schema '{schema}'." - ) + raise ValueError(f"Unsupported spatial calibration schema '{schema}'.") if "stage_axis_map_zyx" in value: stage_axis_payload = value.get("stage_axis_map_zyx") @@ -4683,9 +4690,7 @@ def __post_init__(self) -> None: ) if selected_target is None: selected_target = self.analysis_targets[0] - self.analysis_selected_experiment_path = ( - selected_target.experiment_path - ) + self.analysis_selected_experiment_path = selected_target.experiment_path self.file = selected_target.store_path else: self.analysis_selected_experiment_path = ( @@ -4727,9 +4732,9 @@ def selected_analysis_target(self) -> Optional[AnalysisTarget]: Selected target resolved from ``analysis_targets`` and ``analysis_selected_experiment_path``. """ - selected_experiment_path = ( - str(self.analysis_selected_experiment_path or "").strip() - ) + selected_experiment_path = str( + self.analysis_selected_experiment_path or "" + ).strip() if not selected_experiment_path: return None for target in self.analysis_targets: diff --git a/tests/__init__.py b/tests/__init__.py index 04112ee..6af90c0 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -53,7 +53,11 @@ def download_test_registration_data() -> str: # The zip extracts files directly into downloaded_data/ if output_dir.exists() and output_dir.is_dir(): # Verify that the expected files exist - expected_files = ["cropped_fixed.tif", "cropped_moving.tif", "GenericAffine.mat"] + expected_files = [ + "cropped_fixed.tif", + "cropped_moving.tif", + "GenericAffine.mat", + ] files_exist = all((output_dir / f).exists() for f in expected_files) if files_exist: diff --git a/tests/conftest.py b/tests/conftest.py index 675e45d..3641c2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,6 @@ import pytest - if sys.platform == "darwin": # Use Qt's offscreen platform during tests so transient dialogs/viewers do # not flash on the interactive desktop. Callers can still override this by diff --git a/tests/flatfield/test_pipeline.py b/tests/flatfield/test_pipeline.py index bcb214c..80cbc4d 100644 --- a/tests/flatfield/test_pipeline.py +++ b/tests/flatfield/test_pipeline.py @@ -20,9 +20,7 @@ def test_run_flatfield_analysis_writes_latest_results( ) -> None: store_path = tmp_path / "analysis_store.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(2 * 1 * 2 * 3 * 4 * 4, dtype=np.uint16).reshape( - (2, 1, 2, 3, 4, 4) - ) + data = np.arange(2 * 1 * 2 * 3 * 4 * 4, dtype=np.uint16).reshape((2, 1, 2, 3, 4, 4)) root.create_dataset( name="data", data=data, @@ -96,7 +94,9 @@ def fit(self, images, skip_shape_warning=False) -> None: assert baseline.shape == (1, 2, 2, 3) assert np.allclose(flatfield, 2.0) assert np.allclose(darkfield, 0.5) - assert np.array_equal(baseline[0, 0], np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32)) + assert np.array_equal( + baseline[0, 0], np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32) + ) assert np.isclose(corrected[0, 0, 0, 0, 0, 0], (data[0, 0, 0, 0, 0, 0] - 0.5) / 2.0) assert latest.attrs["source_component"] == "data" assert latest.attrs["data_component"] == "results/flatfield/latest/data" @@ -113,9 +113,7 @@ def test_run_flatfield_analysis_materializes_output_pyramid( ) -> None: store_path = tmp_path / "analysis_store_pyramid.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(1 * 1 * 1 * 2 * 4 * 4, dtype=np.uint16).reshape( - (1, 1, 1, 2, 4, 4) - ) + data = np.arange(1 * 1 * 1 * 2 * 4 * 4, dtype=np.uint16).reshape((1, 1, 1, 2, 4, 4)) root.create_dataset( name="data", data=data, @@ -206,7 +204,9 @@ def test_run_flatfield_analysis_defaults_darkfield_off( root = zarr.open_group(str(store_path), mode="w") root.create_dataset( name="data", - data=np.arange(1 * 1 * 1 * 2 * 2 * 2, dtype=np.uint16).reshape((1, 1, 1, 2, 2, 2)), + data=np.arange(1 * 1 * 1 * 2 * 2 * 2, dtype=np.uint16).reshape( + (1, 1, 1, 2, 2, 2) + ), chunks=(1, 1, 1, 2, 2, 2), overwrite=True, ) @@ -337,7 +337,9 @@ def fit(self, images, skip_shape_warning=False) -> None: ) -def test_run_flatfield_analysis_tiled_fit_blends_tiles(tmp_path: Path, monkeypatch) -> None: +def test_run_flatfield_analysis_tiled_fit_blends_tiles( + tmp_path: Path, monkeypatch +) -> None: store_path = tmp_path / "analysis_store_tiled_blend.zarr" root = zarr.open_group(str(store_path), mode="w") data = np.arange(1 * 1 * 1 * 1 * 4 * 4, dtype=np.uint16).reshape((1, 1, 1, 1, 4, 4)) @@ -474,8 +476,16 @@ def _failing_transform(**kwargs): output_root = zarr.open_group(str(store_path), mode="r") checkpoint = output_root["results"]["flatfield"]["latest"]["checkpoint"] - assert int(np.count_nonzero(np.asarray(checkpoint["fit_tile_done_pcyx"], dtype=bool))) == 4 - assert int(np.count_nonzero(np.asarray(checkpoint["transform_done_tpcyx"], dtype=bool))) == 2 + assert ( + int(np.count_nonzero(np.asarray(checkpoint["fit_tile_done_pcyx"], dtype=bool))) + == 4 + ) + assert ( + int( + np.count_nonzero(np.asarray(checkpoint["transform_done_tpcyx"], dtype=bool)) + ) + == 2 + ) assert fit_calls["count"] == 4 resumed_transform_trace = tmp_path / "resumed_transform_calls.txt" @@ -579,7 +589,9 @@ def fit(self, images, skip_shape_warning=False) -> None: first_run_fit_calls = int(fit_calls["count"]) assert first_run_fit_calls == 4 - first_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"]["latest"] + first_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"][ + "latest" + ] first_fingerprint = str(first_latest.attrs["resume_parameter_fingerprint"]) client = create_dask_client(n_workers=1, threads_per_worker=1, processes=False) @@ -592,7 +604,9 @@ def fit(self, images, skip_shape_warning=False) -> None: finally: client.close() - second_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"]["latest"] + second_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"][ + "latest" + ] second_fingerprint = str(second_latest.attrs["resume_parameter_fingerprint"]) assert fit_calls["count"] == first_run_fit_calls assert second_fingerprint == first_fingerprint @@ -610,7 +624,9 @@ def fit(self, images, skip_shape_warning=False) -> None: finally: client.close() - third_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"]["latest"] + third_latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"][ + "latest" + ] third_fingerprint = str(third_latest.attrs["resume_parameter_fingerprint"]) assert fit_calls["count"] == first_run_fit_calls + 4 assert third_fingerprint != first_fingerprint @@ -681,9 +697,16 @@ def fit(self, images, skip_shape_warning=False) -> None: expected_chunk_shape = tuple(int(v) for v in checkpoint_array.chunks) chunk_root = ( - store_path / "results" / "flatfield" / "latest" / "checkpoint" / "fit_baseline_sum_pctz" + store_path + / "results" + / "flatfield" + / "latest" + / "checkpoint" + / "fit_baseline_sum_pctz" ) - chunk_files = [p for p in chunk_root.rglob("*") if p.is_file() and p.name != "attributes.json"] + chunk_files = [ + p for p in chunk_root.rglob("*") if p.is_file() and p.name != "attributes.json" + ] assert chunk_files, "Expected at least one written checkpoint chunk in N5 store." for chunk_file in chunk_files: with chunk_file.open("rb") as handle: @@ -754,8 +777,12 @@ def fit(self, images, skip_shape_warning=False) -> None: client.close() writable_root = zarr.open_group(str(store_path), mode="a") - malformed = writable_root["results"]["flatfield"]["latest"]["checkpoint"]["fit_baseline_sum_pctz"] - malformed[0, 0, :, :] = np.asarray(malformed[0, 0, :, :], dtype=np.float32) + np.float32(1.0) + malformed = writable_root["results"]["flatfield"]["latest"]["checkpoint"][ + "fit_baseline_sum_pctz" + ] + malformed[0, 0, :, :] = np.asarray( + malformed[0, 0, :, :], dtype=np.float32 + ) + np.float32(1.0) with pytest.raises(AssertionError, match="Expected chunk of shape"): np.asarray(malformed[0:1, 0:1, :, :], dtype=np.float32) @@ -834,7 +861,9 @@ def _full_volume_profile( ) monkeypatch.setattr(flatfield_pipeline, "_fit_profile_tile", _failing_tile) - monkeypatch.setattr(flatfield_pipeline, "_fit_profile_full_volume", _full_volume_profile) + monkeypatch.setattr( + flatfield_pipeline, "_fit_profile_full_volume", _full_volume_profile + ) client = create_dask_client(n_workers=1, threads_per_worker=1, processes=False) try: @@ -856,7 +885,9 @@ def _full_volume_profile( finally: client.close() - latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"]["latest"] + latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"][ + "latest" + ] checkpoint = latest["checkpoint"] fallback_records = list(checkpoint.attrs["fit_fallback_records"]) @@ -914,7 +945,9 @@ def _failing_full_volume( raise RuntimeError("full-volume fallback failed") monkeypatch.setattr(flatfield_pipeline, "_fit_profile_tile", _always_failing_tile) - monkeypatch.setattr(flatfield_pipeline, "_fit_profile_full_volume", _failing_full_volume) + monkeypatch.setattr( + flatfield_pipeline, "_fit_profile_full_volume", _failing_full_volume + ) client = create_dask_client(n_workers=1, threads_per_worker=1, processes=False) try: @@ -936,7 +969,9 @@ def _failing_full_volume( finally: client.close() - latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"]["latest"] + latest = zarr.open_group(str(store_path), mode="r")["results"]["flatfield"][ + "latest" + ] checkpoint = latest["checkpoint"] fallback_records = list(checkpoint.attrs["fit_fallback_records"]) diff --git a/tests/gui/test_gui_execution.py b/tests/gui/test_gui_execution.py index 561696a..2b899ea 100644 --- a/tests/gui/test_gui_execution.py +++ b/tests/gui/test_gui_execution.py @@ -334,6 +334,7 @@ def test_registration_operation_moves_to_preprocessing_tab() -> None: tab_map = dict(app_module.AnalysisSelectionDialog._OPERATION_TABS) + assert "display_pyramid" in tab_map["Preprocessing"] assert "registration" in tab_map["Preprocessing"] assert "registration" not in tab_map.get("Postprocessing", ()) @@ -2011,10 +2012,16 @@ def test_analysis_selection_dialog_uses_napari_and_visualization_labels() -> Non dialog_cls = app_module.AnalysisSelectionDialog assert dialog_cls._OPERATION_LABELS["visualization"] == "Napari" - assert dialog_cls._OPERATION_LABELS["display_pyramid"] == "Display Pyramid" + assert dialog_cls._OPERATION_LABELS["display_pyramid"] == "Pyramidal Downsampling" assert ( - "Visualization", - ("display_pyramid", "visualization", "mip_export"), + "Preprocessing", + ( + "flatfield", + "deconvolution", + "shear_transform", + "display_pyramid", + "registration", + ), ) in dialog_cls._OPERATION_TABS diff --git a/tests/io/test_experiment.py b/tests/io/test_experiment.py index f892ec9..207ea2c 100644 --- a/tests/io/test_experiment.py +++ b/tests/io/test_experiment.py @@ -112,8 +112,7 @@ def _write_bdv_xml( setup_blocks = [] for setup_index in sorted(setup_channel_tile): channel_index, tile_index = setup_channel_tile[setup_index] - setup_blocks.append( - f""" + setup_blocks.append(f""" {setup_index} {setup_index} @@ -129,8 +128,7 @@ def _write_bdv_xml( 0 -""".rstrip() - ) +""".rstrip()) xml = f""" @@ -185,7 +183,9 @@ def test_library_path_env_vars_for_platform_linux() -> None: ) == ("LD_LIBRARY_PATH",) -def test_build_library_path_environment_updates_merges_discovered_and_inherited() -> None: +def test_build_library_path_environment_updates_merges_discovered_and_inherited() -> ( + None +): updates = experiment_module._build_library_path_environment_updates( ["/cuda/runtime/lib", "/cuda/cudnn/lib"], env={"LD_LIBRARY_PATH": "/cluster/custom/lib"}, @@ -424,7 +424,9 @@ def test_initialize_analysis_store_backfills_identity_spatial_calibration( experiment = load_navigate_experiment(experiment_path) store_path = default_analysis_store_path(experiment) - initialize_analysis_store(experiment=experiment, zarr_path=store_path, overwrite=True) + initialize_analysis_store( + experiment=experiment, zarr_path=store_path, overwrite=True + ) calibration = load_store_spatial_calibration(store_path) @@ -486,7 +488,9 @@ def test_write_zyx_block_numpy(tmp_path: Path): _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="H5") experiment = load_navigate_experiment(experiment_path) store_path = default_analysis_store_path(experiment) - initialize_analysis_store(experiment=experiment, zarr_path=store_path, overwrite=True) + initialize_analysis_store( + experiment=experiment, zarr_path=store_path, overwrite=True + ) block = np.ones((4, 8, 16), dtype=np.uint16) write_zyx_block( @@ -597,7 +601,9 @@ def test_resolve_data_store_path_uses_experiment_directory_for_non_zarr(tmp_path save_directory.mkdir(parents=True, exist_ok=True) experiment_path = experiment_dir / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=save_directory, file_type="H5") + _write_minimal_experiment( + experiment_path, save_directory=save_directory, file_type="H5" + ) experiment = load_navigate_experiment(experiment_path) source_path = save_directory / "source.npy" @@ -608,9 +614,13 @@ def test_resolve_data_store_path_uses_experiment_directory_for_non_zarr(tmp_path assert resolved == (experiment_dir / "data_store.zarr").resolve() -def test_materialize_experiment_data_store_creates_data_store_for_non_zarr(tmp_path: Path): +def test_materialize_experiment_data_store_creates_data_store_for_non_zarr( + tmp_path: Path, +): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) @@ -642,7 +652,9 @@ def test_materialize_experiment_data_store_batches_chunk_writes( tmp_path: Path, monkeypatch ): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) @@ -681,7 +693,9 @@ def test_materialize_experiment_data_store_resumes_after_interrupted_base_write( monkeypatch, ): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) @@ -749,7 +763,9 @@ def test_materialize_experiment_data_store_handles_multibatch_base_and_pyramid( tmp_path: Path, ): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(5 * 17 * 17, dtype=np.uint16).reshape(5, 17, 17) @@ -781,7 +797,9 @@ def test_materialize_experiment_data_store_reuses_existing_zarr_store(tmp_path: source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) source_store = tmp_path / "source.ome.zarr" source_root = zarr.open_group(str(source_store), mode="w") - source_root.create_dataset("raw", data=source_data, chunks=(1, 3, 4), overwrite=True) + source_root.create_dataset( + "raw", data=source_data, chunks=(1, 3, 4), overwrite=True + ) source_root["raw"].attrs["_ARRAY_DIMENSIONS"] = ["z", "y", "x"] materialized = materialize_experiment_data_store( @@ -860,7 +878,9 @@ def test_has_complete_canonical_data_store_requires_completed_progress_record( tmp_path: Path, ): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) @@ -902,7 +922,9 @@ def test_materialize_experiment_data_store_reuses_complete_store_by_default_and_ tmp_path: Path, ): experiment_path = tmp_path / "experiment.yml" - _write_minimal_experiment(experiment_path, save_directory=tmp_path, file_type="TIFF") + _write_minimal_experiment( + experiment_path, save_directory=tmp_path, file_type="TIFF" + ) experiment = load_navigate_experiment(experiment_path) source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) @@ -943,7 +965,9 @@ def test_materialize_experiment_data_store_reuses_complete_store_by_default_and_ assert rebuilt_root.attrs["data_pyramid_levels"] == ["data", "data_pyramid/level_1"] -def test_materialize_experiment_data_store_handles_same_component_rewrite(tmp_path: Path): +def test_materialize_experiment_data_store_handles_same_component_rewrite( + tmp_path: Path, +): experiment_path = tmp_path / "experiment.yml" _write_minimal_experiment( experiment_path, save_directory=tmp_path, file_type="OME-ZARR" @@ -953,7 +977,9 @@ def test_materialize_experiment_data_store_handles_same_component_rewrite(tmp_pa source_data = np.arange(24, dtype=np.uint16).reshape(2, 3, 4) source_store = tmp_path / "source_data.zarr" source_root = zarr.open_group(str(source_store), mode="w") - source_root.create_dataset("data", data=source_data, chunks=(1, 3, 4), overwrite=True) + source_root.create_dataset( + "data", data=source_data, chunks=(1, 3, 4), overwrite=True + ) source_root["data"].attrs["_ARRAY_DIMENSIONS"] = ["z", "y", "x"] materialize_experiment_data_store( @@ -1011,9 +1037,7 @@ def test_materialize_experiment_data_store_stacks_tiff_positions_and_channels( for position_index in range(3): for channel_index in range(2): - loaded = np.array( - root["data"][0, position_index, channel_index, :, :, :] - ) + loaded = np.array(root["data"][0, position_index, channel_index, :, :, :]) assert np.array_equal( loaded, expected_blocks[(position_index, channel_index)], @@ -1093,10 +1117,10 @@ def test_materialize_experiment_data_store_stacks_bdv_h5_setups( for position_index in range(2): for channel_index in range(2): - loaded = np.array( - root["data"][0, position_index, channel_index, :, :, :] + loaded = np.array(root["data"][0, position_index, channel_index, :, :, :]) + assert np.array_equal( + loaded, expected_blocks[(position_index, channel_index)] ) - assert np.array_equal(loaded, expected_blocks[(position_index, channel_index)]) def test_materialize_experiment_data_store_stacks_bdv_n5_setups( @@ -1176,10 +1200,10 @@ def test_materialize_experiment_data_store_stacks_bdv_n5_setups( for position_index in range(2): for channel_index in range(2): - loaded = np.array( - root["data"][0, position_index, channel_index, :, :, :] + loaded = np.array(root["data"][0, position_index, channel_index, :, :, :]) + assert np.array_equal( + loaded, expected_blocks[(position_index, channel_index)] ) - assert np.array_equal(loaded, expected_blocks[(position_index, channel_index)]) def test_materialize_experiment_data_store_stacks_bdv_ome_zarr_setups( @@ -1261,10 +1285,10 @@ def test_materialize_experiment_data_store_stacks_bdv_ome_zarr_setups( for position_index in range(2): for channel_index in range(2): - loaded = np.array( - root["data"][0, position_index, channel_index, :, :, :] + loaded = np.array(root["data"][0, position_index, channel_index, :, :, :]) + assert np.array_equal( + loaded, expected_blocks[(position_index, channel_index)] ) - assert np.array_equal(loaded, expected_blocks[(position_index, channel_index)]) def test_should_use_source_aligned_plane_writes_detects_plane_chunk_pattern(): @@ -1305,10 +1329,14 @@ def test_materialize_experiment_data_store_uses_source_aligned_plane_writes( source_data = np.arange(4 * 8 * 10, dtype=np.uint16).reshape(4, 8, 10) source_store = tmp_path / "source.ome.zarr" source_root = zarr.open_group(str(source_store), mode="w") - source_root.create_dataset("raw", data=source_data, chunks=(1, 8, 10), overwrite=True) + source_root.create_dataset( + "raw", data=source_data, chunks=(1, 8, 10), overwrite=True + ) source_root["raw"].attrs["_ARRAY_DIMENSIONS"] = ["z", "y", "x"] - original_source_writer = experiment_module._write_dask_array_source_aligned_plane_batches + original_source_writer = ( + experiment_module._write_dask_array_source_aligned_plane_batches + ) original_chunk_writer = experiment_module._write_dask_array_in_batches writer_calls = {"source_aligned": 0, "chunk_batched": 0} @@ -1342,7 +1370,9 @@ def _count_chunk_writer(**kwargs): assert writer_calls["source_aligned"] == 1 assert writer_calls["chunk_batched"] == 0 assert np.array_equal(np.array(root["data"][0, 0, 0, :, :, :]), source_data) - assert root.attrs["materialization_write_strategy"] == "source_aligned_plane_batches" + assert ( + root.attrs["materialization_write_strategy"] == "source_aligned_plane_batches" + ) assert root.attrs["source_aligned_z_batch_depth"] == 2 assert root.attrs["source_aligned_worker_count"] is None assert root.attrs["source_aligned_worker_memory_limit_bytes"] is None @@ -1362,10 +1392,14 @@ def test_materialize_experiment_data_store_falls_back_to_chunk_batched_writes( source_data = np.arange(4 * 8 * 10, dtype=np.uint16).reshape(4, 8, 10) source_store = tmp_path / "source.ome.zarr" source_root = zarr.open_group(str(source_store), mode="w") - source_root.create_dataset("raw", data=source_data, chunks=(2, 8, 10), overwrite=True) + source_root.create_dataset( + "raw", data=source_data, chunks=(2, 8, 10), overwrite=True + ) source_root["raw"].attrs["_ARRAY_DIMENSIONS"] = ["z", "y", "x"] - original_source_writer = experiment_module._write_dask_array_source_aligned_plane_batches + original_source_writer = ( + experiment_module._write_dask_array_source_aligned_plane_batches + ) original_chunk_writer = experiment_module._write_dask_array_in_batches writer_calls = {"source_aligned": 0, "chunk_batched": 0} @@ -1415,8 +1449,8 @@ def scheduler_info(self): } } - worker_count, worker_memory_limit = experiment_module._detect_client_worker_resources( - _FakeClient() + worker_count, worker_memory_limit = ( + experiment_module._detect_client_worker_resources(_FakeClient()) ) assert worker_count == 2 diff --git a/tests/io/test_provenance.py b/tests/io/test_provenance.py index 0ca3d0a..db588f9 100644 --- a/tests/io/test_provenance.py +++ b/tests/io/test_provenance.py @@ -100,7 +100,10 @@ def test_persist_run_provenance_hash_chain(tmp_path: Path): assert "display_pyramid" in record_1["workflow"]["selected_analyses"] assert "usegment3d" in record_1["workflow"]["selected_analyses"] assert record_1["workflow"]["spatial_calibration_text"] == "z=+z,y=+y,x=+x" - assert record_1["workflow"]["zarr_chunks_ptczyx"] == "p=1, t=1, c=1, z=256, y=256, x=256" + assert ( + record_1["workflow"]["zarr_chunks_ptczyx"] + == "p=1, t=1, c=1, z=256, y=256, x=256" + ) assert "z=1,2,4,8" in record_1["workflow"]["zarr_pyramid_ptczyx"] valid, issues = verify_provenance_chain(store_path) @@ -361,4 +364,6 @@ def test_latest_analysis_gui_state_round_trip(tmp_path: Path) -> None: assert loaded["source"] == "unit_test" assert loaded["updated_utc"] is not None assert loaded["workflow"]["flatfield"] is True - assert loaded["workflow"]["analysis_parameters"]["flatfield"]["execution_order"] == 1 + assert ( + loaded["workflow"]["analysis_parameters"]["flatfield"]["execution_order"] == 1 + ) diff --git a/tests/io/test_read.py b/tests/io/test_read.py index 327d858..9b8c01e 100644 --- a/tests/io/test_read.py +++ b/tests/io/test_read.py @@ -51,7 +51,6 @@ ) from tests import download_test_registration_data - # ============================================================================= # Test ImageInfo Dataclass # ============================================================================= diff --git a/tests/mip_export/test_pipeline.py b/tests/mip_export/test_pipeline.py index 3758f50..bfe598d 100644 --- a/tests/mip_export/test_pipeline.py +++ b/tests/mip_export/test_pipeline.py @@ -156,7 +156,9 @@ def test_run_mip_export_analysis_writes_uint16_ome_tiff_outputs_with_calibration assert yz_pixels["PhysicalSizeY"] == "2.0" latest_attrs = dict( - zarr.open_group(str(store_path), mode="r")["results"]["mip_export"]["latest"].attrs + zarr.open_group(str(store_path), mode="r")["results"]["mip_export"][ + "latest" + ].attrs ) assert latest_attrs["exported_files"] == 6 assert latest_attrs["export_format"] == "ome-tiff" @@ -174,9 +176,7 @@ def test_run_mip_export_analysis_writes_multi_position_zarr_outputs( ) -> None: store_path = tmp_path / "mip_zarr_store.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(2 * 3 * 2 * 4 * 3 * 5, dtype=np.uint16).reshape( - (2, 3, 2, 4, 3, 5) - ) + data = np.arange(2 * 3 * 2 * 4 * 3 * 5, dtype=np.uint16).reshape((2, 3, 2, 4, 3, 5)) root.create_dataset( name="data", data=data, @@ -219,17 +219,23 @@ def test_run_mip_export_analysis_writes_multi_position_zarr_outputs( assert list(yz_root["data"].attrs["axes"]) == ["p", "z", "y"] -@pytest.mark.parametrize(("projection", "expected_axis"), [("xy", 0), ("xz", 1), ("yz", 2)]) +@pytest.mark.parametrize( + ("projection", "expected_axis"), [("xy", 0), ("xz", 1), ("yz", 2)] +) def test_run_export_task_reads_single_position_source_in_blocks( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, projection: str, expected_axis: int, ) -> None: - data = np.arange(1 * 1 * 1 * 4 * 6 * 8, dtype=np.float32).reshape((1, 1, 1, 4, 6, 8)) + data = np.arange(1 * 1 * 1 * 4 * 6 * 8, dtype=np.float32).reshape( + (1, 1, 1, 4, 6, 8) + ) guarded = _GuardedSourceArray(data, chunks=(1, 1, 1, 2, 3, 4)) fake_group = {"data": guarded} - monkeypatch.setattr(pipeline.zarr, "open_group", lambda *_args, **_kwargs: fake_group) + monkeypatch.setattr( + pipeline.zarr, "open_group", lambda *_args, **_kwargs: fake_group + ) task = pipeline._MipExportTask( projection=projection, @@ -258,10 +264,14 @@ def test_run_export_task_reads_multi_position_source_in_blocks( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - data = np.arange(1 * 3 * 1 * 4 * 6 * 8, dtype=np.float32).reshape((1, 3, 1, 4, 6, 8)) + data = np.arange(1 * 3 * 1 * 4 * 6 * 8, dtype=np.float32).reshape( + (1, 3, 1, 4, 6, 8) + ) guarded = _GuardedSourceArray(data, chunks=(1, 3, 1, 2, 3, 4)) fake_group = {"data": guarded} - monkeypatch.setattr(pipeline.zarr, "open_group", lambda *_args, **_kwargs: fake_group) + monkeypatch.setattr( + pipeline.zarr, "open_group", lambda *_args, **_kwargs: fake_group + ) task = pipeline._MipExportTask( projection="xy", @@ -293,9 +303,7 @@ def test_run_mip_export_analysis_distributed_writes_expected_outputs( store_path = tmp_path / "mip_distributed_store.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(1 * 2 * 1 * 4 * 6 * 8, dtype=np.uint16).reshape( - (1, 2, 1, 4, 6, 8) - ) + data = np.arange(1 * 2 * 1 * 4 * 6 * 8, dtype=np.uint16).reshape((1, 2, 1, 4, 6, 8)) root.create_dataset( name="data", data=data, diff --git a/tests/registration/test_image_registration.py b/tests/registration/test_image_registration.py index d5a0811..538d1a6 100644 --- a/tests/registration/test_image_registration.py +++ b/tests/registration/test_image_registration.py @@ -150,8 +150,16 @@ def test_register_uses_instance_attributes(self): # Mock the internal methods to avoid actual registration mock_image = MagicMock() mock_mask = MagicMock() - with patch.object(reg, '_perform_linear_registration', return_value=(mock_image, mock_mask)): - with patch.object(reg, '_perform_nonlinear_registration', return_value=(mock_image, mock_mask)): + with patch.object( + reg, + "_perform_linear_registration", + return_value=(mock_image, mock_mask), + ): + with patch.object( + reg, + "_perform_nonlinear_registration", + return_value=(mock_image, mock_mask), + ): reg.register() # Verify that logging was initialized @@ -176,7 +184,7 @@ def test_register_round_creates_and_calls_image_registration(self): np.save(moving_path, moving_arr) # Mock ImageRegistration to avoid actual registration - with patch('clearex.registration.ImageRegistration') as MockReg: + with patch("clearex.registration.ImageRegistration") as MockReg: mock_instance = MockReg.return_value register_round( @@ -200,4 +208,3 @@ def test_register_round_creates_and_calls_image_registration(self): # Verify register was called mock_instance.register.assert_called_once() - diff --git a/tests/registration/test_pipeline.py b/tests/registration/test_pipeline.py index da42407..91ad7fb 100644 --- a/tests/registration/test_pipeline.py +++ b/tests/registration/test_pipeline.py @@ -80,9 +80,11 @@ def _create_registration_store( } experiment_path = _write_multiposition_sidecar( tmp_path, - rows=[[0.0, 0.0, 0.0, 0.0, 0.0], [4.0, 0.0, 0.0, 0.0, 0.0], [8.0, 0.0, 0.0, 0.0, 0.0]][ - :positions - ], + rows=[ + [0.0, 0.0, 0.0, 0.0, 0.0], + [4.0, 0.0, 0.0, 0.0, 0.0], + [8.0, 0.0, 0.0, 0.0, 0.0], + ][:positions], ) root.attrs["source_experiment"] = str(experiment_path) root.attrs["data_pyramid_factors_tpczyx"] = [ @@ -166,6 +168,39 @@ def test_resolve_source_components_for_level_rejects_missing_level( ) +def test_resolve_source_components_for_level_uses_source_adjacent_pyramid( + tmp_path: Path, +) -> None: + store_path = tmp_path / "registration_source_adjacent_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + root.create_dataset( + name="results/shear_transform/latest/data", + shape=(1, 2, 1, 4, 4, 4), + chunks=(1, 1, 1, 4, 4, 4), + dtype="uint16", + overwrite=True, + ) + root.create_dataset( + name="results/shear_transform/latest/data_pyramid/level_1", + shape=(1, 2, 1, 2, 2, 2), + chunks=(1, 1, 1, 2, 2, 2), + dtype="uint16", + overwrite=True, + ) + + source_component, pairwise_component, level = ( + registration_pipeline._resolve_source_components_for_level( + root=root, + requested_source_component="results/shear_transform/latest/data", + input_resolution_level=1, + ) + ) + + assert source_component == "results/shear_transform/latest/data" + assert pairwise_component == "results/shear_transform/latest/data_pyramid/level_1" + assert level == 1 + + def test_position_centroid_anchor_prefers_central_tile() -> None: stage_rows = [ {"x": 0.0, "y": 0.0, "z": 0.0, "theta": 0.0, "f": 0.0}, @@ -190,9 +225,7 @@ def test_solve_with_pruning_recovers_translation_and_prunes_outlier() -> None: registration_pipeline._EdgeSpec( 1, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 10_000 ), - registration_pipeline._EdgeSpec( - 0, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 1 - ), + registration_pipeline._EdgeSpec(0, 2, ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0)), 1), ] results = [ _edge_result(edges[0], _translation_matrix(1.0)), @@ -298,7 +331,11 @@ def _fake_pairwise(**kwargs): edge = kwargs["edge"] t_index = int(kwargs["t_index"]) correction = np.eye(4, dtype=np.float64) - if t_index == 1 and int(edge.fixed_position) == 1 and int(edge.moving_position) == 2: + if ( + t_index == 1 + and int(edge.fixed_position) == 1 + and int(edge.moving_position) == 2 + ): correction[0, 3] = -1.0 return { "fixed_position": int(edge.fixed_position), diff --git a/tests/segmentation/__init__.py b/tests/segmentation/__init__.py index 59465b8..3074984 100755 --- a/tests/segmentation/__init__.py +++ b/tests/segmentation/__init__.py @@ -23,4 +23,3 @@ # IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. - diff --git a/tests/segmentation/test_pointsource.py b/tests/segmentation/test_pointsource.py index 4dbd5b8..ddf7de6 100755 --- a/tests/segmentation/test_pointsource.py +++ b/tests/segmentation/test_pointsource.py @@ -26,8 +26,11 @@ import numpy as np import pytest -from clearex.detect.particles import remove_close_blobs, sort_by_point_source_intensity, \ - eliminate_insignificant_point_sources +from clearex.detect.particles import ( + remove_close_blobs, + sort_by_point_source_intensity, + eliminate_insignificant_point_sources, +) def test_remove_close_blobs_isotropic(): diff --git a/tests/shear/test_pipeline.py b/tests/shear/test_pipeline.py index 93a98c0..d135aa7 100644 --- a/tests/shear/test_pipeline.py +++ b/tests/shear/test_pipeline.py @@ -126,9 +126,7 @@ def test_run_shear_transform_auto_estimate_updates_applied_shear( def test_run_shear_transform_identity_preserves_data(tmp_path: Path) -> None: store_path = tmp_path / "shear_identity.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(1 * 1 * 1 * 4 * 4 * 4, dtype=np.uint16).reshape( - (1, 1, 1, 4, 4, 4) - ) + data = np.arange(1 * 1 * 1 * 4 * 4 * 4, dtype=np.uint16).reshape((1, 1, 1, 4, 4, 4)) root.create_dataset( name="data", data=data, @@ -157,7 +155,9 @@ def test_run_shear_transform_identity_preserves_data(tmp_path: Path) -> None: ) output = np.asarray( - zarr.open_group(str(store_path), mode="r")["results/shear_transform/latest/data"] + zarr.open_group(str(store_path), mode="r")[ + "results/shear_transform/latest/data" + ] ) assert summary.data_component == "results/shear_transform/latest/data" assert output.shape == data.shape @@ -194,7 +194,9 @@ def test_run_shear_transform_emits_larger_bounds_for_nonzero_shear( ) output = np.asarray( - zarr.open_group(str(store_path), mode="r")["results/shear_transform/latest/data"] + zarr.open_group(str(store_path), mode="r")[ + "results/shear_transform/latest/data" + ] ) assert output.shape == summary.output_shape_tpczyx assert np.max(output) > 0.0 @@ -229,7 +231,9 @@ def test_run_shear_transform_linear_normalizes_edge_support(tmp_path: Path) -> N ) output = np.asarray( - zarr.open_group(str(store_path), mode="r")["results/shear_transform/latest/data"] + zarr.open_group(str(store_path), mode="r")[ + "results/shear_transform/latest/data" + ] ) positive = output[output > 0.0] assert positive.size > 0 @@ -244,9 +248,7 @@ def test_run_shear_transform_identity_with_distributed_client( store_path = tmp_path / "shear_identity_distributed.zarr" root = zarr.open_group(str(store_path), mode="w") - data = np.arange(1 * 1 * 1 * 4 * 4 * 4, dtype=np.uint16).reshape( - (1, 1, 1, 4, 4, 4) - ) + data = np.arange(1 * 1 * 1 * 4 * 4 * 4, dtype=np.uint16).reshape((1, 1, 1, 4, 4, 4)) root.create_dataset( name="data", data=data, @@ -282,7 +284,9 @@ def test_run_shear_transform_identity_with_distributed_client( ) output = np.asarray( - zarr.open_group(str(store_path), mode="r")["results/shear_transform/latest/data"] + zarr.open_group(str(store_path), mode="r")[ + "results/shear_transform/latest/data" + ] ) assert output.shape == data.shape assert summary.output_shape_tpczyx == data.shape diff --git a/tests/test_main.py b/tests/test_main.py index 3a913f4..d2c42e1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1194,7 +1194,9 @@ def test_run_workflow_experiment_input_persists_explicit_identity_spatial_calibr monkeypatch.setattr(main_module, "_configure_dask_backend", lambda **kwargs: None) monkeypatch.setattr(main_module, "is_navigate_experiment_file", lambda path: True) - monkeypatch.setattr(main_module, "load_navigate_experiment", lambda path: experiment) + monkeypatch.setattr( + main_module, "load_navigate_experiment", lambda path: experiment + ) monkeypatch.setattr( main_module, "resolve_experiment_data_path", lambda experiment: source_path ) @@ -1271,7 +1273,9 @@ def test_run_workflow_experiment_input_without_override_preserves_store_mapping( monkeypatch.setattr(main_module, "_configure_dask_backend", lambda **kwargs: None) monkeypatch.setattr(main_module, "is_navigate_experiment_file", lambda path: True) - monkeypatch.setattr(main_module, "load_navigate_experiment", lambda path: experiment) + monkeypatch.setattr( + main_module, "load_navigate_experiment", lambda path: experiment + ) monkeypatch.setattr( main_module, "resolve_experiment_data_path", lambda experiment: source_path ) @@ -1894,7 +1898,9 @@ def _fake_flatfield(*, zarr_path, parameters, client, progress_callback): def _should_not_run(*args, **kwargs): del args, kwargs - raise AssertionError("deconvolution should not run when upstream output is missing") + raise AssertionError( + "deconvolution should not run when upstream output is missing" + ) monkeypatch.setattr( main_module, "_configure_dask_backend", _fake_configure_dask_backend @@ -1939,8 +1945,7 @@ def _should_not_run(*args, **kwargs): assert any( str(step.get("name")) == "deconvolution" and dict(step.get("parameters", {})).get("status") == "failed" - and dict(step.get("parameters", {})).get("reason") - == "missing_input_dependency" + and dict(step.get("parameters", {})).get("reason") == "missing_input_dependency" and dict(step.get("parameters", {})).get("requested_input") == "flatfield" and dict(step.get("parameters", {})).get("resolved_input") == "results/flatfield/latest/data" diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 1ff978e..8d23a16 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -179,7 +179,7 @@ def test_default_zarr_save_config(self): assert cfg.analysis_parameters["shear_transform"]["execution_order"] == 3 assert cfg.analysis_parameters["shear_transform"]["interpolation"] == "linear" assert "registration" in cfg.analysis_parameters - assert cfg.analysis_parameters["registration"]["execution_order"] == 4 + assert cfg.analysis_parameters["registration"]["execution_order"] == 5 assert cfg.analysis_parameters["registration"]["input_source"] == "data" assert cfg.analysis_parameters["registration"]["registration_channel"] == 0 assert cfg.analysis_parameters["registration"]["registration_type"] == "rigid" @@ -189,10 +189,10 @@ def test_default_zarr_save_config(self): assert cfg.analysis_parameters["registration"]["blend_mode"] == "feather" assert "particle_detection" in cfg.analysis_parameters assert cfg.analysis_parameters["particle_detection"]["bg_sigma"] == 20.0 - assert cfg.analysis_parameters["particle_detection"]["execution_order"] == 5 + assert cfg.analysis_parameters["particle_detection"]["execution_order"] == 6 assert cfg.analysis_parameters["particle_detection"]["input_source"] == "data" assert "usegment3d" in cfg.analysis_parameters - assert cfg.analysis_parameters["usegment3d"]["execution_order"] == 6 + assert cfg.analysis_parameters["usegment3d"]["execution_order"] == 7 assert cfg.analysis_parameters["usegment3d"]["input_source"] == "data" assert cfg.analysis_parameters["usegment3d"]["all_channels"] is False assert cfg.analysis_parameters["usegment3d"]["channel_indices"] == [0] @@ -202,7 +202,7 @@ def test_default_zarr_save_config(self): ) assert cfg.analysis_parameters["usegment3d"]["save_native_labels"] is False assert "display_pyramid" in cfg.analysis_parameters - assert cfg.analysis_parameters["display_pyramid"]["execution_order"] == 7 + assert cfg.analysis_parameters["display_pyramid"]["execution_order"] == 4 assert cfg.analysis_parameters["display_pyramid"]["input_source"] == "data" assert cfg.analysis_parameters["visualization"]["show_all_positions"] is False assert cfg.analysis_parameters["visualization"]["position_index"] == 0 @@ -965,9 +965,9 @@ def test_normalize_analysis_operation_parameters_returns_defaults(): assert normalized["deconvolution"]["execution_order"] == 2 assert normalized["flatfield"]["execution_order"] == 1 assert normalized["shear_transform"]["execution_order"] == 3 - assert normalized["registration"]["execution_order"] == 4 - assert normalized["usegment3d"]["execution_order"] == 6 - assert normalized["display_pyramid"]["execution_order"] == 7 + assert normalized["registration"]["execution_order"] == 5 + assert normalized["usegment3d"]["execution_order"] == 7 + assert normalized["display_pyramid"]["execution_order"] == 4 assert normalized["visualization"]["input_source"] == "data" assert normalized["visualization"]["show_all_positions"] is False assert normalized["visualization"]["use_multiscale"] is True @@ -1030,10 +1030,10 @@ def test_analysis_operation_order_contains_expected_keys(): "flatfield", "deconvolution", "shear_transform", + "display_pyramid", + "registration", "particle_detection", "usegment3d", - "registration", - "display_pyramid", "visualization", "mip_export", ) diff --git a/tests/usegment3d/test_pipeline.py b/tests/usegment3d/test_pipeline.py index 6638646..539bbf1 100644 --- a/tests/usegment3d/test_pipeline.py +++ b/tests/usegment3d/test_pipeline.py @@ -254,9 +254,7 @@ def test_remap_usegment3d_runtime_stdout_line( raw_line: str, expected: str, ) -> None: - assert ( - usegment_pipeline._remap_usegment3d_runtime_stdout_line(raw_line) == expected - ) + assert usegment_pipeline._remap_usegment3d_runtime_stdout_line(raw_line) == expected class _PreprocessPrintRuntimeModule(_FakeRuntimeModule): diff --git a/tests/visualization/test_pipeline.py b/tests/visualization/test_pipeline.py index 90e299f..8148e68 100644 --- a/tests/visualization/test_pipeline.py +++ b/tests/visualization/test_pipeline.py @@ -373,7 +373,7 @@ def test_run_display_pyramid_analysis_materializes_levels_and_contrast_metadata( latest_attrs = dict(output_root["results"]["display_pyramid"]["latest"].attrs) assert len(source_attrs["display_pyramid_levels"]) > 1 assert all( - str(component).startswith("results/display_pyramid/by_component/") + str(component).startswith("results/shear_transform/latest/data_pyramid/") or str(component) == "results/shear_transform/latest/data" for component in source_attrs["display_pyramid_levels"] ) @@ -447,6 +447,61 @@ def _unexpected_rechunk(self, *args, **kwargs): assert len(summary.source_components) > 1 +def test_run_display_pyramid_analysis_rebuilds_legacy_component_layout( + tmp_path: Path, +) -> None: + store_path = tmp_path / "analysis_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + root.create_dataset( + name="results/shear_transform/latest/data", + shape=(1, 1, 1, 8, 8, 8), + chunks=(1, 1, 1, 4, 4, 4), + dtype="uint16", + overwrite=True, + ) + legacy_level_component = ( + "results/display_pyramid/by_component/legacy_shear_cache/level_1" + ) + root.create_dataset( + name=legacy_level_component, + shape=(1, 1, 1, 4, 4, 4), + chunks=(1, 1, 1, 4, 4, 4), + dtype="uint16", + overwrite=True, + ) + root["results/shear_transform/latest/data"].attrs["display_pyramid_levels"] = [ + "results/shear_transform/latest/data", + legacy_level_component, + ] + root.attrs["display_pyramid_levels_by_component"] = { + "results/shear_transform/latest/data": [ + "results/shear_transform/latest/data", + legacy_level_component, + ] + } + + summary = run_display_pyramid_analysis( + zarr_path=store_path, + parameters={"input_source": "results/shear_transform/latest/data"}, + ) + + assert summary.reused_existing_levels is False + assert len(summary.source_components) > 1 + assert summary.source_components[0] == "results/shear_transform/latest/data" + assert summary.source_components[1].startswith( + "results/shear_transform/latest/data_pyramid/level_" + ) + + output_root = zarr.open_group(str(store_path), mode="r") + assert ( + "results/shear_transform/latest/data_pyramid/level_1" in output_root + ), "Expected source-adjacent level_1 pyramid after migration." + source_attrs = dict(output_root["results/shear_transform/latest/data"].attrs) + assert source_attrs["display_pyramid_levels"][1].startswith( + "results/shear_transform/latest/data_pyramid/level_" + ) + + def test_run_visualization_analysis_uses_experiment_spacing_when_available( tmp_path: Path, monkeypatch ) -> None: @@ -1078,7 +1133,9 @@ def run(self) -> None: viewer = fake_napari.viewer assert viewer is not None assert len(viewer.image_calls) == 2 - assert all(str(kwargs["blending"]) == "translucent" for kwargs in viewer.image_calls) + assert all( + str(kwargs["blending"]) == "translucent" for kwargs in viewer.image_calls + ) def test_launch_napari_viewer_requests_3d_display_mode( From 8abebea4d74e7cccfb697f902c3befe236baff96 Mon Sep 17 00:00:00 2001 From: Kevin Dean Date: Sat, 21 Mar 2026 16:54:53 -0500 Subject: [PATCH 4/4] Display pyramid: add per-level progress updates and chunk-estimate logging --- src/clearex/visualization/pipeline.py | 258 +++++++++++++++++++++++++- tests/visualization/test_pipeline.py | 59 ++++++ 2 files changed, 311 insertions(+), 6 deletions(-) diff --git a/src/clearex/visualization/pipeline.py b/src/clearex/visualization/pipeline.py index b750c32..4c2503d 100644 --- a/src/clearex/visualization/pipeline.py +++ b/src/clearex/visualization/pipeline.py @@ -33,16 +33,19 @@ from dataclasses import dataclass from datetime import datetime, timezone import json +import logging import math from pathlib import Path import re import subprocess import sys import threading +import time from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union # Third Party Imports import dask.array as da +from dask.callbacks import Callback import numpy as np import zarr @@ -79,6 +82,8 @@ _DISPLAY_CONTRAST_PERCENTILES_ATTR = "display_contrast_percentiles" _DISPLAY_CONTRAST_LEVEL_SOURCE_ATTR = "display_contrast_source_component" _DISPLAY_CONTRAST_SAMPLE_TARGET_VOXELS = 2_000_000 +_DISPLAY_PYRAMID_BUILD_PROGRESS_TASK_STEP = 1_024 +_DISPLAY_PYRAMID_BUILD_PROGRESS_MIN_INTERVAL_SECONDS = 1.5 _SOFTWARE_RENDERER_HINTS = ( "llvmpipe", "softpipe", @@ -104,6 +109,7 @@ "tesla", "rtx", ) +_LOGGER = logging.getLogger(__name__) @dataclass(frozen=True) @@ -2058,6 +2064,189 @@ def _component_matches_shape_chunks( return array_chunks == tuple(int(v) for v in chunks_tpczyx) +def _estimate_chunk_region_count_tpczyx( + *, + shape_tpczyx: Sequence[int], + chunks_tpczyx: Sequence[int], +) -> int: + """Estimate chunk-region count for one canonical ``(t,p,c,z,y,x)`` array. + + Parameters + ---------- + shape_tpczyx : sequence[int] + Array shape in canonical axis order. + chunks_tpczyx : sequence[int] + Chunk shape in canonical axis order. + + Returns + ------- + int + Estimated chunk-region count. Returns ``0`` for invalid payloads. + """ + if len(tuple(shape_tpczyx)) != 6 or len(tuple(chunks_tpczyx)) != 6: + return 0 + region_counts = [ + int(max(1, math.ceil(float(max(1, int(size))) / float(max(1, int(chunk)))))) + for size, chunk in zip(shape_tpczyx, chunks_tpczyx, strict=False) + ] + return int(math.prod(region_counts)) + + +def _write_display_pyramid_level_with_progress( + *, + level_array: da.Array, + zarr_path: Union[str, Path], + level_component: str, + source_component: str, + level_index: int, + total_levels: int, + absolute_factors_tpczyx: Sequence[int], + level_shape_tpczyx: Sequence[int], + level_chunks_tpczyx: Sequence[int], + progress_callback: Optional[ProgressCallback] = None, + progress_start: int = 0, + progress_end: int = 100, +) -> None: + """Write one display-pyramid level and emit incremental progress updates. + + Parameters + ---------- + level_array : dask.array.Array + Prepared downsampled level data. + zarr_path : str or pathlib.Path + Analysis-store path. + level_component : str + Target level component path. + source_component : str + Source component path used to derive this level. + level_index : int + One-based level index currently being written. + total_levels : int + Total number of generated levels. + absolute_factors_tpczyx : sequence[int] + Absolute level factors in canonical order. + level_shape_tpczyx : sequence[int] + Target level shape in canonical order. + level_chunks_tpczyx : sequence[int] + Target level chunks in canonical order. + progress_callback : callable, optional + Callback receiving ``(percent, message)`` updates. + progress_start : int, default=0 + Start percent for this level write stage. + progress_end : int, default=100 + End percent for this level write stage. + + Returns + ------- + None + Writes the level into the configured Zarr/N5 store. + """ + + def _emit(percent: int, message: str) -> None: + if progress_callback is None: + return + progress_callback(int(percent), str(message)) + + normalized_shape = tuple(int(max(1, int(v))) for v in level_shape_tpczyx) + normalized_chunks = tuple(int(max(1, int(v))) for v in level_chunks_tpczyx) + estimated_chunks = _estimate_chunk_region_count_tpczyx( + shape_tpczyx=normalized_shape, + chunks_tpczyx=normalized_chunks, + ) + + write_graph = da.to_zarr( + level_array, + str(zarr_path), + component=level_component, + overwrite=True, + compute=False, + ) + total_tasks = max(1, int(len(getattr(write_graph, "dask", {})))) + progress_start_int = int(progress_start) + progress_end_int = int(max(progress_start_int, int(progress_end))) + progress_span = max(1, progress_end_int - progress_start_int) + + _LOGGER.info( + "[display_pyramid] writing level %d/%d component=%s source=%s " + "shape=%s chunks=%s estimated_chunks=%d tasks=%d factors_tpczyx=%s", + int(level_index), + int(total_levels), + str(level_component), + str(source_component), + normalized_shape, + normalized_chunks, + int(estimated_chunks), + int(total_tasks), + [int(value) for value in absolute_factors_tpczyx], + ) + _emit( + progress_start_int, + "Writing pyramid level " + f"{int(level_index)}/{int(total_levels)} " + f"(estimated_chunks={int(estimated_chunks):,})", + ) + + if progress_callback is None or total_tasks <= 1: + da.compute(write_graph) + _emit( + progress_end_int, + f"Wrote pyramid level {int(level_index)}/{int(total_levels)}", + ) + return + + class _TaskProgressCallback(Callback): + """Throttle task-level updates for UI-friendly progress reporting.""" + + def __init__(self) -> None: + super().__init__() + self._completed = 0 + self._last_emitted_completed = 0 + self._last_emitted_time = 0.0 + + def _posttask( + self, + key: object, + result: object, + dsk: object, + state: object, + worker_id: object, + ) -> None: + del key, result, dsk, state, worker_id + self._completed += 1 + now = float(time.monotonic()) + due_by_count = ( + int(self._completed) - int(self._last_emitted_completed) + ) >= int(_DISPLAY_PYRAMID_BUILD_PROGRESS_TASK_STEP) + due_by_time = (float(now) - float(self._last_emitted_time)) >= float( + _DISPLAY_PYRAMID_BUILD_PROGRESS_MIN_INTERVAL_SECONDS + ) + is_final = int(self._completed) >= int(total_tasks) + if not is_final and not due_by_count and not due_by_time: + return + + fraction = max( + 0.0, + min(1.0, float(self._completed) / float(max(1, int(total_tasks)))), + ) + mapped_percent = progress_start_int + int(round(fraction * progress_span)) + _emit( + mapped_percent, + "Writing pyramid level " + f"{int(level_index)}/{int(total_levels)} " + f"(tasks={int(self._completed):,}/{int(total_tasks):,})", + ) + self._last_emitted_completed = int(self._completed) + self._last_emitted_time = float(now) + + with _TaskProgressCallback(): + da.compute(write_graph) + + _emit( + progress_end_int, + f"Wrote pyramid level {int(level_index)}/{int(total_levels)}", + ) + + def _display_pyramid_level_component( *, source_component: str, @@ -2222,6 +2411,9 @@ def _build_visualization_multiscale_components( root: zarr.hierarchy.Group, source_component: str, level_factors_tpczyx: tuple[tuple[int, int, int, int, int, int], ...], + progress_callback: Optional[ProgressCallback] = None, + progress_start: int = 30, + progress_end: int = 68, ) -> tuple[str, ...]: """Materialize reusable display-pyramid levels for one source component. @@ -2235,6 +2427,13 @@ def _build_visualization_multiscale_components( Base component path. level_factors_tpczyx : tuple[tuple[int, int, int, int, int, int], ...] Absolute level factors including base level. + progress_callback : callable, optional + Callback receiving ``(percent, message)`` updates while levels are + generated. + progress_start : int, default=30 + Start percent for level-generation progress. + progress_end : int, default=68 + End percent for level-generation progress. Returns ------- @@ -2252,7 +2451,24 @@ def _build_visualization_multiscale_components( ] prior_component = str(source_component) prior_factors = tuple(int(value) for value in level_factors_tpczyx[0]) + total_generated_levels = max(1, int(len(level_factors_tpczyx) - 1)) + progress_start_int = int(progress_start) + progress_end_int = int(max(progress_start_int, int(progress_end))) + progress_span = max(0, int(progress_end_int - progress_start_int)) + + def _emit(percent: int, message: str) -> None: + if progress_callback is None: + return + progress_callback(int(percent), str(message)) + for level_index, absolute_factors in enumerate(level_factors_tpczyx[1:], start=1): + level_progress_start = progress_start_int + int( + ((int(level_index) - 1) / float(total_generated_levels)) + * float(progress_span) + ) + level_progress_end = progress_start_int + int( + (int(level_index) / float(total_generated_levels)) * float(progress_span) + ) all_relative = all( int(current) % int(previous) == 0 for current, previous in zip(absolute_factors, prior_factors, strict=False) @@ -2312,12 +2528,39 @@ def _build_visualization_multiscale_components( dtype=source_dtype.name, overwrite=True, ) - da.to_zarr( - downsampled, - str(zarr_path), - component=level_component, - overwrite=True, - compute=True, + _write_display_pyramid_level_with_progress( + level_array=downsampled, + zarr_path=zarr_path, + level_component=level_component, + source_component=source_level_component, + level_index=int(level_index), + total_levels=int(total_generated_levels), + absolute_factors_tpczyx=absolute_factors, + level_shape_tpczyx=level_shape, + level_chunks_tpczyx=level_chunks, + progress_callback=progress_callback, + progress_start=int(level_progress_start), + progress_end=int(level_progress_end), + ) + else: + estimated_chunks = _estimate_chunk_region_count_tpczyx( + shape_tpczyx=level_shape, + chunks_tpczyx=level_chunks, + ) + _LOGGER.info( + "[display_pyramid] reusing existing level %d/%d component=%s " + "shape=%s chunks=%s estimated_chunks=%d factors_tpczyx=%s", + int(level_index), + int(total_generated_levels), + str(level_component), + tuple(int(v) for v in level_shape), + tuple(int(v) for v in level_chunks), + int(estimated_chunks), + [int(value) for value in absolute_factors], + ) + _emit( + int(level_progress_end), + f"Reusing existing pyramid level {int(level_index)}/{int(total_generated_levels)}", ) root[level_component].attrs.update( @@ -3961,6 +4204,9 @@ def _emit(percent: int, message: str) -> None: root=root, source_component=source_component, level_factors_tpczyx=level_factors, + progress_callback=_emit, + progress_start=32, + progress_end=68, ) level_factors = _resolve_visualization_pyramid_factors_tpczyx( diff --git a/tests/visualization/test_pipeline.py b/tests/visualization/test_pipeline.py index 8148e68..b372bb9 100644 --- a/tests/visualization/test_pipeline.py +++ b/tests/visualization/test_pipeline.py @@ -447,6 +447,65 @@ def _unexpected_rechunk(self, *args, **kwargs): assert len(summary.source_components) > 1 +def test_run_display_pyramid_analysis_emits_level_write_progress( + tmp_path: Path, +) -> None: + store_path = tmp_path / "analysis_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + root.create_dataset( + name="results/shear_transform/latest/data", + shape=(1, 1, 1, 16, 16, 16), + chunks=(1, 1, 1, 4, 4, 4), + dtype="uint16", + overwrite=True, + ) + + events: list[tuple[int, str]] = [] + run_display_pyramid_analysis( + zarr_path=store_path, + parameters={"input_source": "results/shear_transform/latest/data"}, + progress_callback=lambda percent, message: events.append( + (int(percent), str(message)) + ), + ) + + level_messages = [ + str(message) for _, message in events if "pyramid level" in message + ] + assert any("Writing pyramid level 1/" in message for message in level_messages) + assert any("Wrote pyramid level 1/" in message for message in level_messages) + assert any( + 30 < int(percent) < 70 + for percent, message in events + if "pyramid level" in str(message) + ) + + +def test_run_display_pyramid_analysis_logs_level_chunk_estimates( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + store_path = tmp_path / "analysis_store.zarr" + root = zarr.open_group(str(store_path), mode="w") + root.create_dataset( + name="results/shear_transform/latest/data", + shape=(1, 1, 1, 16, 16, 16), + chunks=(1, 1, 1, 4, 4, 4), + dtype="uint16", + overwrite=True, + ) + + with caplog.at_level("INFO", logger="clearex.visualization.pipeline"): + run_display_pyramid_analysis( + zarr_path=store_path, + parameters={"input_source": "results/shear_transform/latest/data"}, + ) + + messages = [record.getMessage() for record in caplog.records] + assert any("writing level 1/" in message for message in messages) + assert any("estimated_chunks=" in message for message in messages) + + def test_run_display_pyramid_analysis_rebuilds_legacy_component_layout( tmp_path: Path, ) -> None: