-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerator.py
More file actions
148 lines (104 loc) · 5.32 KB
/
generator.py
File metadata and controls
148 lines (104 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os.path
import json
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
# In this exercise task you will implement an image generator. Generator objects in python are defined as having a next function.
# This next function returns the next generated object. In our case it returns the input of a neural network each time it gets called.
# This input consists of a batch of images and its corresponding labels.
class ImageGenerator:
def __init__(self, file_path, label_path, batch_size, image_size, rotation=False, mirroring=False, shuffle=False):
# Define all members of your generator class object as global members here.
self.file_path=file_path
self.label_path=label_path
self.batch_size=batch_size
self.image_size = image_size
self.rotation = rotation
self.mirroring = mirroring
self.shuffle = shuffle
self.Curr_Index=0 # index count
self.Curr_Epoch=0 # epoch count
with open(self.label_path, 'r') as f: #To load JSON data
self.labels = json.load(f)
self.image_files = [f for f in os.listdir(self.file_path)] #loading data files
self.no_images= len(self.image_files) #number of images
if self.shuffle == True:
np.random.shuffle(self.image_files) # shuffle image files if true
self.class_dict = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog',
7: 'horse', 8: 'ship', 9: 'truck'}
# These need to include:
# the batch size
# the image size
# flags for different augmentations and whether the data should be shuffled for each epoch
# The labels are stored in json format and can be directly loaded as dictionary.
# Note that the file names correspond to the dicts of the label dictionary.
#TODO: implement constructor
def next(self):
labels=[]
images=[]
if self.Curr_Index >= self.no_images:
self.Curr_Epoch=self.Curr_Epoch+1 # increasing epoch count after one cycle is done
self.Curr_Index=0 # start by the first index again for the new epoch
if self.shuffle == True:
np.random.shuffle(self.image_files)
Curr_Batch=self.image_files[self.Curr_Index:self.Curr_Index+self.batch_size] #making batches
if len(Curr_Batch)<self.batch_size:
remaining_batch_size = self.batch_size - len(Curr_Batch)
Curr_Batch = self.image_files[self.Curr_Index:] + self.image_files[:remaining_batch_size] #reuse images from the start
for file in Curr_Batch :
Image_Path = os.path.join(self.file_path,file) #full file path /as/1.npy
img = np.load(Image_Path)
Resized_img= resize(img,(self.image_size[0],self.image_size[1],self.image_size[2]))
root, ext = os.path.splitext(file) # root ='1' ,ext= 'npy'
json_label = self.labels[root]
labels.append(json_label) # append number
if self.rotation== True or self.mirroring == True:
Resized_img = self.augment(Resized_img)
images.append(Resized_img)
self.Curr_Index = self.Curr_Index + self.batch_size # updating index
images=np.array(images)
labels=np.array(labels)
# This function creates a batch of images and corresponding labels and returns them.
# In this context a "batch" of images just means a bunch, say 10 images that are forwarded at once.
# Note that your amount of total data might not be divisible without remainder with the batch_size.
# Think about how to handle such cases
#TODO: implement next method
#print(labels)
pass
return images, labels
def augment(self,img):
if self.rotation==True:
arr = np.array([1, 2, 3,])
np.random.shuffle(arr)
img= np.rot90(img,arr[0])
if self.mirroring== True:
condition=np.array([1,2])
np.random.shuffle(condition)
if condition[0]==1:
img=np.fliplr(img)
# this function takes a single image as an input and performs a random transformation
# (mirroring and/or rotation) on it and outputs the transformed image
#TODO: implement augmentation function
return img
def current_epoch(self):
# return the current epoch number
return self.Curr_Epoch
def class_name(self, x):
# This function returns the class name for a specific input
#TODO: implement class name function
return self.class_dict.get(x)
def show(self):
# In order to verify that the generator creates batches as required, this functions calls next to get a
# batch of images and labels and visualizes it.
images, labels = self.next()
columns = 3
for i, image in enumerate(images):
rows = (len(images) + columns - 1) // columns # to calculate correct number of rows
plt.subplot(rows, columns, i + 1)
plt.imshow(image)
plt.title(self.class_name(labels[i]))
plt.axis('off')
plt.subplots_adjust( hspace=0.6) # making sure the titles doesnt overlap with the images
plt.show()
#TODO: implement show method