diff --git a/src/virtual_stain_flow/datasets/bbox_schema.py b/src/virtual_stain_flow/datasets/bbox_schema.py new file mode 100644 index 0000000..4e00d7d --- /dev/null +++ b/src/virtual_stain_flow/datasets/bbox_schema.py @@ -0,0 +1,188 @@ +""" +bbox_schema.py + +This module defines a schema and accessor for bounding box (bbox) metadata +defining crops in raw images to be extracted and returned by a dataset. +For the purpose of extensibility, the schema additionally defines a rotation +center and angle. Intend to be used by a dataset class as the source of truth +for bbox metadata, column definition and accessor. +""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import Tuple +import numpy as np +import pandas as pd + +@dataclass(frozen=True) +class BBoxSchema: + """ + Centralized bbox column name definitions with flexible aliasing. + This class defines standard names for bounding box columns and allows + for flexible aliasing and prefixing to accommodate different dataframe + naming conventions. + """ + prefix: str = "" + + # mapping canonical keys used by the accessor to possible column names + # in the dataframe + _column_map = { + 'xmin': ['x_min', 'xmin', 'left', 'x1'], + 'ymin': ['y_min', 'ymin', 'top', 'y1'], + 'xmax': ['x_max', 'xmax', 'right', 'x2'], + 'ymax': ['y_max', 'ymax', 'bottom', 'y2'], + 'cx': ['box_x_center', 'cx', 'center_x'], + 'cy': ['box_y_center', 'cy', 'center_y'], + 'rcx': ['rot_x_center', 'rot_cx', 'rcx'], + 'rcy': ['rot_y_center', 'rot_cy', 'rcy'], + 'angle': ['angle', 'rotation', 'theta'] + } + + def __getattr__(self, name: str) -> str: + """ + Dynamic access: schema.xmin, schema.cx, etc. + This is for making easier access to prefixed column names. + Alternatively this could have been implemented with properties + but since we have a lot of fields we wish to access this is the + more compact approach. + """ + if name in self._column_map: + return f"{self.prefix}{name}" + raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") + + def find_column(self, df: pd.DataFrame, key: str) -> str: + """Find actual column name in DataFrame for given key.""" + for alias in self._column_map.get(key, []): + for variant in [alias, f"{self.prefix}{alias}", alias.upper(), alias.lower()]: + if variant in df.columns: + return variant + raise ValueError(f"No column found for key '{key}' in {list(df.columns)}") + + @property + def bbox_cols(self) -> Tuple[str, str, str, str]: + return (self.xmin, self.ymin, self.xmax, self.ymax) + +class BBoxRowView: + """ + Row accessor for bbox data. + Useful for dataset class to access a single bbox defined crop selection. + """ + def __init__(self, row: pd.Series, accessor: BBoxAccessor): + self._row, self._acc = row, accessor + + @property + def bbox(self) -> Tuple[int, int, int, int]: + return tuple(int(self._row[self._acc._cols[k]]) for k in ['xmin', 'ymin', 'xmax', 'ymax']) + + @property + def center(self) -> Tuple[float, float]: + return (float(self._row[self._acc._cols['cx']]), + float(self._row[self._acc._cols['cy']])) + + @property + def rot_center(self) -> Tuple[float, float]: + return (float(self._row[self._acc._cols['rcx']]), + float(self._row[self._acc._cols['rcy']])) + + @property + def angle(self) -> float: + return float(self._row[self._acc._cols['angle']]) + +@pd.api.extensions.register_dataframe_accessor("bbox") +class BBoxAccessor: + """ + Pandas accessor for bbox operations. + This accessor provides methods to ensure required bbox columns exist, + create missing ones, and access bbox data in a structured way. + 1. ensure_columns(): Ensures required columns exist, creates missing ones. + 2. row(i): Returns a BBoxRowView for the i-th row. + 3. coords(i): Returns bbox coordinates for the i-th row. + 4. centers(i): Returns bbox center for the i-th row. + 5. rot_centers(i): Returns rotation center for the i-th row. + 6. angle_of(i): Returns rotation angle for the i-th row. + """ + def __init__(self, df: pd.DataFrame): + self._df = df + self._schema = BBoxSchema() + self._cols = {} + + def __call__(self, schema: BBoxSchema) -> BBoxAccessor: + acc = BBoxAccessor(self._df) + acc._schema = schema + acc._cols = self._cols.copy() # Preserve column mapping + return acc + + def ensure_columns(self) -> pd.DataFrame: + """Ensure required columns exist, create missing ones.""" + df, s = self._df, self._schema + + # Find required bbox columns + required = ['xmin', 'ymin', 'xmax', 'ymax'] + for key in required: + self._cols[key] = s.find_column(df, key) + df[self._cols[key]] = df[self._cols[key]].astype(int) + + # Create/find centers + for key, calc in [('cx', lambda: (df[self._cols['xmin']] + df[self._cols['xmax']]) / 2), + ('cy', lambda: (df[self._cols['ymin']] + df[self._cols['ymax']]) / 2)]: + try: + self._cols[key] = s.find_column(df, key) + except ValueError: + # Create column with proper name + col_name = getattr(s, key) + df[col_name] = calc() + self._cols[key] = col_name + + # Create/find rotation centers and angle + for key, default_key in [('rcx', 'cx'), + ('rcy', 'cy')]: + try: + self._cols[key] = s.find_column(df, key) + except ValueError: + # Create column with proper name + col_name = getattr(s, key) + df[col_name] = df[self._cols[default_key]] + self._cols[key] = col_name + df[self._cols[key]] = df[self._cols[key]].astype(float) + + # Handle angle + try: + self._cols['angle'] = s.find_column(df, 'angle') + except ValueError: + col_name = s.angle + df[col_name] = 0.0 + self._cols['angle'] = col_name + df[self._cols['angle']] = df[self._cols['angle']].astype(float) + + self._ensure_cols_mapped() + + return df + + def _ensure_cols_mapped(self): + """Ensure column mapping is established.""" + if not self._cols: + # Direct mapping for columns that exist in the dataframe + for key in ['xmin', 'ymin', 'xmax', 'ymax', 'cx', 'cy', + 'rcx', 'rcy', 'angle']: + if key in self._df.columns: + self._cols[key] = key + else: + try: + self._cols[key] = self._schema.find_column(self._df, key) + except ValueError: + pass + + def row(self, i: int) -> BBoxRowView: + return BBoxRowView(self._df.iloc[i], self) + + def coords(self, i: int) -> Tuple[int, int, int, int]: + return self.row(i).bbox + + def centers(self, i: int) -> Tuple[float, float]: + return self.row(i).center + + def rot_centers(self, i: int) -> Tuple[float, float]: + return self.row(i).rot_center + + def angle_of(self, i: int) -> float: + return self.row(i).angle diff --git a/src/virtual_stain_flow/datasets/image_utils.py b/src/virtual_stain_flow/datasets/image_utils.py new file mode 100644 index 0000000..ee63751 --- /dev/null +++ b/src/virtual_stain_flow/datasets/image_utils.py @@ -0,0 +1,90 @@ +""" +image_utils.py + +This module centralizes image cropping and rotation operations so that + the dataset can focus on data handling logic. +The primary method `crop_and_rotate_image` is intended to be used by datasets + that need to crop and optionally rotate images based on bounding box annotations. +""" + +from typing import Tuple, Optional +import numpy as np +import cv2 + + +def crop_and_rotate_image( + image: np.ndarray, + bbox: Tuple[int, int, int, int], + rcx: Optional[float] = None, + rcy: Optional[float] = None, + angle: float = 0.0, + min_angle: float = 1e-3 +) -> np.ndarray: + """ + Crop and optionally rotate an image according to bounding box and rotation parameters. + This is the primary image processing method to be used by datasets that need to + crop and rotate images based on bounding box annotations. + + :param image: Input image as numpy array with shape (C, H, W) or (C, H, W, K) + :param bbox: Bounding box coordinates (xmin, ymin, xmax, ymax) + :param rcx: Rotation center x coordinate (optional) + :param rcy: Rotation center y coordinate (optional) + :param angle: Rotation angle in degrees + :param min_angle: Minimum angle threshold for rotation + :return: Cropped (and possibly rotated) image + """ + xmin, ymin, xmax, ymax = bbox + + # Fast path: no rotation needed + if angle == 0.0 or abs(angle) < min_angle or rcx is None or rcy is None: + return image[:, ymin:ymax, xmin:xmax] + + # Prepare image for cv2 (convert from CHW to HWC format) + cv_image = _prepare_image_for_cv2(image) + + # Apply rotation + M = cv2.getRotationMatrix2D(center=(rcx, rcy), angle=angle, scale=1.0) + rotated_cv = cv2.warpAffine( + cv_image, M, (cv_image.shape[1], cv_image.shape[0]), + flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 + ) + + # Convert back to original format and crop + rotated_image = _restore_image_format(rotated_cv, image.shape) + return rotated_image[:, ymin:ymax, xmin:xmax] + + +def _prepare_image_for_cv2(image: np.ndarray) -> np.ndarray: + """ + Convert image from (C, H, W) or (C, H, W, K) to OpenCV format. + Internal helper method used by crop_and_rotate_image. + """ + if image.ndim == 3: + # (C, H, W) -> (H, W, C) + return np.transpose(image, (1, 2, 0)) + elif image.ndim == 4: + # (C, H, W, K) -> (H, W, C*K) + C, H, W, K = image.shape + return image.transpose(1, 2, 0, 3).reshape(H, W, C * K) + else: + raise ValueError(f"Unsupported image dimensions: {image.ndim}. " + f"Expected 3D (C, H, W) or 4D (C, H, W, K).") + + +def _restore_image_format(cv_image: np.ndarray, original_shape: Tuple[int, ...]) -> np.ndarray: + """ + Convert image back from OpenCV format to original format. + Internal helper method used by crop_and_rotate_image. + """ + # Handle single channel case + if cv_image.ndim == 2: + cv_image = cv_image[:, :, np.newaxis] + + if len(original_shape) == 3: + # Convert back to (C, H, W) + return np.transpose(cv_image, (2, 0, 1)) + else: # len(original_shape) == 4 + # Convert back to (C, H, W, K) + C, H, W, K = original_shape + rotated_H, rotated_W = cv_image.shape[:2] + return cv_image.reshape(rotated_H, rotated_W, C, K).transpose(2, 0, 1, 3) diff --git a/tests/datasets/test_bbox_schema.py b/tests/datasets/test_bbox_schema.py new file mode 100644 index 0000000..3d00ed3 --- /dev/null +++ b/tests/datasets/test_bbox_schema.py @@ -0,0 +1,481 @@ +# test_bbox_schema.py +import numpy as np +import pandas as pd +import pytest + +from virtual_stain_flow.datasets.bbox_schema import BBoxSchema, BBoxAccessor, BBoxRowView + + +class TestBBoxSchema: + """Test suite for BBoxSchema class.""" + + def test_init_default(self): + """Test BBoxSchema initialization with default values.""" + schema = BBoxSchema() + assert schema.prefix == "" + + def test_init_with_prefix(self): + """Test BBoxSchema initialization with custom prefix.""" + schema = BBoxSchema(prefix="bbox_") + assert schema.prefix == "bbox_" + + def test_dynamic_attribute_access(self): + """Test dynamic attribute access for column names.""" + schema = BBoxSchema() + + # Test standard column names + assert schema.xmin == "xmin" + assert schema.ymin == "ymin" + assert schema.xmax == "xmax" + assert schema.ymax == "ymax" + assert schema.cx == "cx" + assert schema.cy == "cy" + assert schema.rcx == "rcx" + assert schema.rcy == "rcy" + assert schema.angle == "angle" + + def test_dynamic_attribute_access_with_prefix(self): + """Test dynamic attribute access with prefix.""" + schema = BBoxSchema(prefix="bbox_") + + assert schema.xmin == "bbox_xmin" + assert schema.ymin == "bbox_ymin" + assert schema.xmax == "bbox_xmax" + assert schema.ymax == "bbox_ymax" + assert schema.cx == "bbox_cx" + assert schema.cy == "bbox_cy" + assert schema.rcx == "bbox_rcx" + assert schema.rcy == "bbox_rcy" + assert schema.angle == "bbox_angle" + + def test_invalid_attribute_access(self): + """Test that accessing invalid attributes raises AttributeError.""" + schema = BBoxSchema() + + with pytest.raises(AttributeError, match="has no attribute 'invalid_attr'"): + _ = schema.invalid_attr + + def test_bbox_cols_property(self): + """Test bbox_cols property returns correct tuple.""" + schema = BBoxSchema() + assert schema.bbox_cols == ("xmin", "ymin", "xmax", "ymax") + + def test_bbox_cols_property_with_prefix(self): + """Test bbox_cols property with prefix.""" + schema = BBoxSchema(prefix="crop_") + assert schema.bbox_cols == ("crop_xmin", "crop_ymin", "crop_xmax", "crop_ymax") + + def test_find_column_exact_match(self): + """Test find_column with exact column name match.""" + df = pd.DataFrame({"xmin": [1, 2], "ymin": [3, 4]}) + schema = BBoxSchema() + + assert schema.find_column(df, "xmin") == "xmin" + assert schema.find_column(df, "ymin") == "ymin" + + def test_find_column_alias_match(self): + """Test find_column with alias matches.""" + df = pd.DataFrame({"x_min": [1, 2], "top": [3, 4], "right": [5, 6]}) + schema = BBoxSchema() + + assert schema.find_column(df, "xmin") == "x_min" + assert schema.find_column(df, "ymin") == "top" + assert schema.find_column(df, "xmax") == "right" + + def test_find_column_case_variations(self): + """Test find_column with case variations.""" + df = pd.DataFrame({"XMIN": [1, 2], "xmax": [3, 4]}) + schema = BBoxSchema() + + assert schema.find_column(df, "xmin") == "XMIN" + assert schema.find_column(df, "xmax") == "xmax" + + def test_find_column_with_prefix(self): + """Test find_column with prefix in schema.""" + df = pd.DataFrame({"bbox_xmin": [1, 2], "bbox_ymin": [3, 4]}) + schema = BBoxSchema(prefix="bbox_") + + assert schema.find_column(df, "xmin") == "bbox_xmin" + assert schema.find_column(df, "ymin") == "bbox_ymin" + + def test_find_column_not_found(self): + """Test find_column raises ValueError when column not found.""" + df = pd.DataFrame({"other_col": [1, 2]}) + schema = BBoxSchema() + + with pytest.raises(ValueError, match="No column found for key 'xmin'"): + schema.find_column(df, "xmin") + + def test_frozen_dataclass(self): + """Test that BBoxSchema is frozen (immutable).""" + schema = BBoxSchema(prefix="test_") + + with pytest.raises(Exception): # FrozenInstanceError or similar + schema.prefix = "new_prefix" + + +class TestBBoxAccessor: + """Test suite for BBoxAccessor class.""" + + @pytest.fixture + def basic_bbox_df(self): + """Basic DataFrame with bbox columns.""" + return pd.DataFrame({ + "xmin": [10, 20], + "ymin": [15, 25], + "xmax": [50, 60], + "ymax": [55, 65], + "angle": [0.0, 45.0], + "rcx": [30.0, 40.0], + "rcy": [35.0, 45.0], + }) + + @pytest.fixture + def minimal_bbox_df(self): + """Minimal DataFrame with only required bbox columns.""" + return pd.DataFrame({ + "xmin": [5, 15], + "ymin": [10, 20], + "xmax": [45, 55], + "ymax": [50, 60], + }) + + def test_accessor_registration(self, basic_bbox_df): + """Test that the bbox accessor is properly registered with pandas.""" + assert hasattr(basic_bbox_df, "bbox") + assert callable(basic_bbox_df.bbox) + + def test_accessor_initialization(self, basic_bbox_df): + """Test accessor initialization.""" + accessor = basic_bbox_df.bbox + + assert isinstance(accessor, BBoxAccessor) + assert accessor._df is basic_bbox_df + + def test_accessor_with_custom_schema(self, basic_bbox_df): + """Test accessor with custom schema.""" + custom_schema = BBoxSchema(prefix="bbox_") + accessor = basic_bbox_df.bbox(custom_schema) + + assert accessor._schema is custom_schema + assert accessor._schema.prefix == "bbox_" + + def test_ensure_columns_with_complete_df(self, basic_bbox_df): + """Test ensure_columns with DataFrame that has all columns.""" + accessor = basic_bbox_df.bbox + result_df = accessor.ensure_columns() + + # Should return DataFrame with all required columns + required_cols = ["xmin", "ymin", "xmax", "ymax", "cx", "cy", "rcx", "rcy", "angle"] + for col in required_cols: + assert col in result_df.columns + + # Original bbox columns should be preserved + np.testing.assert_array_equal(result_df["xmin"], [10, 20]) + np.testing.assert_array_equal(result_df["ymin"], [15, 25]) + + def test_ensure_columns_creates_missing_centers(self, minimal_bbox_df): + """Test ensure_columns creates missing center columns.""" + accessor = minimal_bbox_df.bbox + result_df = accessor.ensure_columns() + + # Should create cx and cy columns + assert "cx" in result_df.columns + assert "cy" in result_df.columns + + # Check calculated values + expected_cx = [(5 + 45) / 2, (15 + 55) / 2] # [25.0, 35.0] + expected_cy = [(10 + 50) / 2, (20 + 60) / 2] # [30.0, 40.0] + + np.testing.assert_array_equal(result_df["cx"], expected_cx) + np.testing.assert_array_equal(result_df["cy"], expected_cy) + + def test_ensure_columns_creates_missing_rotation_centers(self, minimal_bbox_df): + """Test ensure_columns creates missing rotation center columns.""" + accessor = minimal_bbox_df.bbox + result_df = accessor.ensure_columns() + + # Should create rcx and rcy columns defaulting to cx and cy + assert "rcx" in result_df.columns + assert "rcy" in result_df.columns + + # Should default to center values + np.testing.assert_array_equal(result_df["rcx"], result_df["cx"]) + np.testing.assert_array_equal(result_df["rcy"], result_df["cy"]) + + def test_ensure_columns_creates_missing_angle(self, minimal_bbox_df): + """Test ensure_columns creates missing angle column.""" + accessor = minimal_bbox_df.bbox + result_df = accessor.ensure_columns() + + # Should create angle column with default 0.0 + assert "angle" in result_df.columns + np.testing.assert_array_equal(result_df["angle"], [0.0, 0.0]) + + def test_ensure_columns_preserves_existing_values(self, basic_bbox_df): + """Test ensure_columns preserves existing column values.""" + accessor = basic_bbox_df.bbox + result_df = accessor.ensure_columns() + + # Existing values should be preserved + np.testing.assert_array_equal(result_df["angle"], [0.0, 45.0]) + np.testing.assert_array_equal(result_df["rcx"], [30.0, 40.0]) + np.testing.assert_array_equal(result_df["rcy"], [35.0, 45.0]) + + def test_ensure_columns_handles_alternative_names(self): + """Test ensure_columns works with alternative column names.""" + df = pd.DataFrame({ + "x_min": [10, 20], + "y_min": [15, 25], + "x_max": [50, 60], + "y_max": [55, 65], + "rotation": [30.0, 60.0], # Alternative to "angle" + }) + + accessor = df.bbox + result_df = accessor.ensure_columns() + + # Should find and use alternative names + assert accessor._cols["xmin"] == "x_min" + assert accessor._cols["angle"] == "rotation" + + # Values should be preserved + np.testing.assert_array_equal(result_df["rotation"], [30.0, 60.0]) + + def test_coords_method(self, basic_bbox_df): + """Test coords method returns correct bbox coordinates.""" + accessor = basic_bbox_df.bbox + accessor.ensure_columns() + + coords_0 = accessor.coords(0) + coords_1 = accessor.coords(1) + + assert coords_0 == (10, 15, 50, 55) + assert coords_1 == (20, 25, 60, 65) + + def test_centers_method(self, basic_bbox_df): + """Test centers method returns correct center coordinates.""" + accessor = basic_bbox_df.bbox + accessor.ensure_columns() + + centers_0 = accessor.centers(0) + centers_1 = accessor.centers(1) + + # Calculated from bbox coordinates + expected_cx_0 = (10 + 50) / 2 # 30.0 + expected_cy_0 = (15 + 55) / 2 # 35.0 + expected_cx_1 = (20 + 60) / 2 # 40.0 + expected_cy_1 = (25 + 65) / 2 # 45.0 + + assert centers_0 == (expected_cx_0, expected_cy_0) + assert centers_1 == (expected_cx_1, expected_cy_1) + + def test_rot_centers_method(self, basic_bbox_df): + """Test rot_centers method returns correct rotation centers.""" + accessor = basic_bbox_df.bbox + accessor.ensure_columns() + + rot_centers_0 = accessor.rot_centers(0) + rot_centers_1 = accessor.rot_centers(1) + + assert rot_centers_0 == (30.0, 35.0) + assert rot_centers_1 == (40.0, 45.0) + + def test_angle_of_method(self, basic_bbox_df): + """Test angle_of method returns correct angles.""" + accessor = basic_bbox_df.bbox + accessor.ensure_columns() + + angle_0 = accessor.angle_of(0) + angle_1 = accessor.angle_of(1) + + assert angle_0 == 0.0 + assert angle_1 == 45.0 + + def test_row_method_returns_bbox_row_view(self, basic_bbox_df): + """Test row method returns BBoxRowView instance.""" + accessor = basic_bbox_df.bbox + accessor.ensure_columns() + + row_view = accessor.row(0) + + assert isinstance(row_view, BBoxRowView) + assert row_view._acc is accessor + + +class TestBBoxRowView: + """Test suite for BBoxRowView class.""" + + @pytest.fixture + def setup_row_view(self): + """Set up a BBoxRowView for testing.""" + df = pd.DataFrame({ + "xmin": [10, 20], + "ymin": [15, 25], + "xmax": [50, 60], + "ymax": [55, 65], + "angle": [0.0, 45.0], + "rcx": [30.0, 40.0], + "rcy": [35.0, 45.0], + }) + + accessor = df.bbox + accessor.ensure_columns() + + return accessor.row(0), accessor.row(1) + + def test_bbox_property(self, setup_row_view): + """Test bbox property returns correct coordinates.""" + row_view_0, row_view_1 = setup_row_view + + assert row_view_0.bbox == (10, 15, 50, 55) + assert row_view_1.bbox == (20, 25, 60, 65) + + def test_center_property(self, setup_row_view): + """Test center property returns correct center coordinates.""" + row_view_0, row_view_1 = setup_row_view + + # Calculated from bbox coordinates + expected_center_0 = (30.0, 35.0) # (10+50)/2, (15+55)/2 + expected_center_1 = (40.0, 45.0) # (20+60)/2, (25+65)/2 + + assert row_view_0.center == expected_center_0 + assert row_view_1.center == expected_center_1 + + def test_rot_center_property(self, setup_row_view): + """Test rot_center property returns correct rotation centers.""" + row_view_0, row_view_1 = setup_row_view + + assert row_view_0.rot_center == (30.0, 35.0) + assert row_view_1.rot_center == (40.0, 45.0) + + def test_angle_property(self, setup_row_view): + """Test angle property returns correct angles.""" + row_view_0, row_view_1 = setup_row_view + + assert row_view_0.angle == 0.0 + assert row_view_1.angle == 45.0 + + +class TestBBoxSchemaIntegration: + """Integration tests for bbox schema components.""" + + def test_end_to_end_workflow(self): + """Test complete workflow from DataFrame to accessing bbox data.""" + # Start with DataFrame using alternative column names + df = pd.DataFrame({ + "left": [5, 15, 25], + "top": [10, 20, 30], + "right": [35, 45, 55], + "bottom": [40, 50, 60], + "rotation": [0.0, 30.0, 60.0], + }) + + # Use custom schema with prefix + schema = BBoxSchema(prefix="bbox_") + accessor = df.bbox(schema) + + # Ensure columns creates missing ones + result_df = accessor.ensure_columns() + + # Should have all required columns + assert "bbox_cx" in result_df.columns + assert "bbox_cy" in result_df.columns + assert "bbox_rcx" in result_df.columns + assert "bbox_rcy" in result_df.columns + + # Should preserve original angle column + assert accessor._cols["angle"] == "rotation" + + # Access data through accessor methods + coords = accessor.coords(1) + centers = accessor.centers(1) + rot_centers = accessor.rot_centers(1) + angle = accessor.angle_of(1) + + assert coords == (15, 20, 45, 50) + assert centers == (30.0, 35.0) # (15+45)/2, (20+50)/2 + assert rot_centers == (30.0, 35.0) # Should default to centers + assert angle == 30.0 + + def test_missing_required_columns_error(self): + """Test that missing required columns raise appropriate errors.""" + # DataFrame missing required bbox columns + df = pd.DataFrame({ + "some_col": [1, 2, 3], + "other_col": [4, 5, 6], + }) + + accessor = df.bbox + + with pytest.raises(ValueError, match="No column found for key 'xmin'"): + accessor.ensure_columns() + + def test_schema_column_mapping_priority(self): + """Test column mapping priority follows the order in _column_map.""" + # DataFrame with multiple possible column names + df = pd.DataFrame({ + "xmin": [10, 20], + "x_min": [100, 200], # This should be used since it's first in _column_map + "ymin": [15, 25], + "ymax": [55, 65], + "xmax": [50, 60], + }) + + accessor = df.bbox + accessor.ensure_columns() + + # Should use "x_min" since it appears first in the _column_map for 'xmin' + assert accessor._cols["xmin"] == "x_min" + + # Verify values come from the correct column + coords = accessor.coords(0) + assert coords[0] == 100 # Should be from "x_min" column, not "xmin" + + def test_data_type_enforcement(self): + """Test that ensure_columns enforces appropriate data types.""" + df = pd.DataFrame({ + "xmin": ["10", "20"], # String values + "ymin": ["15", "25"], + "xmax": ["50", "60"], + "ymax": ["55", "65"], + }) + + accessor = df.bbox + result_df = accessor.ensure_columns() + + # Integer columns should be converted to int + assert result_df["xmin"].dtype in [np.int32, np.int64] + assert result_df["ymin"].dtype in [np.int32, np.int64] + + # Float columns should be float + assert result_df["cx"].dtype in [np.float32, np.float64] + assert result_df["angle"].dtype in [np.float32, np.float64] + + def test_accessor_state_preservation(self): + """Test that accessor state is properly preserved across operations.""" + df = pd.DataFrame({ + "xmin": [10, 20], + "ymin": [15, 25], + "xmax": [50, 60], + "ymax": [55, 65], + }) + + # Create accessor with custom schema + custom_schema = BBoxSchema(prefix="test_") + accessor = df.bbox(custom_schema) + + # Ensure columns + accessor.ensure_columns() + + # Verify schema is preserved + assert accessor._schema.prefix == "test_" + + # Verify column mappings are preserved + assert "xmin" in accessor._cols + assert "cx" in accessor._cols + + # Multiple method calls should work consistently + coords_1 = accessor.coords(0) + coords_2 = accessor.coords(0) + assert coords_1 == coords_2 diff --git a/tests/datasets/test_image_utils.py b/tests/datasets/test_image_utils.py new file mode 100644 index 0000000..1e64249 --- /dev/null +++ b/tests/datasets/test_image_utils.py @@ -0,0 +1,413 @@ +# test_image_utils.py + +import numpy as np +import pytest + +from virtual_stain_flow.datasets.image_utils import ( + crop_and_rotate_image, + _prepare_image_for_cv2, + _restore_image_format, +) + + +class TestCropAndRotateImage: + """Test suite for crop_and_rotate_image function.""" + + @pytest.fixture + def sample_image_3d(self): + """Create a 3D test image (C, H, W) with distinguishable patterns.""" + # Create 100x100 image with 2 channels + image = np.zeros((2, 100, 100), dtype=np.float32) + + # Channel 0: checkerboard pattern + for i in range(100): + for j in range(100): + if (i // 10 + j // 10) % 2 == 0: + image[0, i, j] = 255 + else: + image[0, i, j] = 100 + + # Channel 1: gradient pattern + for i in range(100): + image[1, i, :] = i * 2.55 # 0 to 255 gradient + + return image + + @pytest.fixture + def sample_image_4d(self): + """Create a 4D test image (C, H, W, K) for testing.""" + # Create 50x50 image with 2 channels and 3 additional dimensions + image = np.zeros((2, 50, 50, 3), dtype=np.float32) + + # Fill with distinguishable patterns + for c in range(2): + for k in range(3): + # Different pattern for each channel and k dimension + value = (c + 1) * (k + 1) * 50 + image[c, :, :, k] = value + + return image + + def test_crop_only_no_rotation(self, sample_image_3d): + """Test cropping without rotation (angle=0).""" + bbox = (20, 30, 60, 70) # xmin, ymin, xmax, ymax + + result = crop_and_rotate_image(sample_image_3d, bbox, angle=0.0) + + # Check dimensions + expected_height = 70 - 30 # 40 + expected_width = 60 - 20 # 40 + assert result.shape == (2, expected_height, expected_width) + + # Check content matches manual crop + expected = sample_image_3d[:, 30:70, 20:60] + np.testing.assert_array_equal(result, expected) + + def test_crop_only_small_angle_below_threshold(self, sample_image_3d): + """Test that very small angles below threshold don't trigger rotation.""" + bbox = (10, 10, 50, 50) + small_angle = 1e-4 # Below default min_angle of 1e-3 + + result = crop_and_rotate_image( + sample_image_3d, bbox, rcx=30.0, rcy=30.0, angle=small_angle + ) + + # Should be same as no rotation + expected = sample_image_3d[:, 10:50, 10:50] + np.testing.assert_array_equal(result, expected) + + def test_crop_only_none_rotation_centers(self, sample_image_3d): + """Test that None rotation centers prevent rotation even with angle.""" + bbox = (10, 10, 50, 50) + + result = crop_and_rotate_image( + sample_image_3d, bbox, rcx=None, rcy=None, angle=45.0 + ) + + # Should be same as no rotation + expected = sample_image_3d[:, 10:50, 10:50] + np.testing.assert_array_equal(result, expected) + + def test_crop_with_rotation_applied(self, sample_image_3d): + """Test that rotation is applied when conditions are met.""" + bbox = (20, 20, 80, 80) + rcx, rcy = 50.0, 50.0 # Center of rotation + angle = 90.0 # 90 degree rotation + + result = crop_and_rotate_image( + sample_image_3d, bbox, rcx=rcx, rcy=rcy, angle=angle + ) + + # Should have same dimensions as crop + expected_height = 80 - 20 # 60 + expected_width = 80 - 20 # 60 + assert result.shape == (2, expected_height, expected_width) + + # Content should be different from no-rotation case + no_rotation_result = crop_and_rotate_image(sample_image_3d, bbox, angle=0.0) + assert not np.array_equal(result, no_rotation_result) + + def test_crop_4d_image(self, sample_image_4d): + """Test cropping works with 4D images.""" + bbox = (10, 10, 40, 40) + + result = crop_and_rotate_image(sample_image_4d, bbox, angle=0.0) + + # Check dimensions + expected_height = 40 - 10 # 30 + expected_width = 40 - 10 # 30 + assert result.shape == (2, expected_height, expected_width, 3) + + # Check content + expected = sample_image_4d[:, 10:40, 10:40, :] + np.testing.assert_array_equal(result, expected) + + def test_crop_4d_image_with_rotation(self, sample_image_4d): + """Test cropping with rotation works on 4D images.""" + bbox = (5, 5, 45, 45) + + result = crop_and_rotate_image( + sample_image_4d, bbox, rcx=25.0, rcy=25.0, angle=45.0 + ) + + # Should have correct dimensions + expected_height = 45 - 5 # 40 + expected_width = 45 - 5 # 40 + assert result.shape == (2, expected_height, expected_width, 3) + + def test_different_angles_produce_different_results(self, sample_image_3d): + """Test that different rotation angles produce different results.""" + bbox = (25, 25, 75, 75) + rcx, rcy = 50.0, 50.0 + + result_0 = crop_and_rotate_image(sample_image_3d, bbox, rcx, rcy, 0.0) + result_45 = crop_and_rotate_image(sample_image_3d, bbox, rcx, rcy, 45.0) + result_90 = crop_and_rotate_image(sample_image_3d, bbox, rcx, rcy, 90.0) + + # All should have same shape + assert result_0.shape == result_45.shape == result_90.shape + + # But different content + assert not np.array_equal(result_0, result_45) + assert not np.array_equal(result_45, result_90) + assert not np.array_equal(result_0, result_90) + + def test_custom_min_angle_threshold(self, sample_image_3d): + """Test custom min_angle threshold parameter.""" + bbox = (10, 10, 50, 50) + small_angle = 0.5 + + # With default threshold (1e-3), should trigger rotation + result_default = crop_and_rotate_image( + sample_image_3d, bbox, rcx=30.0, rcy=30.0, angle=small_angle + ) + + # With higher threshold, should not trigger rotation + result_high_thresh = crop_and_rotate_image( + sample_image_3d, bbox, rcx=30.0, rcy=30.0, angle=small_angle, min_angle=1.0 + ) + + # High threshold result should match no rotation + expected_no_rotation = sample_image_3d[:, 10:50, 10:50] + np.testing.assert_array_equal(result_high_thresh, expected_no_rotation) + + # Default threshold result should be different (rotated) + assert not np.array_equal(result_default, expected_no_rotation) + + def test_negative_angle_rotation(self, sample_image_3d): + """Test that negative angles work correctly.""" + bbox = (20, 20, 60, 60) + rcx, rcy = 40.0, 40.0 + + result_pos = crop_and_rotate_image(sample_image_3d, bbox, rcx, rcy, 30.0) + result_neg = crop_and_rotate_image(sample_image_3d, bbox, rcx, rcy, -30.0) + + # Should produce different results + assert not np.array_equal(result_pos, result_neg) + + # Both should have same shape + assert result_pos.shape == result_neg.shape + + def test_edge_case_zero_area_crop(self, sample_image_3d): + """Test edge case where crop area is zero or very small.""" + # Same min and max coordinates + bbox = (50, 50, 50, 50) + + result = crop_and_rotate_image(sample_image_3d, bbox) + + # Should return empty array + assert result.shape == (2, 0, 0) + + def test_edge_case_out_of_bounds_crop(self, sample_image_3d): + """Test cropping coordinates that go beyond image boundaries.""" + # Bbox that extends beyond 100x100 image + bbox = (80, 80, 120, 120) + + # Should not raise error, cv2 handles boundary conditions + result = crop_and_rotate_image(sample_image_3d, bbox) + + # Should have expected crop dimensions + expected_height = 120 - 80 # 40, but limited by image boundary + expected_width = 120 - 80 # 40, but limited by image boundary + assert result.shape[1] <= expected_height + assert result.shape[2] <= expected_width + + +class TestPrepareImageForCv2: + """Test suite for _prepare_image_for_cv2 helper function.""" + + def test_3d_image_conversion(self): + """Test conversion of 3D image from (C, H, W) to (H, W, C).""" + # Create test image (2 channels, 10x20) + image = np.random.rand(2, 10, 20).astype(np.float32) + + result = _prepare_image_for_cv2(image) + + # Should be transposed to (H, W, C) + assert result.shape == (10, 20, 2) + + # Content should be preserved + np.testing.assert_array_equal(result[:, :, 0], image[0, :, :]) + np.testing.assert_array_equal(result[:, :, 1], image[1, :, :]) + + def test_4d_image_conversion(self): + """Test conversion of 4D image from (C, H, W, K) to (H, W, C*K).""" + # Create test image (2 channels, 10x20, 3 additional dims) + image = np.random.rand(2, 10, 20, 3).astype(np.float32) + + result = _prepare_image_for_cv2(image) + + # Should be reshaped to (H, W, C*K) + assert result.shape == (10, 20, 6) # 2*3 = 6 + + # Verify content mapping + # First channel, first K should map to first output channel + np.testing.assert_array_equal(result[:, :, 0], image[0, :, :, 0]) + # Second channel, first K should map to appropriate output channel + np.testing.assert_array_equal(result[:, :, 3], image[1, :, :, 0]) + + def test_unsupported_dimensions(self): + """Test that unsupported image dimensions raise ValueError.""" + # 2D image (missing channel dimension) + image_2d = np.random.rand(10, 20) + with pytest.raises(ValueError, match="Unsupported image dimensions: 2"): + _prepare_image_for_cv2(image_2d) + + # 5D image (too many dimensions) + image_5d = np.random.rand(2, 10, 20, 3, 4) + with pytest.raises(ValueError, match="Unsupported image dimensions: 5"): + _prepare_image_for_cv2(image_5d) + + def test_single_channel_3d(self): + """Test conversion of single-channel 3D image.""" + image = np.random.rand(1, 15, 25).astype(np.float32) + + result = _prepare_image_for_cv2(image) + + assert result.shape == (15, 25, 1) + np.testing.assert_array_equal(result[:, :, 0], image[0, :, :]) + + +class TestRestoreImageFormat: + """Test suite for _restore_image_format helper function.""" + + def test_restore_3d_format(self): + """Test restoring 3D image format from OpenCV (H, W, C) to (C, H, W).""" + original_shape = (2, 10, 20) + cv_image = np.random.rand(10, 20, 2).astype(np.float32) + + result = _restore_image_format(cv_image, original_shape) + + assert result.shape == original_shape + + # Check content preservation + np.testing.assert_array_equal(result[0, :, :], cv_image[:, :, 0]) + np.testing.assert_array_equal(result[1, :, :], cv_image[:, :, 1]) + + def test_restore_4d_format(self): + """Test restoring 4D image format from OpenCV to (C, H, W, K).""" + original_shape = (2, 10, 20, 3) + cv_image = np.random.rand(10, 20, 6).astype(np.float32) # 2*3 = 6 channels + + result = _restore_image_format(cv_image, original_shape) + + assert result.shape == original_shape + + # Check content mapping - first channel, first K + np.testing.assert_array_equal(result[0, :, :, 0], cv_image[:, :, 0]) + # Second channel, first K + np.testing.assert_array_equal(result[1, :, :, 0], cv_image[:, :, 3]) + + def test_restore_single_channel_from_2d(self): + """Test restoring from 2D OpenCV image (single channel case).""" + original_shape = (1, 15, 25) + cv_image = np.random.rand(15, 25).astype(np.float32) # 2D + + result = _restore_image_format(cv_image, original_shape) + + assert result.shape == original_shape + np.testing.assert_array_equal(result[0, :, :], cv_image) + + def test_roundtrip_3d_conversion(self): + """Test that prepare -> restore roundtrip preserves 3D image.""" + original = np.random.rand(3, 12, 18).astype(np.float32) + + cv_format = _prepare_image_for_cv2(original) + restored = _restore_image_format(cv_format, original.shape) + + np.testing.assert_array_almost_equal(original, restored) + + def test_roundtrip_4d_conversion(self): + """Test that prepare -> restore roundtrip preserves 4D image.""" + original = np.random.rand(2, 8, 12, 4).astype(np.float32) + + cv_format = _prepare_image_for_cv2(original) + restored = _restore_image_format(cv_format, original.shape) + + np.testing.assert_array_almost_equal(original, restored) + + +class TestImageUtilsIntegration: + """Integration tests for the image_utils module.""" + + def test_geometric_rotation_properties(self): + """Test that rotations have expected geometric properties.""" + # Create a simple pattern that's easy to verify + image = np.zeros((1, 20, 20), dtype=np.float32) + # Put a bright pixel at (5, 10) - off center + image[0, 5, 10] = 255.0 + + bbox = (0, 0, 20, 20) # Full image + center_x, center_y = 10.0, 10.0 # Image center + + # 180 degree rotation should move the pixel to (15, 10) + # (symmetric around center) + result_180 = crop_and_rotate_image( + image, bbox, center_x, center_y, 180.0 + ) + + # The bright pixel should now be at approximately (15, 10) + # Due to interpolation, check the region around expected position + bright_region = result_180[0, 14:17, 9:12] + assert np.max(bright_region) > 200 # Should find bright pixel nearby + + def test_rotation_preserves_image_statistics_approximately(self): + """Test that rotation preserves approximate image statistics.""" + # Create image with known statistics + np.random.seed(42) + image = np.random.rand(2, 50, 50).astype(np.float32) * 100 + + bbox = (5, 5, 45, 45) + center_x, center_y = 25.0, 25.0 + + original_crop = crop_and_rotate_image(image, bbox, angle=0.0) + rotated_crop = crop_and_rotate_image(image, bbox, center_x, center_y, 45.0) + + # Mean should be approximately preserved (within tolerance due to interpolation) + original_mean = np.mean(original_crop) + rotated_mean = np.mean(rotated_crop) + + # Allow 5% tolerance for interpolation effects + np.testing.assert_allclose(original_mean, rotated_mean, rtol=0.05) + + def test_multiple_rotations_composition(self): + """Test that multiple small rotations approximate one large rotation.""" + image = np.random.rand(1, 40, 40).astype(np.float32) * 255 + bbox = (5, 5, 35, 35) + center_x, center_y = 20.0, 20.0 + + # Single 90 degree rotation + result_90 = crop_and_rotate_image(image, bbox, center_x, center_y, 90.0) + + # Four 22.5 degree rotations (approximately 90 degrees) + temp_image = image.copy() + for _ in range(4): + temp_crop = crop_and_rotate_image(temp_image, bbox, center_x, center_y, 22.5) + # For this test, we'd need to reconstruct full image, which is complex + # So just test that the operation doesn't fail + assert temp_crop.shape == (1, 30, 30) + + def test_cropping_different_regions_same_image(self): + """Test cropping different regions from the same image.""" + # Create image with spatial variation + image = np.zeros((1, 100, 100), dtype=np.float32) + for i in range(100): + for j in range(100): + image[0, i, j] = i + j # Diagonal gradient + + # Crop different regions + bbox1 = (10, 10, 30, 30) # Top-left region + bbox2 = (70, 70, 90, 90) # Bottom-right region + + crop1 = crop_and_rotate_image(image, bbox1) + crop2 = crop_and_rotate_image(image, bbox2) + + # Should have same dimensions + assert crop1.shape == crop2.shape + + # But different content (different means due to gradient) + assert np.mean(crop1) < np.mean(crop2) # crop1 from top-left should have smaller values + + # Verify specific content + assert crop1[0, 0, 0] == 20.0 # i=10, j=10 + assert crop2[0, 0, 0] == 140.0 # i=70, j=70