Skip to content

Commit

Permalink
added basic quantization. Quantization added in input
Browse files Browse the repository at this point in the history
  • Loading branch information
tzufgoogle committed Sep 29, 2022
1 parent 1429eb7 commit 33580f6
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 129 deletions.
83 changes: 48 additions & 35 deletions cabby/model/dataset_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ class TextGeoDataset:
unique_cellids_binary: torch.tensor = attr.ib()
label_to_cellid: Dict[int, int] = attr.ib()
coord_to_cellid: Dict[str, int] = attr.ib()
graph_embed_size: int = attr.ib()

@classmethod
def from_TextGeoSplit(cls, train, valid, test, unique_cellids,
unique_cellids_binary, label_to_cellid, coord_to_cellid):
unique_cellids_binary, label_to_cellid, coord_to_cellid, graph_embed_size):
"""Construct a TextGeoDataset."""
return TextGeoDataset(
train,
Expand All @@ -73,6 +74,7 @@ def from_TextGeoSplit(cls, train, valid, test, unique_cellids,
unique_cellids_binary,
label_to_cellid,
coord_to_cellid,
graph_embed_size,
)

@classmethod
Expand All @@ -90,6 +92,8 @@ def load(cls, dataset_dir: Text, model_type: Text = None,
tensor_cellid_path = os.path.join(dataset_dir, "tensor_cellid.pth")
label_to_cellid_path = os.path.join(dataset_dir, "label_to_cellid.npy")
coord_to_cellid_path = os.path.join(dataset_dir, "coord_to_cellid.npy")
graph_embed_size_path = os.path.join(dataset_dir, "graph_embed_size.npy")


logging.info("Loading dataset from <== {}.".format(dataset_dir))
train_dataset = torch.load(train_path_dataset)
Expand All @@ -102,29 +106,38 @@ def load(cls, dataset_dir: Text, model_type: Text = None,
label_to_cellid = np.load(
label_to_cellid_path, allow_pickle='TRUE').item()
tens_cells = torch.load(tensor_cellid_path)
coord_to_cellid = np.load(coord_to_cellid_path, allow_pickle='TRUE')
coord_to_cellid = np.load(coord_to_cellid_path, allow_pickle='TRUE').item()
graph_embed_size = np.load(graph_embed_size_path, allow_pickle='TRUE')
logging.info(f"Loaded dataset with graph embedding size {graph_embed_size}")

dataset_text = TextGeoDataset(
train_dataset, valid_dataset, test_dataset,
unique_cellid, tens_cells, label_to_cellid, coord_to_cellid)
unique_cellid, tens_cells, label_to_cellid, coord_to_cellid, graph_embed_size)

return dataset_text

@classmethod
def save(cls, dataset_text: Any, dataset_path: Text,
train_path_dataset: Text, valid_path_dataset: Text,
test_path_dataset: Text, unique_cellid_path: Text,
tensor_cellid_path: Text, label_to_cellid_path: Text,
coord_to_cellid_path: Text):
graph_embed_size: int):
os.mkdir(dataset_path)

train_path_dataset = os.path.join(dataset_path, 'train.pth')
valid_path_dataset = os.path.join(dataset_path, 'valid.pth')
test_path_dataset = os.path.join(dataset_path, 'test.pth')
unique_cellid_path = os.path.join(dataset_path, "unique_cellid.npy")
tensor_cellid_path = os.path.join(dataset_path, "tensor_cellid.pth")
label_to_cellid_path = os.path.join(dataset_path, "label_to_cellid.npy")
coord_to_cellid_path = os.path.join(dataset_path, "coord_to_cellid.npy")
graph_embed_size_path = os.path.join(dataset_path, "graph_embed_size.npy")

torch.save(dataset_text.train, train_path_dataset)
torch.save(dataset_text.valid, valid_path_dataset)
torch.save(dataset_text.test, test_path_dataset)
np.save(unique_cellid_path, dataset_text.unique_cellids)
torch.save(dataset_text.unique_cellids_binary, tensor_cellid_path)
np.save(label_to_cellid_path, dataset_text.label_to_cellid)
np.save(coord_to_cellid_path, dataset_text.coord_to_cellid)
np.save(graph_embed_size_path, graph_embed_size)

logging.info("Saved data to ==> {}.".format(dataset_path))

Expand Down Expand Up @@ -174,7 +187,7 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in

self.encodings = self.text_tokenizer(
data.instructions.tolist(), truncation=True,
padding=True, add_special_tokens=True)
padding=True, add_special_tokens=True, max_length=200)

