|
10 | 10 | import torch
|
11 | 11 | import torch.utils.data
|
12 | 12 | import torchvision
|
| 13 | +from pycocotools.coco import COCO |
13 | 14 | from pycocotools import mask as coco_mask
|
14 | 15 | from PIL import Image
|
15 | 16 | from io import BytesIO
|
|
18 | 19 |
|
19 | 20 | import datasets.transforms as T
|
20 | 21 |
|
| 22 | +from torchvision.datasets import VisionDataset |
| 23 | +from typing import Any, Callable, Optional, Tuple, List |
21 | 24 |
|
22 | 25 | ZIPS = dict()
|
23 | 26 |
|
@@ -59,34 +62,92 @@ def my_Image_open(root, fname):
|
59 | 62 | iob = BytesIO(my_open(root, fname))
|
60 | 63 | return Image.open(iob)
|
61 | 64 |
|
| 65 | +class CocoDetectionOptim(VisionDataset): |
| 66 | + """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset. |
| 67 | +
|
| 68 | + Args: |
| 69 | + root (string): Root directory where images are downloaded to. |
| 70 | + annot (string): Path to json ** annotations directory **. |
| 71 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 72 | + and returns a transformed version. E.g, ``transforms.ToTensor`` |
| 73 | + target_transform (callable, optional): A function/transform that takes in the |
| 74 | + target and transforms it. |
| 75 | + transforms (callable, optional): A function/transform that takes input sample and its target as entry |
| 76 | + and returns a transformed version. |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + root: str, |
| 82 | + # annFile: str, |
| 83 | + annot: str, |
| 84 | + transform: Optional[Callable] = None, |
| 85 | + target_transform: Optional[Callable] = None, |
| 86 | + transforms: Optional[Callable] = None, |
| 87 | + ) -> None: |
| 88 | + super().__init__(root, transforms, transform, target_transform) |
| 89 | + # from pycocotools.coco import COCO |
| 90 | + |
| 91 | + # self.coco = COCO(annFile) |
| 92 | + self.ann_paths = [os.path.join(annot, f) for f in os.listdir(annot) if f.endswith('.json')] |
| 93 | + # self.ids = list(sorted(self.coco.imgs.keys())) |
| 94 | + |
| 95 | + def _load_image(self, index: int) -> Image.Image: |
| 96 | + coco = COCO(self.ann_paths[index]) |
| 97 | + id = list(coco.imgs.keys())[0] |
| 98 | + path = coco.loadImgs(id)[0]["file_name"] |
| 99 | + return Image.open(os.path.join(self.root, path)).convert("RGB") |
| 100 | + |
| 101 | + def _load_target(self, index: int) -> List[Any]: |
| 102 | + coco = COCO(self.ann_paths[index]) |
| 103 | + id = list(coco.imgs.keys())[0] |
| 104 | + return coco.loadAnns(coco.getAnnIds(id)) |
| 105 | + |
| 106 | + def __getitem__(self, index: int) -> Tuple[Any, Any]: |
| 107 | + # id = self.ids[index] |
| 108 | + image = self._load_image(index) |
| 109 | + target = self._load_target(index) |
62 | 110 |
|
63 |
| -class CocoDetection(torchvision.datasets.CocoDetection): |
64 |
| - def __init__(self, img_folder, ann_file, transforms, return_masks): |
65 |
| - super(CocoDetection, self).__init__(img_folder, ann_file) |
| 111 | + if self.transforms is not None: |
| 112 | + image, target = self.transforms(image, target) |
| 113 | + |
| 114 | + return image, target |
| 115 | + |
| 116 | + def __len__(self) -> int: |
| 117 | + return len(self.ann_paths) |
| 118 | + |
| 119 | + |
| 120 | +class CocoDetection(CocoDetectionOptim): |
| 121 | + def __init__(self, img_folder, ann_folder, transforms, return_masks): |
| 122 | + super(CocoDetection, self).__init__(img_folder, ann_folder) |
66 | 123 | self._transforms = transforms
|
67 | 124 | self.prepare = ConvertCocoPolysToMask(return_masks)
|
68 | 125 |
|
69 | 126 | def __getitem__(self, idx):
|
70 |
| - id = self.ids[idx] |
71 |
| - img = self._load_image(id) |
72 |
| - target = self._load_target(id) |
| 127 | + img = self._load_image(idx) |
| 128 | + target = self._load_target(idx) |
73 | 129 |
|
74 | 130 | if self.transforms is not None:
|
75 | 131 | img, target = self.transforms(img, target)
|
76 | 132 |
|
77 |
| - image_id = self.ids[idx] |
| 133 | + coco = COCO(self.ann_paths[idx]) |
| 134 | + image_id = list(coco.imgs.keys())[0] |
78 | 135 | target = {'image_id': image_id, 'annotations': target}
|
79 | 136 | img, target = self.prepare(img, target)
|
80 | 137 | if self._transforms is not None:
|
81 | 138 | img, target = self._transforms(img, target)
|
82 | 139 | return img, target
|
83 | 140 |
|
84 |
| - def _load_image(self, id: int) -> Image.Image: |
85 |
| - path = self.coco.loadImgs(id)[0]["file_name"] |
| 141 | + def _load_image(self, index: int) -> Image.Image: |
| 142 | + coco = COCO(self.ann_paths[index]) |
| 143 | + id = list(coco.imgs.keys())[0] |
| 144 | + path = coco.loadImgs(id)[0]["file_name"] |
86 | 145 | return my_Image_open(self.root, path).convert('RGB')
|
87 | 146 |
|
88 |
| - def _load_target(self, id) -> List[Any]: |
89 |
| - return self.coco.loadAnns(self.coco.getAnnIds(id)) |
| 147 | + def _load_target(self, index) -> List[Any]: |
| 148 | + coco = COCO(self.ann_paths[index]) |
| 149 | + id = list(coco.imgs.keys())[0] |
| 150 | + return coco.loadAnns(coco.getAnnIds(id)) |
90 | 151 |
|
91 | 152 |
|
92 | 153 | def convert_coco_poly_to_mask(segmentations, height, width):
|
@@ -207,10 +268,10 @@ def build(image_set, args):
|
207 | 268 | root = Path(args.coco_path)
|
208 | 269 | assert root.exists(), f'provided COCO path {root} does not exist'
|
209 | 270 | PATHS = {
|
210 |
| - "train": (root / "train", root / "annotations" / f'train_annotations.json'), |
211 |
| - "val": (root / "val", root / "annotations" / f'val_annotations.json'), |
| 271 | + "train": (root / "train", root / "annotations"), |
| 272 | + "val": (root / "val", root / "annotations"), |
212 | 273 | }
|
213 | 274 |
|
214 |
| - img_folder, ann_file = PATHS[image_set] |
215 |
| - dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) |
| 275 | + img_folder, ann_folder = PATHS[image_set] |
| 276 | + dataset = CocoDetection(img_folder, ann_folder, transforms=make_coco_transforms(image_set), return_masks=args.masks) |
216 | 277 | return dataset
|
0 commit comments