Skip to content

[INTERN] Experimental training branch #1866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import shutil
import traceback
from collections.abc import Callable
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -47,28 +48,37 @@ def _read_sample(self, index: int) -> tuple[Any, Any]:

def __getitem__(self, index: int) -> tuple[Any, Any]:
# Read image
img, target = self._read_sample(index)
# Pre-transforms (format conversion at run-time etc.)
if self._pre_transforms is not None:
img, target = self._pre_transforms(img, target)

if self.img_transforms is not None:
# typing issue cf. https://github.com/python/mypy/issues/5485
img = self.img_transforms(img)

if self.sample_transforms is not None:
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
if (
isinstance(target, dict)
and all(isinstance(item, np.ndarray) for item in target.values())
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
):
img_transformed = _copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)
try:
img, target = self._read_sample(index)
# Pre-transforms (format conversion at run-time etc.)
if self._pre_transforms is not None:
img, target = self._pre_transforms(img, target)

if self.img_transforms is not None:
# typing issue cf. https://github.com/python/mypy/issues/5485
img = self.img_transforms(img)

if self.sample_transforms is not None:
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
if (
isinstance(target, dict)
and all(isinstance(item, np.ndarray) for item in target.values())
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
):
img_transformed = _copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)
except Exception:
img_name = self.data[index][0]
# Write
print()
print(f"!!!ERROR in Dataset on filename {img_name}")
traceback.print_exc()
print()
return self.__getitem__(0) # should exists ^^

return img, target

Expand Down
16 changes: 11 additions & 5 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ def __init__(

self.data: list[tuple[str, tuple[np.ndarray, list[str]]]] = []
np_dtype = np.float32
missing_files = []
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)

self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
missing_files.append(img_name)
# raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
else:
geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
print("List of missing files:")
print(f"MISSING FILES: {len(missing_files)}")
from pprint import pprint

pprint(missing_files)