data['far_cells'] = data.cellid.apply(
lambda cellid: unique_cells_df[unique_cells_df['cellid'] == cellid].far.iloc[0])
Expand All @@ -187,14 +200,16 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in
lambda x: gutil.tuple_from_point(x)).tolist()


self.coords = data.cellid.apply(lambda x: cellid_to_coord[x]).tolist()
self.coords_end = data.cellid.apply(lambda x: cellid_to_coord[x]).tolist()

self.labels = data.cellid.apply(lambda x: cellid_to_label[x]).tolist()

self.start_point_cells = data.start_point.apply(
self.start_cells = data.start_point.apply(
lambda x: gutil.cellid_from_point(x, s2level))

self.start_point_labels = self.get_cell_to_lablel(self.start_point_cells.tolist())
self.coords_start = self.start_cells.apply(lambda x: cellid_to_coord[x]).tolist()

self.start_point_labels = self.get_cell_to_lablel(self.start_cells.tolist())

self.cellids = self.s2_tokenizer(cellids_array)

Expand All @@ -209,14 +224,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in
self.start_point_labels, data.instructions.tolist())]

if graph_embed_file:

self.graph_embed_start = self.start_point_cells.apply(
lambda cell: torch.from_numpy(
util.get_valid_graph_embed(self.graph_embed_file, str(cell))).unsqueeze(0).unsqueeze(0))

self.graph_embed_end = data['cellid'].apply(
lambda cell: util.get_valid_graph_embed(self.graph_embed_file, str(cell)))
self.graph_embed_start = self.start_point_cells.apply(
self.graph_embed_start = self.start_cells.apply(
lambda cell: util.get_valid_graph_embed(self.graph_embed_file, str(cell)))

data['landmarks_cells'] = data.landmarks.apply(
Expand All @@ -230,6 +240,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in
str(i).replace(':', f': Start at {str(s)}.') for s, i in zip(
self.graph_embed_start, data.instructions.tolist())]

else:
self.graph_embed_start = np.zeros(len(self.cellids))

self.landmarks_dist_raw = []

data['landmarks_cells'] = data.landmarks.apply(
Expand All @@ -247,7 +260,7 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in

if is_dist:
logging.info(f"Calculating distances between {dist_matrix.shape[0]} cells")
dist_lists = self.start_point_cells.apply(lambda start: self.calc_dist(start, dist_matrix))
dist_lists = self.start_cells.apply(lambda start: self.calc_dist(start, dist_matrix))

self.prob = dist_lists.mapply(
lambda row: [dprob(dist) for dist in row.tolist()])
Expand All @@ -257,8 +270,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in

self.set_generation_model(data)


del self.graph_embed_file
del self.start_point_cells
# del self.start_point_cells
del self.s2_tokenizer
del self.text_tokenizer

Expand Down Expand Up @@ -299,7 +313,7 @@ def set_generation_model(self, data):
output_text, truncation=True, padding=True, add_special_tokens=True).input_ids

self.text_input_tokenized = self.text_tokenizer(
input_text, truncation=True, padding=True, add_special_tokens=True)
input_text, truncation=True, padding='max_length', add_special_tokens=True, max_length=200)

def set_S2_Generation_T5_text_start_to_landmarks_dist(self, data):

Expand All @@ -310,18 +324,18 @@ def set_S2_Generation_T5_text_start_to_end_dist(self, data):

self.end_dist_raw = [f"{e_l} distance: {round(gutil.get_distance_between_points(s_p, e_p))}"
for e_l, e_p, s_p in zip(
self.coords, self.end_point, data.start_point)]
self.coords_end, self.end_point, data.start_point)]

