-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloader.py
113 lines (97 loc) · 3.98 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# this code is based on <https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/loader.py>
from pathlib import Path
from random import randint, choice, Random
import PIL
from torch.utils.data import Dataset
from torchvision import transforms as T
import torch
class TextImageDataset(Dataset):
def __init__(self,
folder,
text_len=256,
image_size=128,
truncate_captions=False,
resize_ratio=0.75,
tokenizer=None,
shuffle=False,
seed=None,
):
"""
@param folder: Folder containing images and text files matched by their paths' respective "stem"
@param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
"""
super().__init__()
self.rng = Random(seed)
self.shuffle = shuffle
path = Path(folder)
text_files = [*path.glob('**/*.txt')]
image_files = [
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
]
text_files = {text_file.stem: text_file for text_file in text_files}
image_files = {image_file.stem: image_file for image_file in image_files}
keys = (image_files.keys() & text_files.keys())
self.keys = sorted(list(keys))
self.rng.shuffle(self.keys)
self.text_files = {k: v for k, v in text_files.items() if k in keys}
self.image_files = {k: v for k, v in image_files.items() if k in keys}
self.text_len = text_len
self.truncate_captions = truncate_captions
self.resize_ratio = resize_ratio
self.tokenizer = tokenizer
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.RandomResizedCrop(image_size,
scale=(self.resize_ratio, 1.),
ratio=(1., 1.)),
T.ToTensor()
])
def __len__(self):
return len(self.keys)
def random_sample(self):
return self.__getitem__(self.rng.randint(0, self.__len__() - 1))
def sequential_sample(self, ind):
if ind >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(ind + 1)
def skip_sample(self, ind):
if self.shuffle:
return self.random_sample()
return self.sequential_sample(ind=ind)
def __getitem__(self, ind):
key = self.keys[ind]
text_file = self.text_files[key]
image_file = self.image_files[key]
descriptions = text_file.read_text().split('\n')
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
# try:
# description = self.rng.choice(descriptions)
# # description = descriptions[0]
# except IndexError as zero_captions_in_file_ex:
# print(f"An exception occurred trying to load file {text_file}.")
# print(f"Skipping index {ind}")
# return self.skip_sample(ind)
# print('x',description,'x')
# tokenized_text = self.tokenizer.tokenize(
# description,
# self.text_len,
# truncate_text=self.truncate_captions
# ).squeeze(0)
tokenized_text = torch.cat([
self.tokenizer.tokenize(
description,
self.text_len,
truncate_text=self.truncate_captions
)
for description in descriptions
])
try:
image_tensor = self.image_transform(PIL.Image.open(image_file))
except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
print(f"An exception occurred trying to load file {image_file}.")
print(f"Skipping index {ind}")
return self.skip_sample(ind)
# Success
return tokenized_text, descriptions, image_tensor