Skip to content

update dataset format and fix some bug #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Deploy Mode Validation & Inference

on:
push:
branches: [main]
branches: [main,TRAIN]
pull_request:
branches: [main]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tools/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_create_dataloader_cache(train_cfg: Config):
train_cfg.task.data.shuffle = False
train_cfg.task.data.batch_size = 2

cache_file = Path("tests/data/train.cache")
cache_file = Path("tests/data/images/train.cache")
cache_file.unlink(missing_ok=True)

make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tools/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@pytest.fixture
def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
validator = ModelValidator(
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
validation_cfg, model, vec2box, validation_progress_logger, device
)
return validator

Expand Down
118 changes: 98 additions & 20 deletions tests/test_utils/test_bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,23 +146,62 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):


def test_bbox_nms():
cls_dist = tensor(
[[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution
cls_dist = torch.tensor(
[
[
[0.7, 0.1, 0.2], # High confidence, class 0
[0.3, 0.6, 0.1], # High confidence, class 1
[-3.0, -2.0, -1.0], # low confidence, class 2
[0.6, 0.2, 0.2], # Medium confidence, class 0
],
[
[0.55, 0.25, 0.2], # Medium confidence, class 0
[-4.0, -0.5, -2.0], # low confidence, class 1
[0.15, 0.2, 0.65], # Medium confidence, class 2
[0.8, 0.1, 0.1], # High confidence, class 0
],
],
dtype=float32,
)
bbox = tensor(
[[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes
bbox = torch.tensor(
[
[
[0, 0, 160, 120], # Overlaps with box 4
[160, 120, 320, 240],
[0, 120, 160, 240],
[16, 12, 176, 132],
],
[
[0, 0, 160, 120], # Overlaps with box 4
[160, 120, 320, 240],
[0, 120, 160, 240],
[16, 12, 176, 132],
],
],
dtype=float32,
)
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)

expected_output = [
tensor(
# Batch 1:
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
# - box 2 is kept with class 1
# - box 3 is rejected by the confidence filter
# Batch 2:
# - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
# - box 2 is rejected by the confidence filter
# - box 3 is kept with class 2
expected_output = torch.tensor(
[
[
[1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682],
[0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457],
]
)
]
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
],
[
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
],
]
)

output = bbox_nms(cls_dist, bbox, nms_cfg)

Expand All @@ -171,13 +210,52 @@ def test_bbox_nms():


def test_calculate_map():
predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]]) # [class, x1, y1, x2, y2]
ground_truths = tensor([[0, 50, 50, 150, 150], [0, 30, 30, 100, 100]]) # [class, x1, y1, x2, y2]

mAP = calculate_map(predictions, ground_truths)

expected_ap50 = tensor(0.5)
expected_ap50_95 = tensor(0.2)
# set test data
predictions = torch.tensor([
[0, 60, 60, 160, 160, 0.9], # [class, x1, y1, x2, y2, confidence]
[0, 40, 40, 120, 120, 0.8],
[1, 10, 10, 70, 70, 0.7]
])
ground_truths = torch.tensor([
[0, 50, 50, 150, 150], # [class, x1, y1, x2, y2]
[1, 15, 15, 65, 65],
[0, 30, 30, 100, 100],
])

# test basic function
result = calculate_map(predictions, ground_truths)
assert "mAP.50" in result
assert "mAP.5:.95" in result
assert 0 <= result["mAP.50"] <= 1
assert 0 <= result["mAP.5:.95"] <= 1

# test class-level metrics
assert "class_mAP" in result
assert 0 in result["class_mAP"]
assert 1 in result["class_mAP"]

# test different IoU thresholds
custom_thresholds = [0.3, 0.5, 0.7]
result_custom = calculate_map(predictions, ground_truths, iou_thresholds=custom_thresholds)
assert "mAP.30" in result_custom
assert "mAP.50" in result_custom
assert "mAP.70" in result_custom

# test edge cases, not considered
empty_predictions = torch.zeros((0, 6))
empty_result = calculate_map(empty_predictions, ground_truths)
assert empty_result["mAP.50"] == 0

empty_ground_truths = torch.zeros((0, 5))
empty_gt_result = calculate_map(predictions, empty_ground_truths)
assert empty_gt_result["mAP.50"] == 0

# test perfect match
perfect_predictions = torch.tensor([
[0, 50, 50, 150, 150, 1.0],
[0, 30, 30, 100, 100, 1.0],
[1, 15, 15, 65, 65, 1.0]
])
perfect_result = calculate_map(perfect_predictions, ground_truths)
assert pytest.approx(perfect_result["mAP.50"], 1e-6) == 1.0

assert isclose(mAP["mAP.5"], expected_ap50, atol=1e-5), f"AP50 mismatch"
assert isclose(mAP["mAP.5:.95"], expected_ap50_95, atol=1e-5), f"Mean AP mismatch"
6 changes: 4 additions & 2 deletions yolo/config/dataset/mock.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
path: tests/data
train: train
validation: val
image_train: images/train
label_train: annotations/instances_train.json
image_validation: images/val
label_validation: annotations/instances_val.json

class_num: 80
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
Expand Down
1 change: 1 addition & 0 deletions yolo/config/task/train.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
task: train
mode: detection

defaults:
- validation: ../validation
Expand Down
2 changes: 1 addition & 1 deletion yolo/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(cfg: Config):
if cfg.task.task == "train":
solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
if cfg.task.task == "validation":
solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
solver = ModelValidator(cfg, model, converter, progress, device)
if cfg.task.task == "inference":
solver = ModelTester(cfg, model, converter, progress, device)
progress.start()
Expand Down
83 changes: 58 additions & 25 deletions yolo/tools/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,73 @@


class YoloDataset(Dataset):
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train"):
augment_cfg = data_cfg.data_augment
self.image_size = data_cfg.image_size
phase_name = dataset_cfg.get(phase, phase)

transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
self.transform = AugmentationComposer(transforms, self.image_size)
self.transform.get_more_data = self.get_more_data
self.data = self.load_data(Path(dataset_cfg.path), phase_name)

def load_data(self, dataset_path: Path, phase_name: str):

self.get_dataset_path(cfg=dataset_cfg,phase=phase)

self.data = []
for images_path, labels_path in zip(self.images_paths, self.labels_paths):
datas = self.load_data(images_path, labels_path, phase_name)
datas = [ (images_path / data[0], *data[1:]) for data in datas]
self.data.extend(datas)

def get_dataset_path(self,cfg:DataConfig, phase: str = "train"):
# dataset source
images_paths = getattr(cfg, "image_" + phase)
if isinstance(images_paths, str):
images_paths = [images_paths]
elif isinstance(images_paths, tuple):
images_paths = list(images_paths)
self.images_paths = [Path(cfg.path) / images_path for images_path in images_paths]

labels_paths = getattr(cfg, "label_" + phase)
if isinstance(labels_paths, str):
labels_paths = [labels_paths]
elif isinstance(labels_paths, tuple):
labels_paths = list(labels_paths)
self.labels_paths = [Path(cfg.path) / labels_path for labels_path in labels_paths]

assert len(self.images_paths) == len(self.labels_paths)

def load_data(self, images_path: Path, labels_path: Path, phase_name: str):
"""
Loads data from a cache or generates a new cache for a specific dataset phase.

Parameters:
dataset_path (Path): The root path to the dataset directory.
images_path (Path): The root path to the images directory.
labels_path (Path): The root path to the labels directory.
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

Returns:
dict: The loaded data from the cache for the specified phase.
"""
cache_path = dataset_path / f"{phase_name}.cache"
cache_path = images_path.with_suffix(".cache")

if not cache_path.exists():
logger.info("🏭 Generating {} cache", phase_name)
data = self.filter_data(dataset_path, phase_name)
data = self.filter_data(images_path, labels_path, phase_name)
logger.info("🏭 Generating {} cache, containing {} samples", phase_name, len(data))
torch.save(data, cache_path)
else:
data = torch.load(cache_path)
logger.info("📦 Loaded {} cache", phase_name)
logger.info("📦 Loaded {} cache, containing {} samples", phase_name, len(data))
# TODO: add Validate cache
# if data[0][0].parent == Path("images")/phase_name:
# logger.info("✅ Cache validation successful")
# else:
# logger.warning("⚠️ Cache validation failed, regenerating")
# data = self.filter_data(images_path, labels_path, phase_name)
# torch.save(data, cache_path)

return data

def filter_data(self, dataset_path: Path, phase_name: str) -> list:
def filter_data(self, images_path: Path, labels_path: Path, phase_name: str) -> list:
"""
Filters and collects dataset information by pairing images with their corresponding labels.

Expand All @@ -67,8 +101,7 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list:
Returns:
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
"""
images_path = dataset_path / "images" / phase_name
labels_path, data_type = locate_label_paths(dataset_path, phase_name)
labels_path, data_type = locate_label_paths(labels_path, phase_name)
images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])
if data_type == "json":
annotations_index, image_info_dict = create_image_metadata(labels_path)
Expand All @@ -78,30 +111,29 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list:
for image_name in track(images_list, description="Filtering data"):
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
continue
image_id = Path(image_name).stem