return self.start_text_input_list, self.end_dist_raw


def set_S2_Generation_T5_start_embedding_text_input(self, data):

return self.start_embed_text_input_list, self.coords
return self.start_text_input_list, self.coords_end

def set_S2_Generation_T5_start_text_input(self, data):

return self.start_text_input_list, self.coords
return self.start_text_input_list, self.coords_end


def set_S2_Generation_T5_Landmarks(self, data):
Expand Down Expand Up @@ -358,7 +372,7 @@ def set_S2_Generation_T5_Warmup_start_end_to_dist(self, data):

start_end_point_list_raw = [
f"{self.model_type}: {str(e)}, {str(s)}" for s, e in zip(
self.start_point_labels, self.coords)]
self.start_point_labels, self.coords_end)]

return start_end_point_list_raw, dists_start_end

Expand All @@ -369,7 +383,7 @@ def set_S2_Generation_T5_Warmup_start_end(self, data):

start_end_point_list_raw = [
f"{self.model_type}: {str(e)}, {str(s)}" for s, e in zip(
self.start_point_labels, self.coords)]
self.start_point_labels, self.coords_end)]

data['route_fixed'] = data.route_fixed.apply(
lambda l: [gutil.cellid_from_point(x, self.s2level) for x in l])
Expand All @@ -379,7 +393,7 @@ def set_S2_Generation_T5_Warmup_start_end(self, data):
return start_end_point_list_raw, route_fixed_label

def set_S2_Generation_T5(self, data):
return data.instructions.tolist(), self.coords
return data.instructions.tolist(), self.coords_end


def set_S2_Generation_T5_Path(self, data):
Expand All @@ -397,19 +411,16 @@ def set_S2_Generation_T5_Path(self, data):
return data.instructions.tolist(), route_label

def set_S2_Generation_T5_text_start_embedding_to_landmarks_dist(self, data):
return self.start_embed_text_input_list, self.landmarks_dist_raw
return self.start_text_input_list, self.landmarks_dist_raw


def set_S2_Generation_T5_text_start_embedding_to_landmarks(self, data):
return self.start_embed_text_input_list, self.landmark_label
return self.start_text_input_list, self.landmark_label


def set_S2_Generation_T5_Warmup_cell_embed_to_cell_label(self, data):

graph_embed_end_and_prompt = [
f"{self.model_type}: {str(e)}" for e in self.graph_embed_end
]
return graph_embed_end_and_prompt, self.coords
return [self.model_type]*len(self.coords_start), self.coords_start

def set_Classification_Bert(self, data):

Expand All @@ -420,15 +431,15 @@ def set_Dual_Encoder_Bert(self, data):

def print_sample(self, mode_expected, input, output):

assert 'T5'not in mode_expected or mode_expected in input, \
assert 'T5' not in mode_expected or mode_expected in input, \
f"mode_expected: {mode_expected} \n input: {input}"

if self.model_type == mode_expected:
logging.info(
f"\n Example {self.model_type}: \n" +
f" Input: '{input}'\n" +
f" Output: {output}\n" +
f" Goal: {self.coords[0]}\n" +
f" Goal: {self.coords_end[0]}\n" +
f" Start: {self.start_point_labels[0]}"
)

Expand Down Expand Up @@ -462,14 +473,16 @@ def __getitem__(self, idx: int):

text_input = torch.tensor(self.text_input_tokenized[idx])

graph_embed_start = self.graph_embed_start[idx]


text_output = torch.tensor(self.text_output_tokenized[idx])


sample = {'text': text_input, 'cellid': cellid,
'neighbor_cells': neighbor_cells,
'far_cells': far_cells, 'end_point': end_point, 'label': label,
'prob': prob, 'text_output': text_output,
'prob': prob, 'text_output': text_output, 'graph_embed_start': graph_embed_start
}

return sample
Expand Down
Loading

0 comments on commit 33580f6

Please sign in to comment.