Skip to content

[WIP] Upgrades #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
675e7a6
feat: Basic support for kwcoco files
Erotemic Jan 5, 2025
96aed7a
refactor: cleanup code golf
Erotemic Jan 5, 2025
1a0a31d
change: disable determinism by default
Erotemic Jan 5, 2025
90e15fd
docs: add fixme note
Erotemic Jan 5, 2025
b50209e
change: other deterministic disable
Erotemic Jan 5, 2025
f0b17d0
refactor: Remove import *, and use getattr to avoid an unsafe eval
Erotemic Jan 5, 2025
056cd57
fix: handle case where classes is not in epoch_metrics
Erotemic Jan 5, 2025
b78a7f3
fix: error when v_num is not in the loss dict
Erotemic Jan 5, 2025
1d9e692
lint: remove unused f-string
Erotemic Jan 23, 2025
623264a
fix: type error in main
Erotemic Jan 23, 2025
4a299ff
test: add doctest for DualLoss with helper config_utils
Erotemic Jan 23, 2025
444d5c8
feat: add lazy imports for faster startup time
Erotemic Jan 23, 2025
c63660b
feat: allow user to specify accelerator
Erotemic Jan 23, 2025
b34ef90
feat: allow user to simplify output with environ
Erotemic Jan 23, 2025
585dc34
refactor: disable validation sanity check for faster training respons…
Erotemic Jan 23, 2025
fbcd5c0
fix: ensure categories are remapped with kwcoco
Erotemic Jan 23, 2025
1665a0e
doc: add todo about data / classes
Erotemic Jan 23, 2025
0815249
fix: dont assume iscrowd exists
Erotemic Jan 23, 2025
439c631
fix: valid points check was incorrect
Erotemic Jan 23, 2025
039ff67
feat: use kwimage to handle more polygon reprs
Erotemic Jan 23, 2025
cc60ee0
add: kwcoco training tutorial
Erotemic Jan 23, 2025
8d899fa
refactor: improve on-disk batch viz
Erotemic Jan 24, 2025
3ccd4f3
feat: add weights loading log statement
Erotemic Jan 28, 2025
9c3eba6
fix: workaround weight loading issue at inference time
Erotemic Jan 28, 2025
975ae5f
feat: basic inference on a coco file
Erotemic Jan 28, 2025
aaed2a3
refactor: remove opencv-python requirement
Erotemic Jan 28, 2025
f62d00b
refactor: create a kwcoco utils
Erotemic Jan 28, 2025
f07548e
Add log statement
Erotemic Jan 29, 2025
62af38e
Add docstr
Erotemic Jan 29, 2025
29ba851
prep for ability to output kwcoco
Erotemic Jan 29, 2025
a9eee8e
fix: kwcoco file now properly writes at the end of inference.
Erotemic Feb 3, 2025
2486daa
fix: case when classes is a list
Erotemic Feb 4, 2025
d0a19d1
refactor: add log statement to lazy
Erotemic Feb 4, 2025
9e80e88
add: write config to disk at train time
Erotemic Feb 4, 2025
b06e8fd
add: minimum setuptools version
Erotemic Feb 6, 2025
57f82ac
more debug info
Erotemic Feb 23, 2025
d5f7ddd
log train loss
Erotemic Feb 23, 2025
0de5c13
add: custom trainer for better logging controls
Erotemic Feb 23, 2025
f780250
add: callback for 3090 optimization
Erotemic Feb 23, 2025
b0efc4e
refactor: Use more callbacks
Erotemic Feb 23, 2025
de2a4c5
save checkpoints with train loss
Erotemic Feb 23, 2025
81c2ed7
fix typo
Erotemic Feb 23, 2025
a1b7426
rework float32_matmul_precision hack
Erotemic Feb 23, 2025
0ff0639
add: tensorboad plotting callbacks
Erotemic Feb 23, 2025
ff12c1f
Better tensorboard plotter, training on demo works now
Erotemic Feb 23, 2025
4b40050
log more than 1 image
Erotemic Feb 23, 2025
77e089b
try to use overviews, but disable because it caused a crash
Erotemic Feb 23, 2025
442689f
remove assert for image logger
Erotemic Feb 23, 2025
8d7a738
minor debug tweaks
Erotemic Mar 8, 2025
5a69084
add: expose all lightning trainer args via hydra
Erotemic Mar 8, 2025
40d86ee
add trainer config yaml
Erotemic Mar 8, 2025
f8ebe45
Update logging location and log train batches (todo: make optional)
Erotemic Mar 8, 2025
4137596
improve image logger to show train and val
Erotemic Mar 8, 2025
b81f5e2
Fix outputs now being a dict
Erotemic Mar 8, 2025
238b44c
Re-expose horizontal and vertical flips
Erotemic Mar 8, 2025
33781e0
Add note
Erotemic Mar 8, 2025
7b28d45
Add notes
Erotemic Mar 9, 2025
4bd3823
Rework optimizer creation to handle more optimizers
Erotemic Mar 9, 2025
a6e1be4
Replace lambda with def
Erotemic Mar 9, 2025
4b51b9f
better schedule logging
Erotemic Mar 10, 2025
e0ce8c0
Add doctest examples to BoxMatcher, refactor collate_fn
Erotemic Mar 11, 2025
543599d
Fix issue with _is_coco
Erotemic Mar 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ include-package-data = true
[build-system]
build-backend = "setuptools.build_meta"
requires = [
"setuptools",
"setuptools>=61.0"
]