if data_type == "json":
image_info = image_info_dict.get(image_id, None)
if data_type == "json":
image_info = image_info_dict.get(image_name, None)
# TODO: neg case can be load
if image_info is None:
continue
annotations = annotations_index.get(image_info["id"], [])
image_seg_annotations = scale_segmentation(annotations, image_info)
image_seg_annotations = scale_segmentation(annotations, image_info) # coco2yolo
if not image_seg_annotations:
continue

elif data_type == "txt":
label_path = labels_path / f"{image_id}.txt"
label_path = labels_path / Path(image_name).with_suffix('.txt')
if not label_path.is_file():
continue
with open(label_path, "r") as file:
image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
with label_path.open("r") as f:
image_seg_annotations = [list(map(float, line.strip().split())) for line in f]
else:
image_seg_annotations = []
# TODO: correct the box and log the image file
labels = self.load_valid_labels(images_path / image_name, image_seg_annotations)

labels = self.load_valid_labels(image_id, image_seg_annotations)

img_path = images_path / image_name
data.append((img_path, labels))
data.append((image_name, labels))
valid_inputs += 1
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
return data
Expand Down Expand Up @@ -133,6 +165,7 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te

def get_data(self, idx):
img_path, bboxes = self.data[idx]
# img_path = self.images_path / img_path
img = Image.open(img_path).convert("RGB")
return img, bboxes, img_path

Expand Down Expand Up @@ -200,7 +233,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
if task == "inference":
return StreamDataLoader(data_cfg)

if dataset_cfg.auto_download:
if dataset_cfg.get("auto_download",None):
prepare_dataset(dataset_cfg, task)

return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
Expand Down Expand Up @@ -300,4 +333,4 @@ def stop(self):
self.thread.join(timeout=1)

def __len__(self):
return self.queue.qsize() if not self.is_stream else 0
return self.queue.qsize() if not self.is_stream else 0
Loading