def format_polygons(
self, polygons: list | dict, use_polygons: bool, np_dtype: type
Expand Down
13 changes: 10 additions & 3 deletions doctr/datasets/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ def __init__(
with open(labels_path, encoding="utf-8") as f:
labels = json.load(f)

missing_files = []
for img_name, label in labels.items():
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

self.data.append((img_name, label))
missing_files.append(img_name)
# raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
else:
self.data.append((img_name, label))
print("List of missing files:")
print(f"MISSING FILES: {len(missing_files)}")
from pprint import pprint

pprint(missing_files)

def merge_dataset(self, ds: AbstractDataset) -> None:
# Update data with new root for self
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class _DBNet:
shrink_ratio = 0.4
thresh_min = 0.3
thresh_max = 0.7
min_size_box = 3
min_size_box = 2
assume_straight_pages: bool = True

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class _FAST(BaseModel):
<https://arxiv.org/pdf/2111.02394.pdf>`_.
"""

min_size_box: int = 3
min_size_box: int = 2
assume_straight_pages: bool = True
shrink_ratio = 0.4

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
assume_straight_pages: bool = True,
) -> None:
super().__init__(box_thresh, bin_thresh, assume_straight_pages)
self.unclip_ratio = 1.5
self.unclip_ratio = 1.2

def polygon_to_box(
self,
Expand Down Expand Up @@ -149,9 +149,9 @@ class _LinkNet(BaseModel):
out_chan: number of channels for the output
"""

min_size_box: int = 3
min_size_box: int = 2
assume_straight_pages: bool = True
shrink_ratio = 0.5
shrink_ratio = 0.4

def build_target(
self,
Expand Down
62 changes: 46 additions & 16 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def main(args):
class_names=val_set.class_names,
)

# Target building params
model.min_size_box = args.min_size_box
model.shrink_ratio = args.shrink_ratio

# Resume weights
if isinstance(args.resume, str):
pbar.write(f"Resuming {args.resume}")
Expand All @@ -272,7 +276,9 @@ def main(args):

if args.test_only:
pbar.write("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
val_loss, recall, precision, mean_iou = evaluate(
model, val_loader, batch_transforms, val_metric, amp=args.amp, log=lambda **kwargs: None
)
pbar.write(
f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
f"Mean IoU: {mean_iou:.2%})"
Expand All @@ -284,39 +290,53 @@ def main(args):
# Image augmentations
img_transforms = T.OneOf([
Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2),
T.RandomApply(T.ColorInversion(), 0.4),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.2)), 0.2),
]),
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.RandomShadow(), 0.5),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3),
RandomGrayscale(p=0.15),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.2)), 0.2),
]),
RandomPhotometricDistort(p=0.3),
lambda x: x, # Identity no transformation
RandomGrayscale(p=0.15),
RandomPhotometricDistort(p=0.25),
# lambda x: x, # Identity no transformation
])
# Image + target augmentations
sample_transforms = T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.15),
T.RandomHorizontalFlip(0.1),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.15),
T.RandomResize(scale_range=(0.4, 0.5), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.5),
]),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.15),
T.RandomHorizontalFlip(0.1),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.15),
# T.RandomResize(scale_range=(0.2, 0.25), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.5),
T.RandomApply(
T.Resize(
(args.input_size // 2, args.input_size // 2),
preserve_aspect_ratio=False,
symmetric_pad=False,
),
0.5,
),
T.RandomApply(
T.Resize(
(args.input_size // 2, args.input_size // 2), preserve_aspect_ratio=True, symmetric_pad=True
),
0.5,
),
]),
# Rotation augmentation
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.RandomApply(T.RandomRotate(90, expand=True), 0.75),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
)
Expand Down Expand Up @@ -384,7 +404,7 @@ def main(args):
elif args.sched == "onecycle":
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
elif args.sched == "poly":
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader), power=1.0)

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -544,6 +564,16 @@ def parse_args():
"--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"
)
parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W")
parser.add_argument(
"--min-size-box", type=int, default=2, help="minimum size of a box to be considered", dest="min_size_box"
)
parser.add_argument(
"--shrink-ratio",
type=float,
default=0.4,
help="shrink ratio for the polygons range [0.1, 0.9]",
dest="shrink_ratio",
)
parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)")
parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay")
parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
Expand Down
30 changes: 22 additions & 8 deletions references/detection/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def main(rank: int, world_size: int, args):
class_names=class_names,
)

# Target building params
model.min_size_box = args.min_size_box
model.shrink_ratio = args.shrink_ratio

# Resume weights
if isinstance(args.resume, str):
pbar.write(f"Resuming {args.resume}")
Expand Down Expand Up @@ -303,33 +307,33 @@ def main(rank: int, world_size: int, args):
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2),
RandomGrayscale(p=0.15),
]),
RandomPhotometricDistort(p=0.3),
RandomPhotometricDistort(p=0.2),
lambda x: x, # Identity no transformation
])
# Image + target augmentations
sample_transforms = T.SampleCompose(
(
[
T.RandomHorizontalFlip(0.15),
T.RandomHorizontalFlip(0.1),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
T.RandomResize(scale_range=(0.75, 0.95), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if not args.rotation
else [
T.RandomHorizontalFlip(0.15),
T.RandomHorizontalFlip(0.1),
T.OneOf([
T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25),
T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
T.RandomResize(scale_range=(0.75, 0.95), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25),
]),
# Rotation augmentation
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.RandomApply(T.RandomRotate(90, expand=True), 0.75),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
)
Expand Down Expand Up @@ -398,7 +402,7 @@ def main(rank: int, world_size: int, args):
elif args.sched == "onecycle":
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
elif args.sched == "poly":
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader), power=1.0)

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -523,6 +527,16 @@ def parse_args():
"--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"
)
parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W")
parser.add_argument(
"--min-size-box", type=int, default=2, help="minimum size of a box to be considered", dest="min_size_box"
)
parser.add_argument(
"--shrink-ratio",
type=float,
default=0.4,
help="shrink ratio for the polygons range [0.1, 0.9]",
dest="shrink_ratio",
)
parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)")
parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay")
parser.add_argument("-j", "--workers", type=int, default=0, help="number of workers used for dataloading")
Expand Down
Loading