-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinput_data_levelDB_escape_data_augmentation.py
96 lines (85 loc) · 4.24 KB
/
input_data_levelDB_escape_data_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import input_data_levelDB_simulator_data_augmentation
import numpy as np
import random
from input_data_levelDB_simulator_data_augmentation import readImageFromDB
from ast import literal_eval
import leveldb
def readPointFromDB(db, key, size=(2,1,1)):
point_str = db.Get(key)
point = literal_eval(point_str)
point = point[1], point[0] #Converte de Y,X para X,Y
return np.reshape(np.array(point, dtype=np.int32).T, size)
def rotPoint90(point, origin, rotation):
rotation = rotation % 4
p_o = point - origin
if rotation == 1:
return np.array((-p_o[1], p_o[0]), dtype=np.int32) + origin
elif rotation == 2:
return np.array((-p_o[0], -p_o[1]), dtype=np.int32) + origin
elif rotation == 3:
return np.array((p_o[1], -p_o[0]), dtype=np.int32) + origin
else:
return point
def flipHPoint(point, origin):
return np.array((-(point - origin)[0], (point - origin)[1]), dtype=np.int32) + origin
class DataSet(input_data_levelDB_simulator_data_augmentation.DataSet):
def __init__(self, images_key, input_size, num_examples, db, validation,invert, rotate):
super(DataSet, self).__init__(images_key, input_size, (), num_examples, db, validation, invert, rotate)
def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if batch_size > (self._num_examples - self._index_in_epoch):
# Finished epoch
print 'end epoch'
self._epochs_completed += 1
# Shuffle the data
""" Shufling all the Images with a single permutation """
random.shuffle(self._images_key)
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
images = np.empty((batch_size, self._input_size[0], self._input_size[1],self._input_size[2]))
points = np.empty((batch_size, 2, 1, 1))
origem = np.array((self._input_size[0] / 2, self._input_size[1] / 2), dtype=np.int32)
origem = np.reshape(origem, (2, 1, 1))
for n in range(batch_size):
key=self._images_key[start+n]
rotation=0
inversion=0
if self.rotate:
rotation=key & 3
key=int(key/4)
if self.invert:
inversion=key & 1
key=int(key/2)
if self._is_validation:
images[n] = readImageFromDB(self._db,'val'+str(key),self._input_size)
points[n] = readPointFromDB(self._db,'val'+str(key)+"point")
else:
images[n] = readImageFromDB(self._db,str(key),self._input_size)
points[n] = readPointFromDB(self._db,str(key)+"point")
images[n]=np.rot90(images[n], rotation)
points[n]=rotPoint90(points[n], origem, rotation)
if inversion:
images[n]=np.fliplr(images[n])
points[n]=flipHPoint(points[n], origem)
return images, points #, depths #, transmission
class DataSetManager(input_data_levelDB_simulator_data_augmentation.DataSetManager):
def __init__(self, config):
self.input_size = config.input_size
self.db = leveldb.LevelDB(config.leveldb_path + 'db')
self.num_examples = int(self.db.Get('num_examples'))
self.num_examples_val = int(self.db.Get('num_examples_val'))
if config.invert:
self.num_examples = self.num_examples * 2
self.num_examples_val= self.num_examples_val * 2
if config.rotate:
self.num_examples = self.num_examples * 4
self.num_examples_val= self.num_examples_val * 4
self.images_key = range(self.num_examples)
self.images_key_val = range(self.num_examples_val)
# for i in range(self.num_examples_val):
# self.images_key_val[i] = 'val' + str(i)
self.train = DataSet(self.images_key,config.input_size,self.num_examples,self.db,validation=False,invert=config.invert,rotate=config.rotate)
self.validation = DataSet(self.images_key_val,config.input_size,self.num_examples_val,self.db,validation=True,invert=config.invert,rotate=config.rotate)