[project.scripts]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ einops
faster-coco-eval
graphviz
hydra-core
lazy-loader
lightning
loguru
numpy
opencv-python
Pillow
pycocotools
requests
Expand Down
160 changes: 160 additions & 0 deletions train_kwcoco_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/bin/bash
__doc__="
YOLO Training Tutorial with KWCOCO DemoData
===========================================

This demonstrates an end-to-end YOLO pipeline on toydata generated with kwcoco.
"

# Define where we will store results
BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
mkdir -p "$BUNDLE_DPATH"

echo "
Generate Toy Data
-----------------

Now that we know where the data and our intermediate files will go, lets
generate the data we will use to train and evaluate with.

The kwcoco package comes with a commandline utility called 'kwcoco toydata' to
accomplish this.
"

# Define the names of the kwcoco files to generate
TRAIN_FPATH=$BUNDLE_DPATH/vidshapes_rgb_train/data.kwcoco.json
VALI_FPATH=$BUNDLE_DPATH/vidshapes_rgb_vali/data.kwcoco.json
TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json

# Generate toy datasets using the "kwcoco toydata" tool
kwcoco toydata vidshapes32-frames10 --dst "$TRAIN_FPATH"
kwcoco toydata vidshapes4-frames10 --dst "$VALI_FPATH"
kwcoco toydata vidshapes2-frames6 --dst "$TEST_FPATH"

# Ensure legacy COCO structure for now
kwcoco conform "$TRAIN_FPATH" --inplace --legacy=True
kwcoco conform "$VALI_FPATH" --inplace --legacy=True
kwcoco conform "$TEST_FPATH" --inplace --legacy=True


echo "
Create the YOLO Configuration
-----------------------------

Constructing the YOLO configuration is not entirely kwcoco aware
so we need to set
"
# In the current version we need to write configs to the repo itself.
# Its a bit gross, but this should be somewhat robust.
# Find where the yolo repo is installed (we need to be careful that this is the
# our fork of the WongKinYiu variant
REPO_DPATH=$(python -c "import yolo, pathlib; print(pathlib.Path(yolo.__file__).parent.parent)")
MODULE_DPATH=$(python -c "import yolo, pathlib; print(pathlib.Path(yolo.__file__).parent)")
CONFIG_DPATH=$(python -c "import yolo.config, pathlib; print(pathlib.Path(yolo.config.__file__).parent / 'dataset')")
echo "REPO_DPATH = $REPO_DPATH"
echo "MODULE_DPATH = $MODULE_DPATH"
echo "CONFIG_DPATH = $CONFIG_DPATH"

DATASET_CONFIG_FPATH=$CONFIG_DPATH/kwcoco-demo.yaml

