diff --git a/detectree2/models/train.py b/detectree2/models/train.py index c39c9bf8..49df8e3a 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -328,7 +328,7 @@ def __init__(self, eval_period, model, data_loader, img_per_dataset=6): eval_period (int): The number of iterations between evaluations. model (torch.nn.Module): The model to evaluate. data_loader (torch.utils.data.DataLoader): The data loader for evaluation. - patience (int): The number of evaluation periods to wait for improvement before early stopping. + img_per_dataset (int): The number of images per dataset to visualize. """ self._model = model self._period = eval_period @@ -339,11 +339,11 @@ def __init__(self, eval_period, model, data_loader, img_per_dataset=6): def after_step(self): """ - Hook to be called after each training iteration to evaluate the model and manage checkpoints. + Hook to be called after each training iteration to visualize model predictions. - - Evaluates the model at regular intervals. - - Saves the best model checkpoint based on the AP50 metric. - - Implements early stopping if the AP50 does not improve after a set number of evaluations. + This hook runs at regular intervals to perform inference on a sample of the validation dataset. + It then visualizes the predictions and logs the resulting images to the event storage, + making them accessible through tools like TensorBoard. """ next_iter = self.trainer.iter + 1 is_final = next_iter == self.trainer.max_iter @@ -703,9 +703,9 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non """Get the tree dictionaries. Args: - directory: Path to directory - classes: List of classes to include - classes_at: Signifies which column (if any) corresponds to the class labels + directory: Path to directory containing geojson annotation files. + class_mapping: A dictionary mapping class labels from geojson properties to category indices. + If None, all annotations are assigned to category 0 (tree). Returns: List of dictionaries corresponding to segmentations of trees. Each dictionary includes @@ -919,12 +919,13 @@ def remove_registered_data(name="tree"): MetadataCatalog.remove(name + "_" + d) -def register_test_data(test_location, name="tree"): +def register_test_data(test_location, name="tree", class_mapping_file=None): """Register data for testing. Args: test_location: directory containing test data name: string to name data + class_mapping_file: Path to the class mapping file (json or pickle). """ d = "test" @@ -993,7 +994,6 @@ def setup_cfg( base_lr: base learning rate weight_decay: weight decay for optimizer max_iter: maximum number of iterations - num_classes: number of classes eval_period: number of iterations between evaluations out_dir: directory to save outputs resize: resize strategy for images diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 54ee7202..f6da3d87 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -138,19 +138,27 @@ def process_tile(img_path: str, """Process a single tile for making predictions. Args: - img_path: Path to the orthomosaic - out_dir: Output directory - buffer: Overlapping buffer of tiles in meters (UTM) - tile_width: Tile width in meters - tile_height: Tile height in meters - dtype_bool: Flag to edit dtype to prevent black tiles - minx: Minimum x coordinate of tile - miny: Minimum y coordinate of tile - crs: Coordinate reference system - tilename: Name of the tile + img_path: Path to the orthomosaic. + out_dir: Output directory. + buffer: Overlapping buffer of tiles in meters (UTM). + tile_width: Tile width in meters. + tile_height: Tile height in meters. + dtype_bool: Flag to edit dtype to prevent black tiles. + minx: Minimum x coordinate of the tile. + miny: Minimum y coordinate of the tile. + crs: Coordinate reference system. + tilename: Name of the tile. + crowns: Crown polygons as a GeoDataFrame used to skip tiles if coverage is below `threshold`. + threshold: Minimum fraction [0,1] of tile coverage by `crowns` required to avoid skipping the tile. + nan_threshold: Maximum proportion [0,1] of the tile that can be nodata or NaN values before skipping. + mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata. + additional_nodata: List of additional pixel values to treat as nodata. + image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band. + ignore_bands_indices: List of integer indices of bands to ignore during processing. + use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated training crowns. Returns: - None + A tuple containing the rasterio dataset, output path root, overlapping crowns, and tile parameters (minx, miny, buffer), or None if the tile is skipped. """ try: with rasterio.open(img_path) as data: @@ -290,22 +298,30 @@ def process_tile_ms(img_path: str, image_statistics: List[Dict[str, float]] = None, ignore_bands_indices: List[int] = [], use_convex_mask: bool = True): - """Process a single tile for making predictions. + """Process a single multispectral tile for making predictions. Args: - img_path: Path to the orthomosaic - out_dir: Output directory - buffer: Overlapping buffer of tiles in meters (UTM) - tile_width: Tile width in meters - tile_height: Tile height in meters - dtype_bool: Flag to edit dtype to prevent black tiles - minx: Minimum x coordinate of tile - miny: Minimum y coordinate of tile - crs: Coordinate reference system - tilename: Name of the tile + img_path: Path to the orthomosaic. + out_dir: Output directory. + buffer: Overlapping buffer of tiles in meters (UTM). + tile_width: Tile width in meters. + tile_height: Tile height in meters. + dtype_bool: Flag to edit dtype to prevent black tiles. + minx: Minimum x coordinate of the tile. + miny: Minimum y coordinate of the tile. + crs: Coordinate reference system. + tilename: Name of the tile. + crowns: Crown polygons as a GeoDataFrame used to skip tiles if coverage is below `threshold`. + threshold: Minimum fraction [0,1] of tile coverage by `crowns` required to avoid skipping the tile. + nan_threshold: Maximum proportion [0,1] of the tile that can be nodata or NaN values before skipping. + mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata. + additional_nodata: List of additional pixel values to treat as nodata. + image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band. + ignore_bands_indices: List of integer indices of bands to ignore during processing. + use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated crowns. Returns: - None + A tuple containing the rasterio dataset, output path root, overlapping crowns, and tile parameters (minx, miny, buffer), or None if the tile is skipped. """ try: with rasterio.open(img_path) as data: @@ -457,19 +473,26 @@ def process_tile_train( """Process a single tile for training data. Args: - img_path: Path to the orthomosaic - out_dir: Output directory - buffer: Overlapping buffer of tiles in meters (UTM) - tile_width: Tile width in meters - tile_height: Tile height in meters - dtype_bool: Flag to edit dtype to prevent black tiles - minx: Minimum x coordinate of tile - miny: Minimum y coordinate of tile - crs: Coordinate reference system - tilename: Name of the tile - crowns: Crown polygons as a geopandas dataframe - threshold: Min proportion of the tile covered by crowns to be accepted {0,1} - nan_theshold: Max proportion of tile covered by nans + img_path: Path to the orthomosaic. + out_dir: Output directory. + buffer: Overlapping buffer of tiles in meters (UTM). + tile_width: Tile width in meters. + tile_height: Tile height in meters. + dtype_bool: Flag to edit dtype to prevent black tiles. + minx: Minimum x coordinate of tile. + miny: Minimum y coordinate of tile. + crs: Coordinate reference system. + tilename: Name of the tile. + crowns: Crown polygons as a geopandas DataFrame. + threshold: Min proportion of the tile covered by crowns to be accepted {0,1}. + nan_threshold: Max proportion of tile covered by NaNs. + mode: Type of the raster data ("rgb" or "ms"). + class_column: Name of the column in `crowns` DataFrame for class-based tiling. + mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata. + additional_nodata: List of additional pixel values to treat as nodata. + image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band. + ignore_bands_indices: List of integer indices of bands to ignore during processing. + use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated crowns. Returns: None @@ -544,7 +567,20 @@ def _calculate_tile_placements( tile_placement: str = "grid", overlapping_tiles: bool = False, ) -> List[Tuple[int, int]]: - """Internal method for calculating the placement of tiles""" + """Internal method for calculating the placement of tiles. + + Args: + img_path: Path to the orthomosaic. + buffer: Overlapping buffer of tiles in meters (UTM). + tile_width: Tile width in meters. + tile_height: Tile height in meters. + crowns: Crown polygons as a GeoDataFrame. Required for 'adaptive' placement. + tile_placement: Strategy for placing tiles ('grid' or 'adaptive'). + overlapping_tiles: If True, generates additional tiles offset by half a tile size. + + Returns: + A list of (minx, miny) coordinates for the lower-coordinates corner of each tile. + """ if tile_placement == "grid": with rasterio.open(img_path) as data: @@ -622,17 +658,18 @@ def calculate_image_statistics(file_path, min_windows=100, mode="rgb", ignore_bands_indices: List[int] = []): - """ - Calculate statistics for a raster using either whole image or sampled windows. + """Calculate statistics for a raster using either whole image or sampled windows. - Parameters: - - file_path: str, path to the raster file. - - values_to_ignore: list, values to ignore in statistics (e.g., NaN, custom values). - - window_size: int, size of square window for sampling. - - min_windows: int, minimum number of valid windows to include in statistics. + Args: + file_path: Path to the raster file. + values_to_ignore: Values to ignore in statistics (e.g., NaN, custom values). + window_size: Size of square window for sampling. + min_windows: Minimum number of valid windows to include in statistics. + mode: Type of the raster data ("rgb" or "ms"). + ignore_bands_indices: List of integer indices of bands to ignore during statistics calculation. Returns: - - List of dictionaries containing statistics for each band. + List of dictionaries containing statistics for each band. """ if values_to_ignore is None: values_to_ignore = [] @@ -769,26 +806,34 @@ def tile_data( ) -> None: """Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles. - Tiles up large rasters into manageable tiles for training and prediction. If crowns are not supplied, the function - will tile up the entire landscape for prediction. If crowns are supplied, the function will tile these with the image - and skip tiles without a minimum coverage of crowns. The 'threshold' can be varied to ensure good coverage of - crowns across a training tile. Tiles that do not have sufficient coverage are skipped. + Tiles up large rasters into manageable tiles for training and prediction. If crowns are not + supplied, the function will tile up the entire landscape for prediction. If crowns are supplied, + the function will tile these with the image and skip tiles without a minimum coverage of crowns. + The 'threshold' can be varied to ensure good coverage of crowns across a training tile. Tiles + that do not have sufficient coverage are skipped. Args: - img_path: Path to the orthomosaic - out_dir: Output directory - buffer: Overlapping buffer of tiles in meters (UTM) - tile_width: Tile width in meters - tile_height: Tile height in meters - crowns: Crown polygons as a GeoPandas DataFrame - threshold: Minimum proportion of the tile covered by crowns to be accepted [0,1] - nan_threshold: Maximum proportion of tile covered by NaNs [0,1] - dtype_bool: Flag to edit dtype to prevent black tiles - mode: Type of the raster data ("rgb" or "ms") - class_column: Name of the column in `crowns` DataFrame for class-based tiling + img_path: Path to the orthomosaic. + out_dir: Output directory. + buffer: Overlapping buffer of tiles in meters (UTM). + tile_width: Tile width in meters. + tile_height: Tile height in meters. + crowns: Crown polygons as a GeoDataFrame. + threshold: Minimum proportion of the tile covered by crowns to be accepted [0,1]. + nan_threshold: Maximum proportion of the tile covered by NaNs [0,1]. + dtype_bool: Flag to edit dtype to prevent black tiles. + mode: Type of the raster data ("rgb" or "ms"). + class_column: Name of the column in `crowns` DataFrame for class-based tiling. tile_placement: Strategy for placing tiles. "grid" for fixed grid placement based on the bounds of the input image, optimized for speed. "adaptive" for dynamic placement of tiles based on crowns, adjusts based on data features for better coverage. + mask_path: Path to a mask file to use for tiling. + multithreaded: Flag to enable multithreaded processing. + random_subset: Number of random tiles it will try to process per image. If -1, all tiles are processed. + additional_nodata: List of additional pixel values to treat as nodata. + overlapping_tiles: Flag to enable overlapping tiles for more training data generation. More useful for training the detection part of the Mask R-CNN model. + ignore_bands_indices: List of integer indices of bands to ignore during processing. + use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of the crowns. Returns: None @@ -1154,7 +1199,7 @@ def image_details(fileroot): fileroot: image filename without file extension Returns: - Box structure + A list of two tuples representing the bounding box with buffer: [(xmin, xmax), (ymin, ymax)]. """ image_info = fileroot.split("_") minx = int(image_info[-5]) @@ -1171,11 +1216,11 @@ def is_overlapping_box(test_boxes_array, train_box): """Check if the train box overlaps with any of the test boxes. Args: - test_boxes_array: - train_box: + test_boxes_array: A list of bounding boxes to check against. + train_box: The bounding box to test for overlap. Returns: - Boolean + True if `train_box` overlaps with any box in `test_boxes_array`, False otherwise. """ for test_box in test_boxes_array: test_box_x = test_box[0] @@ -1245,6 +1290,7 @@ def to_traintest_folders( # noqa: C901 test_frac: fraction of tiles to be used for testing folds: number of folds to split the data into strict: if True, training/validation files will be removed if there is any overlap with test files (inc buffer) + seed: Random seed for shuffling to ensure reproducibility. Returns: None