diff --git a/pyproject.toml b/pyproject.toml index ed482f72..bd9efc56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ include-package-data = true [build-system] build-backend = "setuptools.build_meta" requires = [ - "setuptools", + "setuptools>=61.0" ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index f6d336cb..893199b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,10 @@ einops faster-coco-eval graphviz hydra-core +lazy-loader lightning loguru numpy -opencv-python Pillow pycocotools requests diff --git a/train_kwcoco_demo.sh b/train_kwcoco_demo.sh new file mode 100644 index 00000000..92a7b0b4 --- /dev/null +++ b/train_kwcoco_demo.sh @@ -0,0 +1,160 @@ +#!/bin/bash +__doc__=" +YOLO Training Tutorial with KWCOCO DemoData +=========================================== + +This demonstrates an end-to-end YOLO pipeline on toydata generated with kwcoco. +" + +# Define where we will store results +BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train +mkdir -p "$BUNDLE_DPATH" + +echo " +Generate Toy Data +----------------- + +Now that we know where the data and our intermediate files will go, lets +generate the data we will use to train and evaluate with. + +The kwcoco package comes with a commandline utility called 'kwcoco toydata' to +accomplish this. +" + +# Define the names of the kwcoco files to generate +TRAIN_FPATH=$BUNDLE_DPATH/vidshapes_rgb_train/data.kwcoco.json +VALI_FPATH=$BUNDLE_DPATH/vidshapes_rgb_vali/data.kwcoco.json +TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json + +# Generate toy datasets using the "kwcoco toydata" tool +kwcoco toydata vidshapes32-frames10 --dst "$TRAIN_FPATH" +kwcoco toydata vidshapes4-frames10 --dst "$VALI_FPATH" +kwcoco toydata vidshapes2-frames6 --dst "$TEST_FPATH" + +# Ensure legacy COCO structure for now +kwcoco conform "$TRAIN_FPATH" --inplace --legacy=True +kwcoco conform "$VALI_FPATH" --inplace --legacy=True +kwcoco conform "$TEST_FPATH" --inplace --legacy=True + + +echo " +Create the YOLO Configuration +----------------------------- + +Constructing the YOLO configuration is not entirely kwcoco aware +so we need to set +" +# In the current version we need to write configs to the repo itself. +# Its a bit gross, but this should be somewhat robust. +# Find where the yolo repo is installed (we need to be careful that this is the +# our fork of the WongKinYiu variant +REPO_DPATH=$(python -c "import yolo, pathlib; print(pathlib.Path(yolo.__file__).parent.parent)") +MODULE_DPATH=$(python -c "import yolo, pathlib; print(pathlib.Path(yolo.__file__).parent)") +CONFIG_DPATH=$(python -c "import yolo.config, pathlib; print(pathlib.Path(yolo.config.__file__).parent / 'dataset')") +echo "REPO_DPATH = $REPO_DPATH" +echo "MODULE_DPATH = $MODULE_DPATH" +echo "CONFIG_DPATH = $CONFIG_DPATH" + +DATASET_CONFIG_FPATH=$CONFIG_DPATH/kwcoco-demo.yaml + +# Hack to construct the class part of the YAML +CLASS_YAML=$(python -c "if 1: + import kwcoco + train_fpath = kwcoco.CocoDataset('$TRAIN_FPATH') + categories = train_fpath.categories().objs + # It would be nice to have better class introspection, but in the meantime + # do the same sorting as yolo.tools.data_conversion.discretize_categories + categories = sorted(categories, key=lambda cat: cat['id']) + class_num = len(categories) + class_list = [c['name'] for c in categories] + print(f'class_num: {class_num}') + print(f'class_list: {class_list}') +") + + +CONFIG_YAML=" +path: $BUNDLE_DPATH +train: $TRAIN_FPATH +validation: $VALI_FPATH + +$CLASS_YAML +" +echo "$CONFIG_YAML" > "$DATASET_CONFIG_FPATH" + +TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir" +# This might only work in development mode, otherwise we will get site packages +# That still might be fine, but we do want to fix this to run anywhere. +cd "$REPO_DPATH" +export CUDA_VISIBLE_DEVICES="1," +LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \ + task=train \ + dataset=kwcoco-demo \ + use_tensorboard=True \ + use_wandb=False \ + out_path="$TRAIN_DPATH" \ + name=kwcoco-demo \ + cpu_num=0 \ + device=0 \ + accelerator=auto \ + task.data.batch_size=2 \ + "image_size=[640, 640]" \ + task.optimizer.args.lr=0.03 + + +### show how to run inference + +BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train +TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir" +TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json +# Grab a checkpoint +CKPT_FPATH=$(python -c "if 1: + import pathlib + root_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo' + checkpoints = sorted(root_dpath.glob('lightning_logs/*/checkpoints/*')) + print(checkpoints[-1]) +") +echo "CKPT_FPATH = $CKPT_FPATH" + +export DISABLE_RICH_HANDLER=1 +export CUDA_VISIBLE_DEVICES="1," +python yolo/lazy.py \ + task.data.source="$TEST_FPATH" \ + task=inference \ + dataset=kwcoco-demo \ + use_wandb=False \ + out_path=kwcoco-demo-inference \ + name=kwcoco-inference-test \ + cpu_num=8 \ + weight="\"$CKPT_FPATH\"" \ + accelerator=auto \ + task.nms.min_confidence=0.01 \ + task.nms.min_iou=0.3 \ + task.nms.max_bbox=10 + + +### Show how to run validation + +# Grab a checkpoint +BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train +TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir" +CKPT_FPATH=$(python -c "if 1: + import pathlib + ckpt_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo/checkpoints' + checkpoints = sorted(ckpt_dpath.glob('*')) + print(checkpoints[-1]) +") +echo "CKPT_FPATH = $CKPT_FPATH" + +#DISABLE_RICH_HANDLER=1 +LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \ + task=validation \ + dataset=kwcoco-demo \ + use_wandb=False \ + out_path="$TRAIN_DPATH" \ + name=kwcoco-demo \ + cpu_num=0 \ + device=0 \ + weight="'$CKPT_FPATH'" \ + accelerator=auto \ + "task.data.batch_size=2" \ + "image_size=[224,224]" diff --git a/yolo/__init__.py b/yolo/__init__.py index b4b98d7f..77da671a 100644 --- a/yolo/__init__.py +++ b/yolo/__init__.py @@ -1,33 +1,78 @@ -from yolo.config.config import Config, NMSConfig -from yolo.model.yolo import create_model -from yolo.tools.data_loader import AugmentationComposer, create_dataloader -from yolo.tools.drawer import draw_bboxes -from yolo.tools.solver import TrainModel -from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter -from yolo.utils.deploy_utils import FastModelLoader -from yolo.utils.logging_utils import ( - ImageLogger, - YOLORichModelSummary, - YOLORichProgressBar, +""" +The MIT YOLO rewrite +""" + +__autogen__ = """ +mkinit ~/code/YOLO-v9/yolo/__init__.py --nomods --write --lazy-loader + +# Check to see how long it takes to run a simple help command +time python -m yolo.lazy --help +""" + +__submodules__ = { + 'config.config': ['Config', 'NMSConfig'], + 'model.yolo': ['create_model'], + 'tools.data_loader': ['AugmentationComposer', 'create_dataloader'], + 'tools.drawer': ['draw_bboxes'], + 'tools.solver': ['TrainModel'], + 'utils.bounding_box_utils': ['Anc2Box', 'Vec2Box', 'bbox_nms', 'create_converter'], + 'utils.deploy_utils': ['FastModelLoader'], + 'utils.logging_utils': [ + 'ImageLogger', 'YOLORichModelSummary', + 'YOLORichProgressBar', + 'validate_log_directory' + ], + 'utils.model_utils': ['PostProcess'], +} + + +import lazy_loader + + +__getattr__, __dir__, __all__ = lazy_loader.attach( + __name__, + submodules={}, + submod_attrs={ + 'config.config': [ + 'Config', + 'NMSConfig', + ], + 'model.yolo': [ + 'create_model', + ], + 'tools.data_loader': [ + 'AugmentationComposer', + 'create_dataloader', + ], + 'tools.drawer': [ + 'draw_bboxes', + ], + 'tools.solver': [ + 'TrainModel', + ], + 'utils.bounding_box_utils': [ + 'Anc2Box', + 'Vec2Box', + 'bbox_nms', + 'create_converter', + ], + 'utils.deploy_utils': [ + 'FastModelLoader', + ], + 'utils.logging_utils': [ + 'ImageLogger', + 'YOLORichModelSummary', + 'YOLORichProgressBar', + 'validate_log_directory', + ], + 'utils.model_utils': [ + 'PostProcess', + ], + }, ) -from yolo.utils.model_utils import PostProcess - -all = [ - "create_model", - "Config", - "YOLORichProgressBar", - "NMSConfig", - "YOLORichModelSummary", - "validate_log_directory", - "draw_bboxes", - "Vec2Box", - "Anc2Box", - "bbox_nms", - "create_converter", - "AugmentationComposer", - "ImageLogger", - "create_dataloader", - "FastModelLoader", - "TrainModel", - "PostProcess", -] + +__all__ = ['Anc2Box', 'AugmentationComposer', 'Config', 'FastModelLoader', + 'ImageLogger', 'NMSConfig', 'PostProcess', 'TrainModel', 'Vec2Box', + 'YOLORichModelSummary', 'YOLORichProgressBar', 'bbox_nms', + 'create_converter', 'create_dataloader', 'create_model', + 'draw_bboxes', 'validate_log_directory'] diff --git a/yolo/config/config.py b/yolo/config/config.py index 313e4354..50c2423c 100644 --- a/yolo/config/config.py +++ b/yolo/config/config.py @@ -157,6 +157,7 @@ class Config: use_tensorboard: bool weight: Optional[str] + accelerator: str @dataclass diff --git a/yolo/config/config.yaml b/yolo/config/config.yaml index e1a6483d..97f96d7a 100644 --- a/yolo/config/config.yaml +++ b/yolo/config/config.yaml @@ -10,3 +10,4 @@ defaults: - dataset: coco - model: v9-c - general + - trainer: yolo diff --git a/yolo/config/task/train.yaml b/yolo/config/task/train.yaml index e383c352..051f91b2 100644 --- a/yolo/config/task/train.yaml +++ b/yolo/config/task/train.yaml @@ -14,7 +14,8 @@ data: data_augment: Mosaic: 1 # MixUp: 1 - # HorizontalFlip: 0.5 + HorizontalFlip: 0.0 + VerticalFlip: 0.0 RandomCrop: 1 RemoveOutliers: 1e-8 diff --git a/yolo/config/trainer/default.yaml b/yolo/config/trainer/default.yaml new file mode 100644 index 00000000..0e1a5e9b --- /dev/null +++ b/yolo/config/trainer/default.yaml @@ -0,0 +1,47 @@ +# Expose most of the lighting trainer options from the config. + +accelerator : auto # Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") as well as custom accelerator instances. +strategy : auto # Supports different training strategies with aliases as well custom strategies. Default: ``"auto"``. +devices : auto # The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for automatic selection based on the chosen accelerator. Default: ``"auto"``. +num_nodes : 1 # Number of GPU nodes for distributed training. Default: ``1``. +precision : null # Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, or HPUs. Default: ``'32-true'``. +fast_dev_run : False # Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default: ``False``. +min_epochs : null # Force training for at least these many epochs. Disabled by default (None). +max_steps : -1 # Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``. +min_steps : null # Force training for at least these number of steps. Disabled by default (``None``). +max_time : null # Stop training after this amount of time has passed. Disabled by default (``None``). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`. +limit_train_batches : null # How much of training dataset to check (float = fraction, int = num_batches). Default: ``1.0``. +limit_val_batches : null # How much of validation dataset to check (float = fraction, int = num_batches). Default: ``1.0``. +limit_test_batches : null # How much of test dataset to check (float = fraction, int = num_batches). Default: ``1.0``. +limit_predict_batches : null # How much of prediction dataset to check (float = fraction, int = num_batches). Default: ``1.0``. +overfit_batches : 0.0 # Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: ``0.0``. +val_check_interval : null # How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or during iteration-based training. Default: ``1.0``. +check_val_every_n_epoch : 1 # Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``. +num_sanity_val_steps : null # Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: ``2``. +log_every_n_steps : null # How often to log within steps. Default: ``50``. +enable_checkpointing : null # If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. Default: ``True``. +enable_model_summary : null # Whether to enable model summarization by default. Default: ``True``. +accumulate_grad_batches : 1 # Accumulates gradients over k batches before stepping the optimizer. Default: 1. +gradient_clip_val : null # The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: ``None``. +gradient_clip_algorithm : null # The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``. +deterministic : null # If ``True``, sets whether PyTorch operations must use deterministic algorithms. Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. +benchmark : null # The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``. Override to manually set a different value. Default: ``None``. +inference_mode : True # Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during evaluation (``validate``/``test``/``predict``). +use_distributed_sampler : True # Whether to wrap the DataLoader's sampler with :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, we don't do this automatically. +profiler : null # To profile individual steps during training and assist in identifying bottlenecks. Default: ``None``. +detect_anomaly : False # Enable anomaly detection for the autograd engine. Default: ``False``. +barebones : False # Whether to run in "barebones mode", where all features that may impact raw speed are disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training runs. The following features are deactivated: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`, :meth:`~lightning.pytorch.core.LightningModule.log`, :meth:`~lightning.pytorch.core.LightningModule.log_dict`. +plugins : null # Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``. +sync_batchnorm : False # Synchronize batch norm layers between process groups/whole world. Default: ``False``. +reload_dataloaders_every_n_epochs: 0 # Set to a positive integer to reload dataloaders every n epochs. Default: ``0``. + + +### EXPOSED ELSEWHERE +#default_root_dir : null # Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' +#enable_progress_bar : null # Whether to enable to progress bar by default. Default: ``True``. +#max_epochs : null # Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``. + + +### UNSUPPORTED +# callbacks : null # Add a callback or list of callbacks. Default: ``None``. +# logger : null # Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. ``False`` will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. Default: ``True``. diff --git a/yolo/config/trainer/yolo.yaml b/yolo/config/trainer/yolo.yaml new file mode 100644 index 00000000..569a5ccc --- /dev/null +++ b/yolo/config/trainer/yolo.yaml @@ -0,0 +1,10 @@ +defaults: + - default + +num_sanity_val_steps: 0 +precision: "16-mixed" +log_every_n_steps: 1 +gradient_clip_val: 10 +gradient_clip_algorithm: "value" +#deterministic: True +deterministic: False diff --git a/yolo/lazy.py b/yolo/lazy.py index 0f1cc55b..a164018c 100644 --- a/yolo/lazy.py +++ b/yolo/lazy.py @@ -2,33 +2,40 @@ from pathlib import Path import hydra -from lightning import Trainer +from omegaconf.dictconfig import DictConfig +# FIXME: messing with sys.path is a bad idea. Factor this out. project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from yolo.config.config import Config -from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel -from yolo.utils.logging_utils import setup - @hydra.main(config_path="config", config_name="config", version_base=None) -def main(cfg: Config): +def main(cfg: DictConfig): + from yolo.utils.logging_utils import setup callbacks, loggers, save_path = setup(cfg) - trainer = Trainer( - accelerator="auto", - max_epochs=getattr(cfg.task, "epoch", None), - precision="16-mixed", + from yolo.utils.trainer import YoloTrainer as Trainer + from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel + + trainer_kwargs = dict( + ### + # Not Allowed to be overwritten (FIXME: can we fix this) callbacks=callbacks, logger=loggers, - log_every_n_steps=1, - gradient_clip_val=10, - gradient_clip_algorithm="value", - deterministic=True, - enable_progress_bar=not getattr(cfg, "quite", False), + ### + # Uses a non-standard configuration location (Should we refactor this?) default_root_dir=save_path, + max_epochs=getattr(cfg.task, "epoch", None), + enable_progress_bar=not getattr(cfg, "quite", False), ) + if len(cfg.trainer.keys() & trainer_kwargs.keys()) > 0: + unsupported = set(cfg.trainer.keys() & trainer_kwargs.keys()) + raise AssertionError( + f'Cannot specify unsupported trainer args: {unsupported!r} ' + 'in the trainer config' + ) + trainer_kwargs.update(cfg.trainer) + trainer = Trainer(**trainer_kwargs) if cfg.task.task == "train": model = TrainModel(cfg) diff --git a/yolo/model/__init__.py b/yolo/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 42d72208..4aa34437 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -32,7 +32,7 @@ def __init__(self, model_cfg: ModelConfig, class_num: int = 80): def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): self.layer_index = {} output_dim, layer_idx = [3], 1 - logger.info(f":tractor: Building YOLO") + logger.info(":tractor: Building YOLO") for arch_name in model_arch: if model_arch[arch_name]: logger.info(f" :building_construction: Building {arch_name}") @@ -135,27 +135,99 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) if "model_state_dict" in weights: weights = weights["model_state_dict"] + if "state_dict" in weights: + weights = weights["state_dict"] - model_state_dict = self.model.state_dict() + if 0: + # Debug the state of the model and the loaded weights + import networkx as nx + graph_src = nx.DiGraph() + for key in list(weights.keys()): + graph_src.add_node(key) + graph_src.add_node('__root__') + for key in list(weights.keys()): + parts = key.split('.') + graph_src.add_edge('__root__', parts[0]) + for i in range(1, len(parts)): + parent = '.'.join(parts[:i - 1]) + child = '.'.join(parts[:i]) + graph_src.add_edge(parent, child) + nx.write_network_text(graph_src, max_depth=4, sources=['__root__']) - # TODO1: autoload old version weight - # TODO2: weight transform if num_class difference + graph_dst = nx.DiGraph() + dst_weights = self.state_dict() + for key in list(dst_weights.keys()): + graph_dst.add_node(key) + graph_dst.add_node('__root__') + for key in list(dst_weights.keys()): + parts = key.split('.') + graph_dst.add_edge('__root__', parts[0]) + for i in range(1, len(parts)): + parent = '.'.join(parts[:i - 1]) + child = '.'.join(parts[:i]) + graph_dst.add_edge(parent, child) + nx.write_network_text(graph_dst, max_depth=3, sources=['__root__']) - error_dict = {"Mismatch": set(), "Not Found": set()} - for model_key, model_weight in model_state_dict.items(): - if model_key not in weights: - error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) - continue - if model_weight.shape != weights[model_key].shape: - error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) - continue - model_state_dict[model_key] = weights[model_key] + USE_TORCH_LIBERATOR = False + if USE_TORCH_LIBERATOR: - for error_name, error_set in error_dict.items(): - for weight_name in error_set: - logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") + # Torch liberator will figure out the mapping in most cases but it + # is slow. + HACK_DONT_LOAD_EMA_WEIGHTS = True + if HACK_DONT_LOAD_EMA_WEIGHTS: + for key in list(weights.keys()): + if key.startswith('ema.model'): + weights.pop(key) + from torch_liberator.initializer import load_partial_state + load_partial_state(self, weights, verbose=3) - self.model.load_state_dict(model_state_dict) + else: + # TODO1: autoload old version weight + # TODO2: weight transform if num_class difference + + model_state_dict = self.model.state_dict() + + CHECK_FOR_WEIGHT_MUNGING = True + if CHECK_FOR_WEIGHT_MUNGING: + # Handle the simple case of weight munging ourselves + src_keys = list(weights.keys()) + dst_keys = list(model_state_dict.keys()) + src_roots = {p.split('.')[0] for p in src_keys} + dst_roots = {p.split('.')[0] for p in dst_keys} + if len(src_roots & dst_roots) == 0: + src_prefixes = {tuple(p.split('.')[0:2]) for p in src_keys} + if src_prefixes == {('ema', 'model'), ('model', 'model')}: + logger.warning(":warning: Munging weights") + munged_weights = {} + for key in list(weights.keys()): + prefix = 'model.model.' + if key.startswith(prefix): + new_key = key[len(prefix):] + munged_weights[new_key] = weights[key] + logger.warning(f":warning: Munged {len(munged_weights)} / {len(weights)} tensors") + weights = munged_weights + + error_dict = {"Mismatch": set(), "Not Found": set()} + for model_key, model_weight in model_state_dict.items(): + if model_key not in weights: + error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) + continue + if model_weight.shape != weights[model_key].shape: + error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) + continue + model_state_dict[model_key] = weights[model_key] + + for error_name, error_set in error_dict.items(): + for weight_name in error_set: + logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") + + for error_name, error_set in error_dict.items(): + if len(error_set) == 0: + logger.info(f":white_check_mark: Num: weight {error_name}: {len(error_set)}") + else: + logger.warning(f":warning: Num: weight {error_name}: {len(error_set)}") + + self.model.load_state_dict(model_state_dict) def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: @@ -167,10 +239,13 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, Returns: YOLO: An instance of the model defined by the given configuration. """ + logger.info = print + logger.info('CREATE MODEL') OmegaConf.set_struct(model_cfg, False) model = YOLO(model_cfg, class_num) if weight_path: - if weight_path == True: + logger.info('🏋 Initializing weights') + if weight_path is True: weight_path = Path("weights") / f"{model_cfg.name}.pt" elif isinstance(weight_path, str): weight_path = Path(weight_path) @@ -179,8 +254,10 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, logger.info(f"🌐 Weight {weight_path} not found, try downloading") prepare_weight(weight_path=weight_path) if weight_path.exists(): + logger.info(f'🏋 Loading weights from {weight_path}') model.save_load_weights(weight_path) logger.info(":white_check_mark: Success load model & weight") else: - logger.info(":white_check_mark: Success load model") + logger.info(":white_check_mark: Success load model without weights") + logger.info('CREATED MODEL') return model diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index c44f00c6..72bc292a 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader, Dataset from yolo.config.config import DataConfig, DatasetConfig -from yolo.tools.data_augmentation import * +from yolo.tools import data_augmentation from yolo.tools.data_augmentation import AugmentationComposer from yolo.tools.dataset_preparation import prepare_dataset from yolo.utils.dataset_utils import ( @@ -33,10 +33,12 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False) self.base_size = mean(self.image_size) - transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()] + transforms = [getattr(data_augmentation, aug)(prob) for aug, prob in augment_cfg.items() if prob] self.transform = AugmentationComposer(transforms, self.image_size, self.base_size) self.transform.get_more_data = self.get_more_data - self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name)) + + data = self.load_data(Path(dataset_cfg.path), phase_name) + self.img_paths, self.bboxes, self.ratios = tensorlize(data) def load_data(self, dataset_path: Path, phase_name: str): """ @@ -49,7 +51,7 @@ def load_data(self, dataset_path: Path, phase_name: str): Returns: dict: The loaded data from the cache for the specified phase. """ - cache_path = dataset_path / f"{phase_name}.cache" + cache_path = dataset_path / f"{phase_name}-v1.cache" if not cache_path.exists(): logger.info(f":factory: Generating {phase_name} cache") @@ -81,47 +83,113 @@ def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = Fa list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor. """ images_path = dataset_path / "images" / phase_name + labels_path, data_type = locate_label_paths(dataset_path, phase_name) - images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()]) - if data_type == "json": - annotations_index, image_info_dict = create_image_metadata(labels_path) + logger.warning(f"Idenfitied input dataset type as: {data_type}") - data = [] - valid_inputs = 0 - for image_name in track(images_list, description="Filtering data"): - if not image_name.lower().endswith((".jpg", ".jpeg", ".png")): - continue - image_id = Path(image_name).stem + if data_type == 'kwcoco': + """ + More robust data handling that only depends on paths within the + specified manifest file. - if data_type == "json": - image_info = image_info_dict.get(image_id, None) - if image_info is None: - continue - annotations = annotations_index.get(image_info["id"], []) - image_seg_annotations = scale_segmentation(annotations, image_info) - elif data_type == "txt": - label_path = labels_path / f"{image_id}.txt" - if not label_path.is_file(): - continue - with open(label_path, "r") as file: - image_seg_annotations = [list(map(float, line.strip().split())) for line in file] - else: - image_seg_annotations = [] + Principles: - labels = self.load_valid_labels(image_id, image_seg_annotations) + * Dont glob for the images, let the dataset tell you where they are. + + * A Dataset should be referenced as a single URI to a manifest. + The manifest should either contain relevant data or point to + paths for everything. + """ + import kwcoco + coco_dset = kwcoco.CocoDataset(labels_path) + + from yolo.tools.data_conversion import discretize_categories + id_to_idx = discretize_categories(coco_dset.dataset.get("categories", [])) if "categories" in coco_dset.dataset else None + + total_images = coco_dset.n_images - img_path = images_path / image_name if sort_image: - with Image.open(img_path) as img: - width, height = img.size - else: - width, height = 0, 1 - data.append((img_path, labels, width / height)) - valid_inputs += 1 + # Ensure all images have populated sizes + coco_dset._ensure_imgsize() + + # FIXME: do empty images make sense for training YOLO? + ALLOW_EMPTY_IMAGES = 0 + + # Build the expected output + data = [] + valid_inputs = 0 + for coco_img in coco_dset.images().coco_images_iter(): + image_info = coco_img.img + img_path = coco_img.primary_image_filepath() + + if sort_image: + width, height = coco_img['width'], coco_img['height'] + else: + width, height = 0, 1 + + annotations = coco_img.annots().objs + + # Handle filtering as done in + # :func:`dataset_utils.organize_annotations_by_image` + modified_annotations = [] + for anno in annotations: + if id_to_idx: + anno["category_id"] = id_to_idx[anno["category_id"]] + if anno.get("iscrowd", False): # TODO: make configurable + continue + modified_annotations.append(anno) + annotations = modified_annotations + + if ALLOW_EMPTY_IMAGES or len(annotations): + + image_seg_annotations = scale_segmentation(annotations, image_info) + labels = self.load_valid_labels(None, image_seg_annotations) + + data.append((img_path, labels, width / height)) + valid_inputs += 1 + + else: + images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()]) + if data_type == "json": + annotations_index, image_info_dict = create_image_metadata(labels_path) + + data = [] + valid_inputs = 0 + for image_name in track(images_list, description="Filtering data"): + if not image_name.lower().endswith((".jpg", ".jpeg", ".png")): + continue + image_id = Path(image_name).stem + + if data_type == "json": + image_info = image_info_dict.get(image_id, None) + if image_info is None: + continue + annotations = annotations_index.get(image_info["id"], []) + image_seg_annotations = scale_segmentation(annotations, image_info) + elif data_type == "txt": + label_path = labels_path / f"{image_id}.txt" + if not label_path.is_file(): + continue + with open(label_path, "r") as file: + image_seg_annotations = [list(map(float, line.strip().split())) for line in file] + else: + image_seg_annotations = [] + + labels = self.load_valid_labels(image_id, image_seg_annotations) + + img_path = images_path / image_name + if sort_image: + with Image.open(img_path) as img: + width, height = img.size + else: + width, height = 0, 1 + data.append((img_path, labels, width / height)) + valid_inputs += 1 + total_images = len(images_list) data = sorted(data, key=lambda x: x[2], reverse=True) - logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs") + logger.info(f"Recorded {valid_inputs}/{total_images} valid inputs") return data def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]: @@ -139,8 +207,18 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te bboxes = [] for seg_data in seg_data_one_img: cls = seg_data[0] - points = np.array(seg_data[1:]).reshape(-1, 2) - valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) + # This seems like an incorrect check. Putting my fix inside an if + # in case I don't understand why it is this way. + FIX_INCORRECT_CHECK = 1 + if FIX_INCORRECT_CHECK: + points = np.array(seg_data[1:]).reshape(-1, 2) + # This probably should just be a clamp / clip operation + # but I'm keeping it similar to the original + flags = (points >= 0).all(axis=1) & (points <= 1).all(axis=1) + valid_points = points[flags] + else: + points = np.array(seg_data[1:]).reshape(-1, 2) + valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) if valid_points.size > 1: bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) bboxes.append(bbox) @@ -154,9 +232,31 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te def get_data(self, idx): img_path, bboxes = self.img_paths[idx], self.bboxes[idx] valid_mask = bboxes[:, 0] != -1 - with Image.open(img_path) as img: + + # TODO: we can load an overview here to make this much more efficent + USE_OVERVIEW_HACK = 0 + if USE_OVERVIEW_HACK: + # Can leverage overviews to load images faster if they exist. + import delayed_image + delayed = delayed_image.DelayedLoad(img_path) + delayed._load_metadata() + scale_factor = self.base_size / max(delayed.shape[0:2]) + delayed = delayed.scale(scale_factor) + delayed = delayed.optimize() + # Peel off the top warp to only get the overviews + delayed = delayed.subdata + imdata = delayed.finalize() + img = Image.fromarray(imdata) img = img.convert("RGB") - return img, torch.from_numpy(bboxes[valid_mask]), img_path + # import kwimage + # imdata = kwimage.imread(img_path, overview=1, backend='gdal') + else: + with Image.open(img_path) as img: + img = img.convert("RGB") + + valid_boxes = bboxes[valid_mask] + valid_boxes = torch.from_numpy(valid_boxes) + return img, valid_boxes, img_path def get_more_data(self, num: int = 1): indices = torch.randint(0, len(self), (num,)) @@ -201,13 +301,10 @@ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor] - A list of tensors, each corresponding to bboxes for each image in the batch. """ batch_size = len(batch) - target_sizes = [item[1].size(0) for item in batch] # TODO: Improve readability of these process # TODO: remove maxBbox or reduce loss function memory usage - batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5) - batch_targets[:, :, 0] = -1 - for idx, target_size in enumerate(target_sizes): - batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100] + labels = [batch_item[1] for batch_item in batch] + batch_targets = pack_targets(labels, max_targets=100) batch_images, _, batch_reverse, batch_path = zip(*batch) batch_images = torch.stack(batch_images) @@ -216,6 +313,40 @@ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor] return batch_size, batch_images, batch_targets, batch_reverse, batch_path +def pack_targets(labels, max_targets=100): + """ + Collate truth bounding boxes into a fixed size tensor with -1 padding + + Args: + labels (List[Tensor]): list of target boxes in the form + left_x, top_y, right_x, bottom_y, class_index + + Returns: + Tensor: with shape (batch, targets, 5): + The packed tensor with at most max_targets per batch item. + A -1 indicates when there is no target + + Example: + >>> import kwimage + >>> kwimage.Detections.random(1).tensor().data + >>> target_sizes = [3, 1, 0, 2] + >>> batch_dets = [kwimage.Detections.random(s).tensor() for s in target_sizes] + >>> # Construct YOLO-packed batch labels from kwimage Detections + >>> labels = [torch.concat([det.boxes.to_ltrb().data, + >>> det.class_idxs[:, None]], dim=1) + >>> for det in batch_dets] + >>> batch_targets = pack_targets(labels, max_targets=10) + >>> assert batch_targets.shape == (4, 3, 5) + """ + target_sizes = [targets.size(0) for targets in labels] + batch_size = len(labels) + batch_targets = torch.zeros(batch_size, min(max(target_sizes), max_targets), 5) + batch_targets[:, :, 0] = -1 + for idx, target_size in enumerate(target_sizes): + batch_targets[idx, :min(target_size, max_targets)] = labels[idx][:max_targets] + return batch_targets + + def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"): if task == "inference": return StreamDataLoader(data_cfg) @@ -242,6 +373,17 @@ def __init__(self, data_cfg: DataConfig): self.transform = AugmentationComposer([], data_cfg.image_size) self.stop_event = Event() + self.known_length = None + + self._is_coco = str(self.source).endswith(('.zip', '.json')) + if self._is_coco: + # Prevent race conditions and ensure the coco file is loaded before + # we start a thread (OR improve thread architecture) + import kwcoco + self.coco_dset = kwcoco.CocoDataset(self.source) + self.known_length = self.coco_dset.n_images + ... + if self.is_stream: import cv2 @@ -253,26 +395,48 @@ def __init__(self, data_cfg: DataConfig): self.thread.start() def load_source(self): - if self.source.is_dir(): # image folder + if self._is_coco: + self.process_preloaded_coco() + elif self.source.is_dir(): # image folder self.load_image_folder(self.source) elif any(self.source.suffix.lower().endswith(ext) for ext in [".mp4", ".avi", ".mkv"]): # Video file self.load_video_file(self.source) - else: # Single image + else: + # Single image self.process_image(self.source) + def process_preloaded_coco(self): + coco_dset = self.coco_dset + for image_id in coco_dset.images(): + if self.stop_event.is_set(): + break + classes = coco_dset.object_categories() # todo: cache? + coco_img = coco_dset.coco_image(image_id) + file_path = coco_img.primary_image_filepath() + metadata = { + 'img': coco_img.img, + 'classes': classes, + } + self.process_image(file_path, metadata) + def load_image_folder(self, folder): folder_path = Path(folder) + # FIXME: This will just yield as many images as it can before the + # dataloader len function is called, and at that point it will + # only process up to the the number of images that were already + # loaded at that point, even though this function is doing + # more work in the background. for file_path in folder_path.rglob("*"): if self.stop_event.is_set(): break if file_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]: self.process_image(file_path) - def process_image(self, image_path): + def process_image(self, image_path, metadata=None): image = Image.open(image_path).convert("RGB") if image is None: raise ValueError(f"Error loading image: {image_path}") - self.process_frame(image) + self.process_frame(image, metadata) def load_video_file(self, video_path): import cv2 @@ -285,7 +449,7 @@ def load_video_file(self, video_path): self.process_frame(frame) cap.release() - def process_frame(self, frame): + def process_frame(self, frame, metadata=None): if isinstance(frame, np.ndarray): # TODO: we don't need cv2 import cv2 @@ -296,10 +460,11 @@ def process_frame(self, frame): frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5)) frame = frame[None] rev_tensor = rev_tensor[None] + item = (frame, rev_tensor, origin_frame, metadata) if not self.is_stream: - self.queue.put((frame, rev_tensor, origin_frame)) + self.queue.put(item) else: - self.current_frame = (frame, rev_tensor, origin_frame) + self.current_frame = item def __iter__(self) -> Generator[Tensor, None, None]: return self @@ -327,4 +492,7 @@ def stop(self): self.thread.join(timeout=1) def __len__(self): - return self.queue.qsize() if not self.is_stream else 0 + if self.known_length is None: + return self.queue.qsize() if not self.is_stream else 0 + else: + return self.known_length diff --git a/yolo/tools/loss_functions.py b/yolo/tools/loss_functions.py index 79fe1cf9..d4be22c7 100644 --- a/yolo/tools/loss_functions.py +++ b/yolo/tools/loss_functions.py @@ -15,6 +15,7 @@ def __init__(self) -> None: super().__init__() # TODO: Refactor the device, should be assign by config # TODO: origin v9 assing pos_weight == 1? + # TODO: Add ability to specify class weights self.bce = BCEWithLogitsLoss(reduction="none") def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any: @@ -68,6 +69,8 @@ def forward( class YOLOLoss: def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80, reg_max: int = 16) -> None: + # TODO: refactor to know what the class labels actually are instead of + # just the number. self.class_num = class_num self.vec2box = vec2box @@ -107,8 +110,25 @@ def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Ten class DualLoss: + """ + Example: + >>> import torch + >>> from yolo.tools.loss_functions import DualLoss + >>> from yolo.utils.bounding_box_utils import Vec2Box + >>> from yolo.utils.config_utils import build_config + >>> cfg = build_config(overrides=['task=train']) + >>> device = 'cpu' + >>> vec2box = Vec2Box(model=None, anchor_cfg=cfg.model.anchor, image_size=cfg.image_size, device=device) + >>> self = DualLoss(cfg, vec2box) + >>> targets = torch.zeros(1, 20, 5, device=device) + >>> aux_predicts = [torch.zeros(1, 8400, *cn, device=device) for cn in [(80,), (4, 16), (4,)]] + >>> main_predicts = [torch.zeros(1, 8400, *cn, device=device) for cn in [(80,), (4, 16), (4,)]] + >>> loss, loss_dict = self(aux_predicts, main_predicts, targets) + """ def __init__(self, cfg: Config, vec2box) -> None: loss_cfg = cfg.task.loss + # TODO: refactor to know what the class labels actually are instead of + # just the number. self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.dataset.class_num, reg_max=cfg.model.anchor.reg_max) self.aux_rate = loss_cfg.aux diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index c20b1ab3..aee1f426 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -16,6 +16,8 @@ class BaseModel(LightningModule): def __init__(self, cfg: Config): super().__init__() + # TODO: refactor to know what the class labels actually are instead of + # just the number. self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) def forward(self, x): @@ -47,15 +49,20 @@ def val_dataloader(self): def validation_step(self, batch, batch_idx): batch_size, images, targets, rev_tensor, img_paths = batch H, W = images.shape[2:] - predicts = self.post_process(self.ema(images), image_size=[W, H]) + raw_predicts = self.ema(images) + predicts = self.post_process(raw_predicts, image_size=[W, H]) mAP = self.metric( [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets] ) - return predicts, mAP + outputs = { + 'predicts': predicts, + 'mAP': mAP, + } + return outputs def on_validation_epoch_end(self): epoch_metrics = self.metric.compute() - del epoch_metrics["classes"] + epoch_metrics.pop("classes", None) self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True) self.log_dict( {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, @@ -69,6 +76,13 @@ class TrainModel(ValidateModel): def __init__(self, cfg: Config): super().__init__(cfg) self.cfg = cfg + + # Flag that lets plugins communicate with the model + self.request_draw = False + + # TODO: if we defer creating the model until the dataset is loaded, we + # can introspect the number of categories and other things to make user + # configuration have less interdependencies and thus be more robust. self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task) def setup(self, stage): @@ -86,10 +100,11 @@ def on_train_epoch_start(self): def training_step(self, batch, batch_idx): lr_dict = self.trainer.optimizers[0].next_batch() + batch_size, images, targets, *_ = batch - predicts = self(images) - aux_predicts = self.vec2box(predicts["AUX"]) - main_predicts = self.vec2box(predicts["Main"]) + raw_predicts = self(images) + aux_predicts = self.vec2box(raw_predicts["AUX"]) + main_predicts = self.vec2box(raw_predicts["Main"]) loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets) self.log_dict( loss_item, @@ -99,7 +114,16 @@ def training_step(self, batch, batch_idx): rank_zero_only=True, ) self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True) - return loss * batch_size + total_loss = loss * batch_size + stage = self.trainer.state.stage.value + self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=batch_size) + output = {} + output['loss'] = total_loss + if self.request_draw: + H, W = images.shape[2:] + predicts = self.post_process(raw_predicts, image_size=[W, H]) + output['predicts'] = predicts + return output def configure_optimizers(self): optimizer = create_optimizer(self.model, self.cfg.task.optimizer) @@ -114,6 +138,28 @@ def __init__(self, cfg: Config): # TODO: Add FastModel self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) + if getattr(self.predict_loader, '_is_coco', False): + # Setup a kwcoco file to write to if the user requests it. + self.pred_dset = self.predict_loader.coco_dset.copy() + self.pred_dset.reroot(absolute=True) + + def on_predict_start(self, *args, **kwargs): + import rich + import ubelt as ub + out_dpath = ub.Path(self.trainer.default_root_dir).absolute() + rich.print(f'Predict in: [link={out_dpath}]{out_dpath}[/link]') + if self.predict_loader._is_coco: + out_coco_fpath = out_dpath / 'pred.kwcoco.zip' + self.pred_dset.fpath = out_coco_fpath + rich.print(f'Coco prediction is enabled in: {self.pred_dset.fpath}') + + def on_predict_end(self, *args, **kwargs): + print('[InferenceModel] on_predict_end') + dset = self.pred_dset + print(f'dset.fpath={dset.fpath}') + dset.dump() + print('Finished prediction') + def setup(self, stage): self.vec2box = create_converter( self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device @@ -124,9 +170,29 @@ def predict_dataloader(self): return self.predict_loader def predict_step(self, batch, batch_idx): - images, rev_tensor, origin_frame = batch + + images, rev_tensor, origin_frame, metadata = batch + + assert metadata is not None + img = metadata['img'] + classes = metadata['classes'] + image_id = img['id'] predicts = self.post_process(self(images), rev_tensor=rev_tensor) + + WRITE_TO_COCO = 1 + if WRITE_TO_COCO: + from yolo.utils.kwcoco_utils import tensor_to_kwimage + dset = self.pred_dset + for yolo_annot_tensor in predicts: + pred_dets = tensor_to_kwimage(yolo_annot_tensor, classes=classes).numpy() + anns = list(pred_dets.to_coco(dset=dset)) + print(f"Detected {len(anns)} boxes") + for ann in anns: + ann['image_id'] = image_id + dset.add_annotation(**ann) + img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) + if getattr(self.predict_loader, "is_stream", None): fps = self._display_stream(img) else: diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 0357bfdc..4e30c3e5 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -120,6 +120,15 @@ def generate_anchors(image_size: List[int], strides: List[int]): Returns: all_anchors [HW x 2]: all_scalers [HW]: The index of the best targets for each anchors + + Example: + >>> from yolo.utils.bounding_box_utils import * # NOQA + >>> from collections import Counter + >>> image_size = (640, 640) + >>> strides = [8, 16, 32] + >>> all_anchors, all_scalers = generate_anchors(image_size, strides) + >>> histogram = Counter(map(int, all_scalers)) + >>> assert histogram == {8: 6400, 16: 1600, 32: 400} """ W, H = image_size anchors = [] @@ -142,6 +151,39 @@ def generate_anchors(image_size: List[int], strides: List[int]): class BoxMatcher: + """ + Example: + >>> from yolo.utils.bounding_box_utils import * # NOQA + >>> import torch + >>> from yolo.utils.bounding_box_utils import BoxMatcher + >>> from yolo.utils.bounding_box_utils import Vec2Box + >>> from yolo.utils.config_utils import build_config + >>> from yolo.tools.data_loader import pack_targets + >>> import kwimage + >>> cfg = build_config(overrides=['task=train']) + >>> match_cfg = cfg.task.loss.matcher + >>> device = 'cpu' + >>> vec2box = Vec2Box(model=None, anchor_cfg=cfg.model.anchor, image_size=cfg.image_size, device=device) + >>> reg_max = cfg.model.anchor['reg_max'] + >>> C = class_num = 5 + >>> B = batch_size = 4 + >>> A = num_anchors = vec2box.anchor_grid.shape[0] + >>> # Build the Box Matcher + >>> self = BoxMatcher(match_cfg, class_num, vec2box, reg_max) + >>> # Generate random targets (TODO: ensure scales agree with what is used in the forward pass) + >>> target_sizes = [3, 1, 0, 2] + >>> batch_dets = [kwimage.Detections.random(s).tensor() for s in target_sizes] + >>> labels = [torch.concat([det.boxes.to_ltrb().data, + >>> det.class_idxs[:, None]], dim=1) + >>> for det in batch_dets] + >>> target = pack_targets(labels, max_targets=10) + >>> # Generate random predictions (TODO: ensure scales agree with what is used in the forward pass) + >>> predict_cls = torch.rand(B, A, C) + >>> predict_bbox = kwimage.Boxes.random(B * A).to_ltrb().tensor().data.view(B, A, 4) + >>> predict = (predict_bbox, predict_cls) + >>> # Call function + >>> anchor_matched_targets, valid_mask = self(target, predict) + """ def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None: self.class_num = class_num self.vec2box = vec2box @@ -336,6 +378,22 @@ def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tens class Vec2Box: + """ + Example: + >>> import torch + >>> from yolo.utils.bounding_box_utils import Vec2Box + >>> from yolo.utils.config_utils import build_config + >>> cfg = build_config(overrides=['task=train']) + >>> match_cfg = cfg.task.loss.matcher + >>> device = 'cpu' + >>> vec2box = Vec2Box(model=None, anchor_cfg=cfg.model.anchor, image_size=cfg.image_size, device=device) + >>> # TODO: document form of predicts + >>> B, C, h, w = 2, 3, 7, 11 + >>> A = 1 + >>> R = 1 + >>> predicts = [] + + """ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): self.device = device diff --git a/yolo/utils/callbacks/__init__.py b/yolo/utils/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/yolo/utils/callbacks/tensorboard_plotter.py b/yolo/utils/callbacks/tensorboard_plotter.py new file mode 100644 index 00000000..92f2bb16 --- /dev/null +++ b/yolo/utils/callbacks/tensorboard_plotter.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +r""" +Parses an existing tensorboard event file and draws the plots as pngs on disk +in the monitor/tensorboard directory. + +Derived from netharn/mixins.py for dumping tensorboard plots to disk +""" +import scriptconfig as scfg +import os +import ubelt as ub +from lightning.pytorch.callbacks import Callback + + +__all__ = ['TensorboardPlotter'] + + +# TODO: can move the callback to its own file and have the CLI variant with +# core logic live separately for faster response times when using the CLI (i.e. +# avoid lightning import overhead). +class TensorboardPlotter(Callback): + """ + Asynchronously dumps PNGs to disk visualize tensorboard scalars. + exit + + Example: + >>> # xdoctest: +REQUIRES(module:tensorboard) + >>> from geowatch.utils.lightning_ext import demo + >>> from geowatch.monkey import monkey_lightning + >>> import pytorch_lightning as pl + >>> import pandas as pd + >>> monkey_lightning.disable_lightning_hardware_warnings() + >>> self = demo.LightningToyNet2d(num_train=55) + >>> default_root_dir = ub.Path.appdir('lightning_ext/tests/TensorboardPlotter').ensuredir() + >>> # + >>> trainer = pl.Trainer(callbacks=[TensorboardPlotter()], + >>> default_root_dir=default_root_dir, + >>> max_epochs=3, accelerator='cpu', devices=1) + >>> trainer.fit(self) + >>> train_dpath = trainer.logger.log_dir + >>> print('trainer.logger.log_dir = {!r}'.format(train_dpath)) + >>> data = read_tensorboard_scalars(train_dpath) + >>> for key in data.keys(): + >>> d = data[key] + >>> df = pd.DataFrame({key: d['ydata'], 'step': d['xdata'], 'wall': d['wall']}) + >>> print(df) + """ + + def _on_epoch_end(self, trainer, logs=None, serial=False): + # The following function draws the tensorboard result. This might take + # a some non-trivial amount of time so we attempt to run in a separate + # process. + from kwutil import util_environ + if util_environ.envflag('DISABLE_TENSORBOARD_PLOTTER'): + return + + if trainer.global_rank != 0: + return + + # train_dpath = trainer.logger.log_dir + train_dpath = trainer.log_dir + if train_dpath is None: + import warnings + warnings.warn('The trainer logdir is not set. Cannot dump a batch plot') + return + + func = _dump_measures + + model = trainer.model + # TODO: get step number + if hasattr(model, 'get_cfgstr'): + model_cfgstr = model.get_cfgstr() + else: + # from geowatch.utils.lightning_ext import util_model + from kwutil.slugify_ext import smart_truncate + # hparams = util_model.model_hparams(model) + model_config = { + 'type': str(model.__class__), + # 'hp': smart_truncate(ub.urepr(hparams, compact=1, nl=0), max_length=8), + } + model_cfgstr = smart_truncate(ub.urepr( + model_config, compact=1, nl=0), max_length=64) + + args = (train_dpath, model_cfgstr) + + proc_name = 'dump_tensorboard' + + if not serial: + # This causes thread-unsafe warning messages in the inner loop + # Likely because we are forking while a thread is alive + if not hasattr(trainer, '_internal_procs'): + trainer._internal_procs = ub.ddict(dict) + + # Clear finished processes from the pool + for pid in list(trainer._internal_procs[proc_name].keys()): + proc = trainer._internal_procs[proc_name][pid] + if not proc.is_alive(): + trainer._internal_procs[proc_name].pop(pid) + + # only start a new process if there is room in the pool + if len(trainer._internal_procs[proc_name]) < 1: + import multiprocessing + proc = multiprocessing.Process(target=func, args=args) + proc.daemon = True + proc.start() + trainer._internal_procs[proc_name][proc.pid] = proc + else: + # Draw is already in progress + pass + else: + func(*args) + + def on_train_epoch_end(self, trainer, logs=None): + return self._on_epoch_end(trainer, logs=logs) + + def on_validation_epoch_end(self, trainer, logs=None): + return self._on_epoch_end(trainer, logs=logs) + + def on_test_epoch_end(self, trainer, logs=None): + return self._on_epoch_end(trainer, logs=logs) + + +def read_tensorboard_scalars(train_dpath, verbose=1, cache=1): + """ + Reads all tensorboard scalar events in a directory. + Caches them because reading events of interest from protobuf can be slow. + + Ignore: + train_dpath = '/home/joncrall/.cache/lightning_ext/tests/TensorboardPlotter/lightning_logs/version_2' + tb_data = read_tensorboard_scalars(train_dpath) + """ + try: + from tensorboard.backend.event_processing import event_accumulator + except ImportError: + raise ImportError('tensorboard/tensorflow is not installed') + train_dpath = ub.Path(train_dpath) + event_paths = sorted(train_dpath.glob('events.out.tfevents*')) + # make a hash so we will re-read of we need to + cfgstr = ub.hash_data(list(map(ub.hash_file, event_paths))) if cache else '' + cacher = ub.Cacher('tb_scalars', depends=cfgstr, enabled=cache, + dpath=train_dpath / '_cache') + datas = cacher.tryload() + if datas is None: + datas = {} + for p in ub.ProgIter(list(reversed(event_paths)), desc='read tensorboard', + enabled=verbose, verbose=verbose * 3): + p = os.fspath(p) + if verbose: + print('reading tensorboard scalars') + ea = event_accumulator.EventAccumulator(p) + if verbose: + print('loading tensorboard scalars') + ea.Reload() + if verbose: + print('iterate over scalars') + for key in ea.scalars.Keys(): + if key not in datas: + datas[key] = {'xdata': [], 'ydata': [], 'wall': []} + subdatas = datas[key] + events = ea.scalars.Items(key) + for e in events: + subdatas['xdata'].append(int(e.step)) + subdatas['ydata'].append(float(e.value)) + subdatas['wall'].append(float(e.wall_time)) + + # Order all information by its wall time + for _key, subdatas in datas.items(): + sortx = ub.argsort(subdatas['wall']) + for d, vals in subdatas.items(): + subdatas[d] = list(ub.take(vals, sortx)) + cacher.save(datas) + return datas + + +def _write_helper_scripts(out_dpath, train_dpath): + """ + Writes scripts to let the user refresh data on the fly + """ + train_dpath_ = train_dpath.resolve().shrinkuser() + + # TODO: make this a nicer python script that aranges figures nicely. + stack_fpath = (out_dpath / 'stack.sh') + stack_fpath.write_text(ub.codeblock( + fr''' + #!/usr/bin/env bash + kwimage stack_images --out "{train_dpath_}/monitor/tensorboard-stack.png" -- {train_dpath_}/monitor/tensorboard/*.png + ''')) + try: + stack_fpath.chmod('ug+x') + except PermissionError as ex: + print(f'Unable to change permissions on {stack_fpath}: {ex}') + + refresh_fpath = (out_dpath / 'redraw.sh') + refresh_fpath.write_text(ub.codeblock( + fr''' + #!/usr/bin/env bash + python -m tensorboard_plotter \ + {train_dpath_} + ''')) + try: + refresh_fpath.chmod('ug+x') + except PermissionError as ex: + print(f'Unable to change permissions on {refresh_fpath}: {ex}') + + +def _dump_measures(train_dpath, title='?name?', smoothing='auto', ignore_outliers=True, verbose=0): + """ + This is its own function in case we need to modify formatting + """ + import kwplot + import kwutil + from kwplot.auto_backends import BackendContext + import pandas as pd + import numpy as np # NOQA + + train_dpath = ub.Path(train_dpath).resolve() + if not train_dpath.name.startswith('version_'): + # hack: use knowledge of common directory structures to find + # the root directory of training output for a specific training run + if not (train_dpath / 'monitor').exists(): + if (train_dpath / '../monitor').exists(): + train_dpath = (train_dpath / '..') + elif (train_dpath / '../../monitor').exists(): + train_dpath = (train_dpath / '../..') + + tb_data = read_tensorboard_scalars(train_dpath, cache=0, verbose=verbose) + + out_dpath = ub.Path(train_dpath, 'monitor', 'tensorboard').ensuredir() + _write_helper_scripts(out_dpath, train_dpath) + + if isinstance(smoothing, str) and smoothing == 'auto': + smoothing_values = [0.6, 0.95] + elif isinstance(smoothing, list): + smoothing_values = [smoothing] + else: + smoothing_values = [smoothing] + + # plot_keys = [k for k in tb_data.keys() if '/' not in k] + plot_keys = [k for k in tb_data.keys()] + keys = set(tb_data.keys()).intersection(set(plot_keys)) + # no idea what hp metric is, but it doesn't seem important + # keys = keys - {'hp_metric'} + + if len(keys) == 0: + print('warning: no known keys to plot') + print(f'available keys: {list(tb_data.keys())}') + + USE_NEW_PLOT_PREF = 1 + if USE_NEW_PLOT_PREF: + # TODO: finish this + default_plot_preferences = kwutil.Yaml.loads(ub.codeblock( + ''' + attributes: + - pattern: [ + '*_acc*', '*_ap*', '*_mAP*', '*_auc*', '*_mcc*', '*_brier*', '*_mauc*', + '*_f1*', '*_iou*', + ] + ymax: 1 + ymin: 0 + + - pattern: ['*error*', '*loss*'] + ymin: 0 + + - pattern: ['*lr*', '*momentum*', '*epoch*'] + smoothing: null + + - pattern: ['hp_metric'] + ignore: true + ''')) + plot_preferences_fpath = train_dpath / 'plot_preferences.yaml' + if plot_preferences_fpath.exists(): + user_plot_preferences = kwutil.Yaml.coerce(plot_preferences_fpath) + plot_preferences = default_plot_preferences.copy() + plot_preferences.update(user_plot_preferences) + else: + plot_preferences = default_plot_preferences + + for item in plot_preferences['attributes']: + item['pattern_'] = kwutil.util_pattern.MultiPattern.coerce(item['pattern']) + + key_table = [] + for plot_key in keys: + row = {'key': plot_key} + row['smoothing'] = smoothing_values + for item in plot_preferences['attributes']: + if item['pattern_'].match(plot_key.lower()): + row.update(item) + row.pop('pattern', None) + row.pop('pattern_', None) + key_table.append(row) + else: + y01_measures = [ + '_acc', '_ap', '_mAP', '_auc', '_mcc', '_brier', '_mauc', + '_f1', '_iou', + ] + y0_measures = ['error', 'loss'] + HACK_NO_SMOOTH = {'lr', 'momentum', 'epoch'} + key_table = [] + for plot_key in tb_data.keys(): + row = {'key': plot_key} + if plot_key == 'hp_metric' or '/' in plot_key: + row['ignore'] = True + continue + if plot_key in y01_measures: + row['ymax'] = 1 + row['ymin'] = 0 + if plot_key in y0_measures: + if ignore_outliers: + row['ymax'] = 'ignore_outliers' + row['ymin'] = 0 + if plot_key in HACK_NO_SMOOTH: + row['smoothing'] = None + else: + row['smoothing'] = smoothing_values + key_table.append(row) + + if 0: + print(f'key_table = {ub.urepr(key_table, nl=1)}') + print(pd.DataFrame(key_table)) + key_table = [r for r in key_table if not r.get('ignore', False)] + + with BackendContext('agg'): + import seaborn as sns + sns.set() + nice = title + fig = kwplot.figure(fnum=1) + fig.clf() + ax = fig.gca() + + key_iter = ub.ProgIter(key_table, desc='dump plots', verbose=verbose * 3) + for key_row in key_iter: + key = key_row['key'] + key_iter.set_extra(key) + snskw = { + 'y': key, + 'x': 'step', + } + + d = tb_data[key] + df_orig = pd.DataFrame({key: d['ydata'], 'step': d['xdata']}) + num_non_nan = (~df_orig[key].isnull()).sum() + num_nan = (df_orig[key].isnull()).sum() + df_orig['smoothing'] = 0.0 + variants = [df_orig] + smoothing_values = key_row['smoothing'] + if smoothing_values: + for _smoothing_value in smoothing_values: + # if 0: + # # TODO: can we get a hueristic for how much smoothing + # # we might want? Look at the entropy of the derivative + # # curve? + # import scipy.stats + # deriv = np.diff(df_orig[key]) + # counts1, bins1 = np.histogram(deriv[deriv < 0], bins=25) + # counts2, bins2 = np.histogram(deriv[deriv >= 0], bins=25) + # counts = np.hstack([counts1, counts2]) + # # bins = np.hstack([bins1, bins2]) + # # dict(zip(bins, counts)) + # entropy = scipy.stats.entropy(counts) + # print(f'entropy={entropy}') + if _smoothing_value > 0: + df_smooth = df_orig.copy() + beta = _smoothing_value + ydata = df_orig[key] + df_smooth[key] = smooth_curve(ydata, beta) + df_smooth['smoothing'] = _smoothing_value + variants.append(df_smooth) + + if len(variants) == 1: + df = variants[0] + else: + if verbose: + print('Combine smoothed variants') + df = pd.concat(variants).reset_index() + snskw['hue'] = 'smoothing' + + kw = {} + + ymin = key_row.get('ymin', None) + ymax = key_row.get('max', None) + if ymin is not None: + kw['ymin'] = float(ymin) + if ymax is not None: + if ymax == 'ignore_outliers': + if num_non_nan > 3: + if verbose: + print('Finding outliers') + low, kw['ymax'] = tensorboard_inlier_ylim(ydata) + else: + kw['ymax'] = float(ymax) + + if verbose: + print('Begin plot') + # NOTE: this is actually pretty slow + # TODO: port title buidler to kwplot and use it + ax.cla() + try: + if num_non_nan <= 1: + sns.scatterplot(data=df, **snskw) + else: + # todo: we have an alternative in kwplot can + # handle nans, use that instead. + sns.lineplot(data=df, **snskw) + except Exception as ex: + title = nice + '\n' + key + str(ex) + else: + title = nice + '\n' + key + initial_ylim = ax.get_ylim() + if kw.get('ymax', None) is None: + kw['ymax'] = initial_ylim[1] + if kw.get('ymin', None) is None: + kw['ymin'] = initial_ylim[0] + try: + ax.set_ylim(kw['ymin'], kw['ymax']) + except Exception: + ... + if num_nan > 0: + title += '(num_nan={})'.format(num_nan) + + ax.set_title(title) + + # png is smaller than jpg for this kind of plot + fpath = out_dpath / (key.replace('/', '-') + '.png') + if verbose: + print('Save plot: ' + str(fpath)) + ax.figure.savefig(fpath) + ax.figure.subplots_adjust(top=0.8) + do_tensorboard_stack(train_dpath) + + +def do_tensorboard_stack(train_dpath): + # Do the kwimage stack as well. + import kwimage + tensorboard_dpath = train_dpath / 'monitor/tensorboard' + monitor_dpath = train_dpath / 'monitor' + image_paths = sorted(tensorboard_dpath.glob('*.png')) + images = [kwimage.imread(fpath) for fpath in image_paths] + canvas = kwimage.stack_images_grid(images) + stack_fpath = monitor_dpath / 'tensorboard-stack.png' + kwimage.imwrite(stack_fpath, canvas) + + +def smooth_curve(ydata, beta): + """ + Curve smoothing algorithm used by tensorboard + """ + import pandas as pd + alpha = 1.0 - beta + if alpha <= 0: + return ydata + ydata_smooth = pd.Series(ydata).ewm(alpha=alpha).mean().values + return ydata_smooth + + +# def inlier_ylim(ydata): +# """ +# outlier removal used by tensorboard +# """ +# import kwarray +# normalizer = kwarray.find_robust_normalizers(ydata, { +# 'low': 0.05, +# 'high': 0.95, +# }) +# low = normalizer['min_val'] +# high = normalizer['max_val'] +# return (low, high) + + +def tensorboard_inlier_ylim(ydata): + """ + outlier removal used by tensorboard + """ + import numpy as np + q1 = 0.05 + q2 = 0.95 + low_, high_ = np.quantile(ydata, [q1, q2]) + + # Extrapolate how big the entire span should be based on inliers + inner_q = q2 - q1 + inner_extent = high_ - low_ + extrap_total_extent = inner_extent / inner_q + + # amount of padding to add to either side + missing_p1 = q1 + missing_p2 = 1 - q2 + frac1 = missing_p1 / (missing_p2 + missing_p1) + frac2 = missing_p2 / (missing_p2 + missing_p1) + missing_extent = extrap_total_extent - inner_extent + + pad1 = missing_extent * frac1 + pad2 = missing_extent * frac2 + + low = low_ - pad1 + high = high_ + pad2 + return (low, high) + + +def redraw_cli(train_dpath): + """ + Create png plots for the tensorboard data in a training directory. + """ + from kwutil.util_yaml import Yaml + train_dpath = ub.Path(train_dpath) + + expt_name = train_dpath.parent.parent.name + + hparams_fpath = train_dpath / 'hparams.yaml' + if hparams_fpath.exists(): + print('Found hparams') + hparams = Yaml.load(hparams_fpath) + if 'name' in hparams: + title = hparams['name'] + else: + from kwutil.slugify_ext import smart_truncate + model_config = { + # 'type': str(model.__class__), + 'hp': smart_truncate(ub.urepr(hparams, compact=1, nl=0), max_length=8), + } + model_cfgstr = smart_truncate(ub.urepr( + model_config, compact=1, nl=0), max_length=64) + title = model_cfgstr + title = expt_name + '\n' + title + else: + print('Did not find hparams') + title = expt_name + + if 1: + # Add in other relevant data + # ... + config_fpath = train_dpath / 'config.yaml' + if config_fpath.exists(): + + config = Yaml.load(config_fpath) + trainer_config = config.get('trainer', {}) + optimizer_config = config.get('optimizer', {}) + data_config = config.get('data', {}) + optimizer_args = optimizer_config.get('init_args', {}) + + devices = trainer_config.get('devices', None) + + batch_size = data_config.get('batch_size', None) + accum_batches = trainer_config.get('accumulate_grad_batches', None) + optim_lr = optimizer_args.get('lr', None) + decay = optimizer_args.get('weight_decay', None) + # optim_name = optimizer_config.get('class_path', '?').split('.')[-1] + learn_dynamics_str = ub.codeblock( + f''' + BS=({batch_size} x {accum_batches}), LR={optim_lr}, decay={decay}, devs={devices} + ''' + ) + title = title + '\n' + learn_dynamics_str + # print(learn_dynamics_str) + + print(f'train_dpath={train_dpath}') + print(f'title={title}') + _dump_measures(train_dpath, title, verbose=1) + import rich + tensorboard_dpath = train_dpath / 'monitor/tensorboard' + rich.print(f'[link={tensorboard_dpath}]{tensorboard_dpath}[/link]') + + +class TensorboardPlotterCLI(scfg.DataConfig): + """ + Helper CLI executable to redraw on demand. + """ + train_dpath = scfg.Value('.', help='train_dpath', position=1) + + @classmethod + def main(cls, cmdline=1, **kwargs): + import rich + config = cls.cli(cmdline=cmdline, data=kwargs, strict=True) + rich.print('config = ' + ub.urepr(config, nl=1)) + redraw_cli(config.train_dpath) + + +if __name__ == '__main__': + """ + CommandLine: + python -m yolo.utils.callbacks.tensorboard_plotter . + python ~/code/YOLO-v9/yolo/utils/callbacks/tensorboard_plotter.py . + """ + TensorboardPlotterCLI.main() diff --git a/yolo/utils/config_utils.py b/yolo/utils/config_utils.py new file mode 100644 index 00000000..465373c4 --- /dev/null +++ b/yolo/utils/config_utils.py @@ -0,0 +1,36 @@ +import omegaconf +from typing import List + + +def build_config(overrides: List[str] = []) -> omegaconf.DictConfig: + """ + Creates an explicit config for testing. + + Example: + >>> from yolo.utils.config_utils import build_config + >>> cfg = build_config(overrides=['task=train']) + >>> cfg = build_config(overrides=['task=validation']) + >>> cfg = build_config(overrides=['task=inference']) + """ + import yolo + import os + import pathlib + from hydra import compose, initialize + + # This is annoying that we cant just specify an absolute path when it is + # robustly built. Furthermore, the relative path seems like it isn't even + # from the cwd, but the module that is currently being run. + + # Find the path that we need to be relative to in a somewhat portable + # manner (i.e. will work in a Jupyter snippet). + try: + path_base = pathlib.Path(__file__).parent + except NameError: + path_base = pathlib.Path.cwd() + yolo_path = pathlib.Path(yolo.__file__).parent + rel_yolo_path = pathlib.Path(os.path.relpath(yolo_path, path_base)) + # rel_yolo_path = yolo_path.relative_to(path_base, walk_up=True) # requires Python 3.12 + config_path = os.fspath(rel_yolo_path / 'config') + with initialize(config_path=config_path, version_base=None): + cfg = compose(config_name="config", overrides=overrides) + return cfg diff --git a/yolo/utils/dataset_utils.py b/yolo/utils/dataset_utils.py index dd9a66ab..b794b8c5 100644 --- a/yolo/utils/dataset_utils.py +++ b/yolo/utils/dataset_utils.py @@ -34,6 +34,14 @@ def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path if txt_files: return txt_labels_path, "txt" + HANDLE_KWCOCO_FILES = 1 + if HANDLE_KWCOCO_FILES: + candidate = dataset_path / phase_name + if candidate.is_file(): + labels_path = dataset_path / phase_name + data_type = 'kwcoco' + return labels_path, data_type + logger.warning("No labels found in the specified dataset path and phase name.") return [], None @@ -97,12 +105,22 @@ def scale_segmentation( if annotations is None: return None + try: + import kwimage + except ImportError: + kwimage = None + seg_array_with_cat = [] h, w = image_dimensions["height"], image_dimensions["width"] for anno in annotations: category_id = anno["category_id"] if "segmentation" in anno: - seg_list = [item for sublist in anno["segmentation"] for item in sublist] + if kwimage is None: + # original fallback code + seg_list = [item for sublist in anno["segmentation"] for item in sublist] + else: + # Convert to original coco representation + seg_list = kwimage.MultiPolygon.coerce(anno["segmentation"]).to_coco('orig') elif "bbox" in anno: x, y, width, height = anno["bbox"] seg_list = [x, y, x + width, y, x + width, y + height, x, y + height] diff --git a/yolo/utils/kwcoco_utils.py b/yolo/utils/kwcoco_utils.py new file mode 100644 index 00000000..15429ea9 --- /dev/null +++ b/yolo/utils/kwcoco_utils.py @@ -0,0 +1,44 @@ +""" +Helpers for COCO / KWCoco integration +""" + + +def tensor_to_kwimage(yolo_annot_tensor, classes=None): + """ + Convert a raw output tensor to a kwimage Detections object + + Args: + yolo_annot_tensor (Tensor): + Each row corresponds to an annotation. + yolo_annot_tensor[:, 0] is the class index + yolo_annot_tensor[:, 1:5] is the ltrb bounding box + yolo_annot_tensor[:, 5] is the objectness confidence + Other columns are the per-class confidence + + classes (kwcoco.CategoryTree): + ... + + Example: + yolo_annot_tensor = torch.rand(1, 6) + """ + import kwimage + class_idxs = yolo_annot_tensor[:, 0].int() + boxes = kwimage.Boxes(yolo_annot_tensor[:, 1:5], format='xyxy') + dets = kwimage.Detections( + boxes=boxes, + class_idxs=class_idxs, + classes=classes, + ) + + if yolo_annot_tensor.shape[1] > 5: + scores = yolo_annot_tensor[:, 5] + dets.data['scores'] = scores + + if classes is not None: + if hasattr(classes, 'idx_to_id'): + # Add class-id information if that is available + import torch + idx_to_id = torch.Tensor(classes.idx_to_id).int().to(class_idxs.device) + class_ids = idx_to_id[class_idxs] + dets.data['class_ids'] = class_ids + return dets diff --git a/yolo/utils/logger.py b/yolo/utils/logger.py index 28602720..4880cf07 100644 --- a/yolo/utils/logger.py +++ b/yolo/utils/logger.py @@ -3,9 +3,16 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_only from rich.console import Console from rich.logging import RichHandler +import os logger = logging.getLogger("yolo") logger.setLevel(logging.DEBUG) logger.propagate = False -if rank_zero_only.rank == 0 and not logger.hasHandlers(): - logger.addHandler(RichHandler(console=Console(), show_level=True, show_path=True, show_time=True, markup=True)) + +# allow the user to get a simpler output +# TODO: needs to be better integrated +DISABLE_RICH_HANDLER = bool(os.environ.get('DISABLE_RICH_HANDLER', '')) + +if not DISABLE_RICH_HANDLER: + if rank_zero_only.rank == 0 and not logger.hasHandlers(): + logger.addHandler(RichHandler(console=Console(), show_level=True, show_path=True, show_time=True, markup=True)) diff --git a/yolo/utils/logging_utils.py b/yolo/utils/logging_utils.py index f60410d4..8dd4feb4 100644 --- a/yolo/utils/logging_utils.py +++ b/yolo/utils/logging_utils.py @@ -40,6 +40,7 @@ from yolo.utils.logger import logger from yolo.utils.model_utils import EMA from yolo.utils.solver_utils import make_ap_table +from yolo.utils.kwcoco_utils import tensor_to_kwimage # TODO: should be moved to correct position @@ -48,7 +49,7 @@ def set_seed(seed): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. - torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -107,7 +108,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: epoch_descript = "[cyan]Train [white]|" batch_descript = "[green]Train [white]|" metrics = self.get_metrics(trainer, pl_module) - metrics.pop("v_num") + metrics.pop("v_num", None) for metrics_name, metrics_val in metrics.items(): if "Loss_step" in metrics_name: epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|" @@ -126,7 +127,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) self._update(self.val_sanity_progress_bar_id, batch_idx + 1) elif self.val_progress_bar_id is not None: self._update(self.val_progress_bar_id, batch_idx + 1) - _, mAP = outputs + mAP = outputs['mAP'] mAP_desc = f" mAP :{mAP['map']*100:6.2f} | mAP50 :{mAP['map_50']*100:6.2f} |" self.progress.update(self.val_progress_bar_id, description=f"[green]Valid [white]|{mAP_desc}") self.refresh() @@ -222,20 +223,125 @@ def summarize( class ImageLogger(Callback): + def __init__(self): + # Number of validation / training batches to draw per epoch + self.num_draw_validation_per_epoch = 1 + self.num_draw_training_per_epoch = 1 + self.max_items_per_batch = float('inf') # maximum number of items to draw per batch + super().__init__() + def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None: - if batch_idx != 0: + if batch_idx >= self.num_draw_validation_per_epoch: + return + self.draw_batch_image(trainer, pl_module, outputs, batch, batch_idx) + + def on_train_batch_start(self, trainer: Trainer, pl_module, batch, batch_idx): + # We need to let the trainer know that we would like to draw its + # output. + if hasattr(trainer.model, 'request_draw'): + if batch_idx >= self.num_draw_training_per_epoch: + pl_module.request_draw = False + else: + pl_module.request_draw = True + + def on_train_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None: + if batch_idx >= self.num_draw_training_per_epoch: return + self.draw_batch_image(trainer, pl_module, outputs, batch, batch_idx) + + def draw_batch_image(self, trainer, pl_module, outputs, batch, batch_idx): batch_size, images, targets, rev_tensor, img_paths = batch - predicts, _ = outputs - gt_boxes = targets[0] if targets.ndim == 3 else targets - pred_boxes = predicts[0] if isinstance(predicts, list) else predicts - images = [images[0]] + predicts = outputs.get('predicts', None) + if predicts is None: + # Cannot draw what is not provided + print('Warning, attempted to draw batch, ' + 'but the model did not provide the correct outputs') + return None + + # gt_boxes = targets[0] if targets.ndim == 3 else targets + # pred_boxes = predicts[0] if isinstance(predicts, list) else predicts + # images = [images[0]] step = trainer.current_epoch - for logger in trainer.loggers: - if isinstance(logger, WandbLogger): - logger.log_image("Input Image", images, step=step) - logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)]) - logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)]) + + for _logger in trainer.loggers: + if isinstance(_logger, WandbLogger): + # FIXME: not robust to configured image sizes, need to know + # that info. + for image, gt_boxes, pred_boxes in zip(images, targets, predicts): + _logger.log_image("Input Image", [image], step=step) + _logger.log_image("Ground Truth", [image], step=step, boxes=[log_bbox(gt_boxes)]) + _logger.log_image("Prediction", [image], step=step, boxes=[log_bbox(pred_boxes)]) + + # TODO: better config + import os + LOG_BATCH_VIZ_TO_DISK = bool(os.environ.get('LOG_BATCH_VIZ_TO_DISK', '')) + if LOG_BATCH_VIZ_TO_DISK: + import einops + import kwimage + + # TODO: + # get a batter output path + # import pathlib + # root_dpath = pathlib.Path(trainer.default_root_dir) + root_dpath = trainer.log_dpath + out_dpath = root_dpath / 'monitor/batches' / trainer.state.stage.name + out_dpath.mkdir(exist_ok=True, parents=True) + epoch = trainer.current_epoch + + num_draw = min(len(images), self.max_items_per_batch) + + for bx in range(num_draw): + image_chw = images[bx].data.cpu().numpy() + gt_boxes = targets[bx] + pred_boxes = predicts[bx] + image_hwc = einops.rearrange(image_chw, 'c h w -> h w c') + image_hwc = kwimage.ensure_uint255(image_hwc) + + # TODO: include confusion analysis + + # assert bx == 0, 'not handling multiple per batch' + true_dets = tensor_to_kwimage(gt_boxes).numpy() + pred_dets = tensor_to_kwimage(pred_boxes).numpy() + pred_dets = pred_dets.non_max_supress(thresh=0.3) + pred_dets_2 = pred_dets.compress(pred_dets.scores > 0.1) + if len(pred_dets_2) > 0: + pred_dets = pred_dets_2 + + raw_canvas = image_hwc.copy() + true_canvas = true_dets.draw_on(raw_canvas.copy(), color='green') + pred_canvas = pred_dets.draw_on(raw_canvas.copy(), color='blue') + + raw_canvas = kwimage.draw_header_text(raw_canvas, 'raw') + true_canvas = kwimage.draw_header_text(true_canvas, f'true, n={len(true_dets)}') + pred_canvas = kwimage.draw_header_text(pred_canvas, f'pred, n={len(pred_dets)}') + canvas = kwimage.stack_images([ + raw_canvas, true_canvas, pred_canvas + ], axis=1, pad=3) + + fname = f'img_epoch{epoch:04d}_batch{batch_idx:04d}_bx{bx:04d}.jpg' + fpath = out_dpath / fname + kwimage.imwrite(fpath, canvas) + + +def wandb_to_kwimage(wand_annots): + import numpy as np + import kwimage + box_list = [] + class_idxs = [] + for row in wand_annots['predictions']['box_data']: + pos = row['position'] + class_idx = row['class_id'] + xyxy = [pos['minX'], pos['minY'], pos['maxX'], pos['maxY']] + box_list.append(xyxy) + class_idxs.append(class_idx) + + boxes = kwimage.Boxes(np.array(box_list), format='xyxy') + dets = kwimage.Detections( + boxes=boxes, + class_idxs=np.array(class_idxs) + ) + dets = dets.compress(dets.class_idxs > -1) + return dets def setup_logger(logger_name, quite=False): @@ -271,23 +377,56 @@ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=Tru save_path = validate_log_directory(cfg, cfg.name) - progress, loggers = [], [] + write_config(cfg, save_path) + + callbacks, loggers = [], [] if hasattr(cfg.task, "ema") and cfg.task.ema.enable: - progress.append(EMA(cfg.task.ema.decay)) + callbacks.append(EMA(cfg.task.ema.decay)) if quite: logger.setLevel(logging.ERROR) - return progress, loggers, save_path + return callbacks, loggers, save_path + + from yolo.utils.logger import DISABLE_RICH_HANDLER + + if not DISABLE_RICH_HANDLER: + callbacks.append(YOLORichProgressBar()) + callbacks.append(YOLORichModelSummary()) + + if 1: + import lightning + checkpoint_init_args = { + 'monitor': 'train_loss', + 'mode': 'min', + 'save_top_k': 5, + 'filename': '{epoch:04d}-{step:06d}-trainloss{train_loss:.3f}.ckpt', + 'save_last': True, + } + checkpointer = lightning.pytorch.callbacks.ModelCheckpoint(**checkpoint_init_args) + callbacks.append(checkpointer) + + callbacks.append(ImageLogger()) - progress.append(YOLORichProgressBar()) - progress.append(YOLORichModelSummary()) - progress.append(ImageLogger()) + print(f'cfg.use_tensorboard={cfg.use_tensorboard}') if cfg.use_tensorboard: - loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path)) + print(f'save_path={save_path}') + # loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path)) + loggers.append(TensorBoardLogger(save_path)) + from yolo.utils.callbacks.tensorboard_plotter import TensorboardPlotter + callbacks.append(TensorboardPlotter()) if cfg.use_wandb: loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None)) - return progress, loggers, save_path + return callbacks, loggers, save_path + + +@rank_zero_only +def write_config(cfg, save_path): + # Dump the config to the disk in the output folder + from omegaconf import OmegaConf + config_text = OmegaConf.to_yaml(cfg) + config_fpath = save_path / f'{cfg.task.task}_config.yaml' + config_fpath.write_text(config_text) def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]): diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 9d6c0ce5..4845039b 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -33,6 +33,14 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1): Returns: float: The interpolated value. + + Example: + >>> # from big to small + >>> lerp(1, 0, 10, 100) + 0.9 + >>> # from small to big + >>> lerp(0, 1, 10, 100) + 0.1 """ return start + (end - start) * step / total @@ -57,6 +65,14 @@ def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): self.ema_state_dict = deepcopy(pl_module.model.state_dict()) pl_module.ema.load_state_dict(self.ema_state_dict) + @no_grad() + def on_train_batch_start(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: + if self.ema_state_dict is None: + # If validation sanity checks are disabled, then we need to + # initialize the ema state before training starts. + self.ema_state_dict = deepcopy(pl_module.model.state_dict()) + pl_module.ema.load_state_dict(self.ema_state_dict) + @no_grad() def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: self.step += 1 @@ -72,45 +88,134 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: An instance of the optimizer configured according to the provided settings. """ optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type) - - bias_params = [p for name, p in model.named_parameters() if "bias" in name] - norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name] - conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name] - - model_parameters = [ - {"params": bias_params, "momentum": 0.937, "weight_decay": 0}, - {"params": conv_params, "momentum": 0.937}, - {"params": norm_params, "momentum": 0.937, "weight_decay": 0}, + optim_args = optim_cfg.args + + # Note the arguments that the optimizer class actually accepts + valid_optim_args = set(optim_args.keys()) + + named_params = dict(model.named_parameters()) + + named_groups = {} + + # Define groups that will have their optimizer params overwritten + # (NOTE: these params are valid for SGD, but may not be for other + # optimizers) + named_groups['bias'] = {name for name in named_params if 'bias' in name} + named_groups['norm'] = {name for name in named_params if "weight" in name and "bn" in name} + named_groups['conv'] = {name for name in named_params if "weight" in name and "bn" not in name} + + if __debug__: + import itertools as it + # Check that all groups are disjoint + for g1, g2 in it.combinations(named_groups.values(), 2): + assert len(g1 & g2) == 0 + # Check that all parmeters are in a group + used = set.union(*named_groups.values()) + all_param_name = set(named_params.keys()) + unused = all_param_name - used + assert len(unused) == 0 + + named_group_overrides = { + 'bias': {"momentum": 0.937, "weight_decay": 0}, + 'conv': {"momentum": 0.937}, + 'norm': {"momentum": 0.937, "weight_decay": 0}, + } + + # Remove any of the overrides that are valid arguments to the optimizer. + named_group_overrides = { + name: {k: v for k, v in overrides.items() if k in valid_optim_args} + for name, overrides in named_group_overrides.items() + } + + # Map the group names to the parameter objects + named_groups = { + group_name: [named_params[name] for name in param_names] + for group_name, param_names in named_groups.items() + } + + # Setup the input to standard torch optimizers + param_groups = [ + {"name": name, "params": params, **named_group_overrides[name]} + for name, params in named_groups.items() ] + # TODO: load momentum from config instead a fix number + warmup_schedule = { + 'momentum': { + 'start': 0.8, # Start Momemtum + 'normal': 0.937, # Normal Momemtum + 'peak_epoch': 3 # The warm up epoch num + } + } + def next_epoch(self, batch_num, epoch_idx): + """ + Args: + batch_num (int): the number of batches in the epoch + epoch_id (int): the epoch index + """ self.min_lr = self.max_lr - self.max_lr = [param["lr"] for param in self.param_groups] - # TODO: load momentum from config instead a fix number - # 0.937: Start Momentum - # 0.8 : Normal Momemtum - # 3 : The warm up epoch num - self.min_mom = lerp(0.8, 0.937, min(epoch_idx, 3), 3) - self.max_mom = lerp(0.8, 0.937, min(epoch_idx + 1, 3), 3) + self.max_lr = { + group['name']: group["lr"] + for group in self.param_groups + } + if 'momentum' in valid_optim_args: + mom0 = warmup_schedule['momentum']['start'] + mom1 = warmup_schedule['momentum']['normal'] + peak_epoch = warmup_schedule['momentum']['peak_epoch'] + self.min_mom = lerp(mom0, mom1, min(epoch_idx, peak_epoch), peak_epoch) + self.max_mom = lerp(mom0, mom1, min(epoch_idx + 1, peak_epoch), peak_epoch) self.batch_num = batch_num self.batch_idx = 0 def next_batch(self): self.batch_idx += 1 lr_dict = dict() - for lr_idx, param_group in enumerate(self.param_groups): - min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx] - param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num) - param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num) - lr_dict[f"LR/{lr_idx}"] = param_group["lr"] - lr_dict[f"momentum/{lr_idx}"] = param_group["momentum"] + for param_group in self.param_groups: + group_name = param_group['name'] + # TODO: give user control if they want this commented or not. + USE_CUSTOMIZED_LR_SCHEDULE = 0 + if USE_CUSTOMIZED_LR_SCHEDULE: + min_lr, max_lr = self.min_lr[group_name], self.max_lr[group_name] + param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num) + # lr_dict[f"LR/{group_name}"] = param_group["lr"] + + # Add any other scheduled key here. + keys = ['weight_decay', "lr"] + for k in keys: + if k in param_group: + lr_dict[f"{k}/{group_name}"] = param_group[k] + + if "momentum" in valid_optim_args: + param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num) + lr_dict[f"momentum/{group_name}"] = param_group["momentum"] return lr_dict + # Monkey patch in methods/attributes for more control over the schedule. optimizer_class.next_batch = next_batch optimizer_class.next_epoch = next_epoch + optimizer = optimizer_class(param_groups, **optim_args) + optimizer.max_lr = { + 'bias': 0.1, + 'conv': 0, + 'norm': 0, + } + + if 0: + # Test the schedule. + import ubelt as ub + batch_num = 3 + for epoch_idx in range(3): + optimizer.next_epoch(batch_num, epoch_idx) + ignore_names = {'defaults', 'param_groups', 'state'} + ignore_names |= {k for k in optimizer.__dict__.keys() if k.startswith('_')} + optim_dict = {k: v for k, v in optimizer.__dict__.items() if k not in ignore_names} + print(f'optim_dict = {ub.urepr(optim_dict, nl=1)}') + + for _ in range(batch_num): + lr_dict = optimizer.next_batch() + print(f'lr_dict = {ub.urepr(lr_dict, nl=0)}') - optimizer = optimizer_class(model_parameters, **optim_cfg.args) - optimizer.max_lr = [0.1, 0, 0] return optimizer @@ -124,13 +229,25 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR schedule = scheduler_class(optimizer, **schedule_cfg.args) if hasattr(schedule_cfg, "warmup"): wepoch = schedule_cfg.warmup.epochs - lambda1 = lambda epoch: (epoch + 1) / wepoch if epoch < wepoch else 1 - lambda2 = lambda epoch: 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1 + + def lambda1(epoch): + return (epoch + 1) / wepoch if epoch < wepoch else 1 + + def lambda2(epoch): + return 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1 + warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda2, lambda1, lambda1]) schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[wepoch - 1]) + if 0: + schedule.step() + print(_get_optim_lrs(optimizer)) return schedule +def _get_optim_lrs(optimizer): + return {group['name']: group['lr'] for group in optimizer.param_groups} + + def initialize_distributed() -> None: rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) diff --git a/yolo/utils/trainer.py b/yolo/utils/trainer.py new file mode 100644 index 00000000..17ed2eaf --- /dev/null +++ b/yolo/utils/trainer.py @@ -0,0 +1,114 @@ +import lightning +import ubelt as ub + + +class YoloTrainer(lightning.Trainer): + """ + Simple trainer subclass so we can ensure a print happens directly before + the training loop. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hacked_torch_global_callback = TorchGlobals(float32_matmul_precision='auto') + + def _run(self, *args, **kwargs): + # All I want is to print this directly before training starts. + # Is that so hard to do? + self._on_before_run() + super()._run(*args, **kwargs) + + def _run_stage(self, *args, **kwargs): + # All I want is to print this directly before training starts. + # Is that so hard to do? + self._on_before_run_stage() + super()._run_stage(*args, **kwargs) + + @property + def log_dpath(self): + """ + Get path to the the log directory if it exists. + """ + if self.logger is None: + # Fallback to default root dir + return ub.Path(self.default_root_dir) + # raise Exception('cannot get a log_dpath when no logger exists') + if self.logger.log_dir is None: + return ub.Path(self.default_root_dir) + # raise Exception('cannot get a log_dpath when logger.log_dir is None') + return ub.Path(self.logger.log_dir) + + def _on_before_run(self): + """ + Our custom "callback" + """ + self._hacked_torch_global_callback.before_setup_environment(self) + + def _on_before_run_stage(self): + """ + Our custom "callback" + """ + print(f'self.global_rank={self.global_rank}') + if self.global_rank == 0: + self._on_before_run_rank0() + self._handle_restart_details() + + def _on_before_run_rank0(self): + import rich + dpath = self.log_dpath + rich.print(f"Trainer log dpath:\n\n[link={dpath}]{dpath}[/link]\n") + + def _handle_restart_details(self): + """ + Handle chores when restarting from a previous checkpoint. + """ + if self.ckpt_path: + print('Detected that you are restarting from a previous checkpoint') + ckpt_path = ub.Path(self.ckpt_path) + assert ckpt_path.parent.name == 'checkpoints' + old_event_fpaths = list(ckpt_path.parent.parent.glob('events.out.tfevents.*')) + if len(old_event_fpaths): + print('Copying tensorboard events to new training directory directory') + for old_fpath in old_event_fpaths: + new_fpath = self.log_dpath / old_fpath.name + old_fpath.copy(new_fpath) + + +class TorchGlobals(lightning.pytorch.callbacks.Callback): + """ + Callback to setup torch globals. + + Note: this needs to be called before the accelerators are setup, and + existing callbacks don't have mechanisms for that, so we hack it in here. + + Args: + float32_matmul_precision (str): + can be 'medium', 'high', 'default', or 'auto'. + The 'default' value does not change any setting. + The 'auto' value defaults to 'medium' if the training devices have + ampere cores. + """ + + def __init__(self, float32_matmul_precision='default'): + self.float32_matmul_precision = float32_matmul_precision + + def before_setup_environment(self, trainer): + import torch + float32_matmul_precision = self.float32_matmul_precision + if float32_matmul_precision == 'default': + float32_matmul_precision = None + elif float32_matmul_precision == 'auto': + # Detect if we have Ampere tensor cores + # Ampere (V8) and later leverage tensor cores, where medium + # float32_matmul_precision becomes useful + if torch.cuda.is_available(): + device_versions = [torch.cuda.get_device_capability(device_id)[0] + for device_id in trainer.device_ids] + if all(v >= 8 for v in device_versions): + float32_matmul_precision = 'medium' + else: + float32_matmul_precision = None + else: + float32_matmul_precision = None + if float32_matmul_precision is not None: + torch.set_float32_matmul_precision(float32_matmul_precision)