# Hack to construct the class part of the YAML
CLASS_YAML=$(python -c "if 1:
import kwcoco
train_fpath = kwcoco.CocoDataset('$TRAIN_FPATH')
categories = train_fpath.categories().objs
# It would be nice to have better class introspection, but in the meantime
# do the same sorting as yolo.tools.data_conversion.discretize_categories
categories = sorted(categories, key=lambda cat: cat['id'])
class_num = len(categories)
class_list = [c['name'] for c in categories]
print(f'class_num: {class_num}')
print(f'class_list: {class_list}')
")


CONFIG_YAML="
path: $BUNDLE_DPATH
train: $TRAIN_FPATH
validation: $VALI_FPATH

$CLASS_YAML
"
echo "$CONFIG_YAML" > "$DATASET_CONFIG_FPATH"

TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
# This might only work in development mode, otherwise we will get site packages
# That still might be fine, but we do want to fix this to run anywhere.
cd "$REPO_DPATH"
export CUDA_VISIBLE_DEVICES="1,"
LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \
task=train \
dataset=kwcoco-demo \
use_tensorboard=True \
use_wandb=False \
out_path="$TRAIN_DPATH" \
name=kwcoco-demo \
cpu_num=0 \
device=0 \
accelerator=auto \
task.data.batch_size=2 \
"image_size=[640, 640]" \
task.optimizer.args.lr=0.03


### show how to run inference

BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json
# Grab a checkpoint
CKPT_FPATH=$(python -c "if 1:
import pathlib
root_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo'
checkpoints = sorted(root_dpath.glob('lightning_logs/*/checkpoints/*'))
print(checkpoints[-1])
")
echo "CKPT_FPATH = $CKPT_FPATH"

export DISABLE_RICH_HANDLER=1
export CUDA_VISIBLE_DEVICES="1,"
python yolo/lazy.py \
task.data.source="$TEST_FPATH" \
task=inference \
dataset=kwcoco-demo \
use_wandb=False \
out_path=kwcoco-demo-inference \
name=kwcoco-inference-test \
cpu_num=8 \
weight="\"$CKPT_FPATH\"" \
accelerator=auto \
task.nms.min_confidence=0.01 \
task.nms.min_iou=0.3 \
task.nms.max_bbox=10


### Show how to run validation

# Grab a checkpoint
BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
CKPT_FPATH=$(python -c "if 1:
import pathlib
ckpt_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo/checkpoints'
checkpoints = sorted(ckpt_dpath.glob('*'))
print(checkpoints[-1])
")
echo "CKPT_FPATH = $CKPT_FPATH"

#DISABLE_RICH_HANDLER=1
LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \
task=validation \
dataset=kwcoco-demo \
use_wandb=False \
out_path="$TRAIN_DPATH" \
name=kwcoco-demo \
cpu_num=0 \
device=0 \
weight="'$CKPT_FPATH'" \
accelerator=auto \
"task.data.batch_size=2" \
"image_size=[224,224]"
109 changes: 77 additions & 32 deletions yolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,78 @@
from yolo.config.config import Config, NMSConfig
from yolo.model.yolo import create_model
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
from yolo.tools.drawer import draw_bboxes
from yolo.tools.solver import TrainModel
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
from yolo.utils.deploy_utils import FastModelLoader
from yolo.utils.logging_utils import (
ImageLogger,
YOLORichModelSummary,
YOLORichProgressBar,
"""
The MIT YOLO rewrite
"""

__autogen__ = """
mkinit ~/code/YOLO-v9/yolo/__init__.py --nomods --write --lazy-loader

# Check to see how long it takes to run a simple help command
time python -m yolo.lazy --help
"""

__submodules__ = {
'config.config': ['Config', 'NMSConfig'],
'model.yolo': ['create_model'],
'tools.data_loader': ['AugmentationComposer', 'create_dataloader'],
'tools.drawer': ['draw_bboxes'],
'tools.solver': ['TrainModel'],
'utils.bounding_box_utils': ['Anc2Box', 'Vec2Box', 'bbox_nms', 'create_converter'],
'utils.deploy_utils': ['FastModelLoader'],
'utils.logging_utils': [
'ImageLogger', 'YOLORichModelSummary',
'YOLORichProgressBar',
'validate_log_directory'
],
'utils.model_utils': ['PostProcess'],
}


import lazy_loader


__getattr__, __dir__, __all__ = lazy_loader.attach(
__name__,
submodules={},
submod_attrs={
'config.config': [
'Config',
'NMSConfig',
],
'model.yolo': [
'create_model',
],
'tools.data_loader': [
'AugmentationComposer',
'create_dataloader',
],
'tools.drawer': [
'draw_bboxes',
],
'tools.solver': [
'TrainModel',
],
'utils.bounding_box_utils': [
'Anc2Box',
'Vec2Box',
'bbox_nms',
'create_converter',
],
'utils.deploy_utils': [
'FastModelLoader',
],
'utils.logging_utils': [
'ImageLogger',
'YOLORichModelSummary',
'YOLORichProgressBar',
'validate_log_directory',
],
'utils.model_utils': [
'PostProcess',
],
},
)
from yolo.utils.model_utils import PostProcess

all = [
"create_model",
"Config",
"YOLORichProgressBar",
"NMSConfig",
"YOLORichModelSummary",
"validate_log_directory",
"draw_bboxes",
"Vec2Box",
"Anc2Box",
"bbox_nms",
"create_converter",
"AugmentationComposer",
"ImageLogger",
"create_dataloader",
"FastModelLoader",
"TrainModel",
"PostProcess",
]

__all__ = ['Anc2Box', 'AugmentationComposer', 'Config', 'FastModelLoader',
'ImageLogger', 'NMSConfig', 'PostProcess', 'TrainModel', 'Vec2Box',
'YOLORichModelSummary', 'YOLORichProgressBar', 'bbox_nms',
'create_converter', 'create_dataloader', 'create_model',
'draw_bboxes', 'validate_log_directory']
1 change: 1 addition & 0 deletions yolo/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class Config:
use_tensorboard: bool

weight: Optional[str]
accelerator: str


@dataclass
Expand Down
1 change: 1 addition & 0 deletions yolo/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ defaults:
- dataset: coco
- model: v9-c
- general
- trainer: yolo
3 changes: 2 additions & 1 deletion yolo/config/task/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ data:
data_augment:
Mosaic: 1
# MixUp: 1
# HorizontalFlip: 0.5
HorizontalFlip: 0.0
VerticalFlip: 0.0
RandomCrop: 1
RemoveOutliers: 1e-8

Expand Down
47 changes: 47 additions & 0 deletions yolo/config/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Expose most of the lighting trainer options from the config.

accelerator : auto # Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") as well as custom accelerator instances.
strategy : auto # Supports different training strategies with aliases as well custom strategies. Default: ``"auto"``.
devices : auto # The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for automatic selection based on the chosen accelerator. Default: ``"auto"``.
num_nodes : 1 # Number of GPU nodes for distributed training. Default: ``1``.
precision : null # Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, or HPUs. Default: ``'32-true'``.
fast_dev_run : False # Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default: ``False``.
min_epochs : null # Force training for at least these many epochs. Disabled by default (None).
max_steps : -1 # Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``.
min_steps : null # Force training for at least these number of steps. Disabled by default (``None``).
max_time : null # Stop training after this amount of time has passed. Disabled by default (``None``). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`.
limit_train_batches : null # How much of training dataset to check (float = fraction, int = num_batches). Default: ``1.0``.
limit_val_batches : null # How much of validation dataset to check (float = fraction, int = num_batches). Default: ``1.0``.
limit_test_batches : null # How much of test dataset to check (float = fraction, int = num_batches). Default: ``1.0``.
limit_predict_batches : null # How much of prediction dataset to check (float = fraction, int = num_batches). Default: ``1.0``.
overfit_batches : 0.0 # Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: ``0.0``.
val_check_interval : null # How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or during iteration-based training. Default: ``1.0``.
check_val_every_n_epoch : 1 # Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``.
num_sanity_val_steps : null # Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: ``2``.
log_every_n_steps : null # How often to log within steps. Default: ``50``.
enable_checkpointing : null # If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. Default: ``True``.
enable_model_summary : null # Whether to enable model summarization by default. Default: ``True``.
accumulate_grad_batches : 1 # Accumulates gradients over k batches before stepping the optimizer. Default: 1.
gradient_clip_val : null # The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: ``None``.
gradient_clip_algorithm : null # The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``.
deterministic : null # If ``True``, sets whether PyTorch operations must use deterministic algorithms. Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
benchmark : null # The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``. Override to manually set a different value. Default: ``None``.
inference_mode : True # Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during evaluation (``validate``/``test``/``predict``).
use_distributed_sampler : True # Whether to wrap the DataLoader's sampler with :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, we don't do this automatically.
profiler : null # To profile individual steps during training and assist in identifying bottlenecks. Default: ``None``.
detect_anomaly : False # Enable anomaly detection for the autograd engine. Default: ``False``.
barebones : False # Whether to run in "barebones mode", where all features that may impact raw speed are disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training runs. The following features are deactivated: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`, :meth:`~lightning.pytorch.core.LightningModule.log`, :meth:`~lightning.pytorch.core.LightningModule.log_dict`.
plugins : null # Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``.
sync_batchnorm : False # Synchronize batch norm layers between process groups/whole world. Default: ``False``.
reload_dataloaders_every_n_epochs: 0 # Set to a positive integer to reload dataloaders every n epochs. Default: ``0``.


### EXPOSED ELSEWHERE
#default_root_dir : null # Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
#enable_progress_bar : null # Whether to enable to progress bar by default. Default: ``True``.
#max_epochs : null # Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``.


### UNSUPPORTED
# callbacks : null # Add a callback or list of callbacks. Default: ``None``.
# logger : null # Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. ``False`` will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. Default: ``True``.
10 changes: 10 additions & 0 deletions yolo/config/trainer/yolo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults:
- default

num_sanity_val_steps: 0
precision: "16-mixed"
log_every_n_steps: 1
gradient_clip_val: 10
gradient_clip_algorithm: "value"
#deterministic: True
deterministic: False
Loading