From 33580f61ba61ae54e6273666837a47c1f74af66c Mon Sep 17 00:00:00 2001 From: tzufgoogle Date: Thu, 29 Sep 2022 19:44:03 +0300 Subject: [PATCH] added basic quantization. Quantization added in input --- cabby/model/dataset_item.py | 83 ++++++---- cabby/model/datasets.py | 128 +++++++++----- cabby/model/text/model_trainer.py | 27 +-- cabby/model/text/model_trainer_multitask.py | 5 +- cabby/model/text/models.py | 175 +++++++++++++++++--- cabby/model/text/train.py | 30 +++- cabby/model/util.py | 3 +- run_all.sh | 13 +- 8 files changed, 335 insertions(+), 129 deletions(-) diff --git a/cabby/model/dataset_item.py b/cabby/model/dataset_item.py index fb8a9658..19b51514 100755 --- a/cabby/model/dataset_item.py +++ b/cabby/model/dataset_item.py @@ -60,10 +60,11 @@ class TextGeoDataset: unique_cellids_binary: torch.tensor = attr.ib() label_to_cellid: Dict[int, int] = attr.ib() coord_to_cellid: Dict[str, int] = attr.ib() + graph_embed_size: int = attr.ib() @classmethod def from_TextGeoSplit(cls, train, valid, test, unique_cellids, - unique_cellids_binary, label_to_cellid, coord_to_cellid): + unique_cellids_binary, label_to_cellid, coord_to_cellid, graph_embed_size): """Construct a TextGeoDataset.""" return TextGeoDataset( train, @@ -73,6 +74,7 @@ def from_TextGeoSplit(cls, train, valid, test, unique_cellids, unique_cellids_binary, label_to_cellid, coord_to_cellid, + graph_embed_size, ) @classmethod @@ -90,6 +92,8 @@ def load(cls, dataset_dir: Text, model_type: Text = None, tensor_cellid_path = os.path.join(dataset_dir, "tensor_cellid.pth") label_to_cellid_path = os.path.join(dataset_dir, "label_to_cellid.npy") coord_to_cellid_path = os.path.join(dataset_dir, "coord_to_cellid.npy") + graph_embed_size_path = os.path.join(dataset_dir, "graph_embed_size.npy") + logging.info("Loading dataset from <== {}.".format(dataset_dir)) train_dataset = torch.load(train_path_dataset) @@ -102,22 +106,30 @@ def load(cls, dataset_dir: Text, model_type: Text = None, label_to_cellid = np.load( label_to_cellid_path, allow_pickle='TRUE').item() tens_cells = torch.load(tensor_cellid_path) - coord_to_cellid = np.load(coord_to_cellid_path, allow_pickle='TRUE') + coord_to_cellid = np.load(coord_to_cellid_path, allow_pickle='TRUE').item() + graph_embed_size = np.load(graph_embed_size_path, allow_pickle='TRUE') + logging.info(f"Loaded dataset with graph embedding size {graph_embed_size}") dataset_text = TextGeoDataset( train_dataset, valid_dataset, test_dataset, - unique_cellid, tens_cells, label_to_cellid, coord_to_cellid) + unique_cellid, tens_cells, label_to_cellid, coord_to_cellid, graph_embed_size) return dataset_text @classmethod def save(cls, dataset_text: Any, dataset_path: Text, - train_path_dataset: Text, valid_path_dataset: Text, - test_path_dataset: Text, unique_cellid_path: Text, - tensor_cellid_path: Text, label_to_cellid_path: Text, - coord_to_cellid_path: Text): + graph_embed_size: int): os.mkdir(dataset_path) + train_path_dataset = os.path.join(dataset_path, 'train.pth') + valid_path_dataset = os.path.join(dataset_path, 'valid.pth') + test_path_dataset = os.path.join(dataset_path, 'test.pth') + unique_cellid_path = os.path.join(dataset_path, "unique_cellid.npy") + tensor_cellid_path = os.path.join(dataset_path, "tensor_cellid.pth") + label_to_cellid_path = os.path.join(dataset_path, "label_to_cellid.npy") + coord_to_cellid_path = os.path.join(dataset_path, "coord_to_cellid.npy") + graph_embed_size_path = os.path.join(dataset_path, "graph_embed_size.npy") + torch.save(dataset_text.train, train_path_dataset) torch.save(dataset_text.valid, valid_path_dataset) torch.save(dataset_text.test, test_path_dataset) @@ -125,6 +137,7 @@ def save(cls, dataset_text: Any, dataset_path: Text, torch.save(dataset_text.unique_cellids_binary, tensor_cellid_path) np.save(label_to_cellid_path, dataset_text.label_to_cellid) np.save(coord_to_cellid_path, dataset_text.coord_to_cellid) + np.save(graph_embed_size_path, graph_embed_size) logging.info("Saved data to ==> {}.".format(dataset_path)) @@ -174,7 +187,7 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in self.encodings = self.text_tokenizer( data.instructions.tolist(), truncation=True, - padding=True, add_special_tokens=True) + padding=True, add_special_tokens=True, max_length=200) data['far_cells'] = data.cellid.apply( lambda cellid: unique_cells_df[unique_cells_df['cellid'] == cellid].far.iloc[0]) @@ -187,14 +200,16 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in lambda x: gutil.tuple_from_point(x)).tolist() - self.coords = data.cellid.apply(lambda x: cellid_to_coord[x]).tolist() + self.coords_end = data.cellid.apply(lambda x: cellid_to_coord[x]).tolist() self.labels = data.cellid.apply(lambda x: cellid_to_label[x]).tolist() - self.start_point_cells = data.start_point.apply( + self.start_cells = data.start_point.apply( lambda x: gutil.cellid_from_point(x, s2level)) - self.start_point_labels = self.get_cell_to_lablel(self.start_point_cells.tolist()) + self.coords_start = self.start_cells.apply(lambda x: cellid_to_coord[x]).tolist() + + self.start_point_labels = self.get_cell_to_lablel(self.start_cells.tolist()) self.cellids = self.s2_tokenizer(cellids_array) @@ -209,14 +224,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in self.start_point_labels, data.instructions.tolist())] if graph_embed_file: - - self.graph_embed_start = self.start_point_cells.apply( - lambda cell: torch.from_numpy( - util.get_valid_graph_embed(self.graph_embed_file, str(cell))).unsqueeze(0).unsqueeze(0)) - self.graph_embed_end = data['cellid'].apply( lambda cell: util.get_valid_graph_embed(self.graph_embed_file, str(cell))) - self.graph_embed_start = self.start_point_cells.apply( + self.graph_embed_start = self.start_cells.apply( lambda cell: util.get_valid_graph_embed(self.graph_embed_file, str(cell))) data['landmarks_cells'] = data.landmarks.apply( @@ -230,6 +240,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in str(i).replace(':', f': Start at {str(s)}.') for s, i in zip( self.graph_embed_start, data.instructions.tolist())] + else: + self.graph_embed_start = np.zeros(len(self.cellids)) + self.landmarks_dist_raw = [] data['landmarks_cells'] = data.landmarks.apply( @@ -247,7 +260,7 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in if is_dist: logging.info(f"Calculating distances between {dist_matrix.shape[0]} cells") - dist_lists = self.start_point_cells.apply(lambda start: self.calc_dist(start, dist_matrix)) + dist_lists = self.start_cells.apply(lambda start: self.calc_dist(start, dist_matrix)) self.prob = dist_lists.mapply( lambda row: [dprob(dist) for dist in row.tolist()]) @@ -257,8 +270,9 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in self.set_generation_model(data) + del self.graph_embed_file - del self.start_point_cells + # del self.start_point_cells del self.s2_tokenizer del self.text_tokenizer @@ -299,7 +313,7 @@ def set_generation_model(self, data): output_text, truncation=True, padding=True, add_special_tokens=True).input_ids self.text_input_tokenized = self.text_tokenizer( - input_text, truncation=True, padding=True, add_special_tokens=True) + input_text, truncation=True, padding='max_length', add_special_tokens=True, max_length=200) def set_S2_Generation_T5_text_start_to_landmarks_dist(self, data): @@ -310,18 +324,18 @@ def set_S2_Generation_T5_text_start_to_end_dist(self, data): self.end_dist_raw = [f"{e_l} distance: {round(gutil.get_distance_between_points(s_p, e_p))}" for e_l, e_p, s_p in zip( - self.coords, self.end_point, data.start_point)] + self.coords_end, self.end_point, data.start_point)] return self.start_text_input_list, self.end_dist_raw def set_S2_Generation_T5_start_embedding_text_input(self, data): - return self.start_embed_text_input_list, self.coords + return self.start_text_input_list, self.coords_end def set_S2_Generation_T5_start_text_input(self, data): - return self.start_text_input_list, self.coords + return self.start_text_input_list, self.coords_end def set_S2_Generation_T5_Landmarks(self, data): @@ -358,7 +372,7 @@ def set_S2_Generation_T5_Warmup_start_end_to_dist(self, data): start_end_point_list_raw = [ f"{self.model_type}: {str(e)}, {str(s)}" for s, e in zip( - self.start_point_labels, self.coords)] + self.start_point_labels, self.coords_end)] return start_end_point_list_raw, dists_start_end @@ -369,7 +383,7 @@ def set_S2_Generation_T5_Warmup_start_end(self, data): start_end_point_list_raw = [ f"{self.model_type}: {str(e)}, {str(s)}" for s, e in zip( - self.start_point_labels, self.coords)] + self.start_point_labels, self.coords_end)] data['route_fixed'] = data.route_fixed.apply( lambda l: [gutil.cellid_from_point(x, self.s2level) for x in l]) @@ -379,7 +393,7 @@ def set_S2_Generation_T5_Warmup_start_end(self, data): return start_end_point_list_raw, route_fixed_label def set_S2_Generation_T5(self, data): - return data.instructions.tolist(), self.coords + return data.instructions.tolist(), self.coords_end def set_S2_Generation_T5_Path(self, data): @@ -397,19 +411,16 @@ def set_S2_Generation_T5_Path(self, data): return data.instructions.tolist(), route_label def set_S2_Generation_T5_text_start_embedding_to_landmarks_dist(self, data): - return self.start_embed_text_input_list, self.landmarks_dist_raw + return self.start_text_input_list, self.landmarks_dist_raw def set_S2_Generation_T5_text_start_embedding_to_landmarks(self, data): - return self.start_embed_text_input_list, self.landmark_label + return self.start_text_input_list, self.landmark_label def set_S2_Generation_T5_Warmup_cell_embed_to_cell_label(self, data): - graph_embed_end_and_prompt = [ - f"{self.model_type}: {str(e)}" for e in self.graph_embed_end - ] - return graph_embed_end_and_prompt, self.coords + return [self.model_type]*len(self.coords_start), self.coords_start def set_Classification_Bert(self, data): @@ -420,7 +431,7 @@ def set_Dual_Encoder_Bert(self, data): def print_sample(self, mode_expected, input, output): - assert 'T5'not in mode_expected or mode_expected in input, \ + assert 'T5' not in mode_expected or mode_expected in input, \ f"mode_expected: {mode_expected} \n input: {input}" if self.model_type == mode_expected: @@ -428,7 +439,7 @@ def print_sample(self, mode_expected, input, output): f"\n Example {self.model_type}: \n" + f" Input: '{input}'\n" + f" Output: {output}\n" + - f" Goal: {self.coords[0]}\n" + + f" Goal: {self.coords_end[0]}\n" + f" Start: {self.start_point_labels[0]}" ) @@ -462,6 +473,8 @@ def __getitem__(self, idx: int): text_input = torch.tensor(self.text_input_tokenized[idx]) + graph_embed_start = self.graph_embed_start[idx] + text_output = torch.tensor(self.text_output_tokenized[idx]) @@ -469,7 +482,7 @@ def __getitem__(self, idx: int): sample = {'text': text_input, 'cellid': cellid, 'neighbor_cells': neighbor_cells, 'far_cells': far_cells, 'end_point': end_point, 'label': label, - 'prob': prob, 'text_output': text_output, + 'prob': prob, 'text_output': text_output, 'graph_embed_start': graph_embed_start } return sample diff --git a/cabby/model/datasets.py b/cabby/model/datasets.py index c1235ed5..6f5e47e8 100644 --- a/cabby/model/datasets.py +++ b/cabby/model/datasets.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from absl import logging +from gensim.models import KeyedVectors import numpy as np import os import pandas as pd @@ -71,14 +73,22 @@ def __init__( region: Optional[str], model_type: str, n_fixed_points: int = 4, - graph_embed_file: Any = None, + graph_embed_path: str = "", ): self.data_dir = data_dir self.s2level = s2level self.region = region self.model_type = model_type self.n_fixed_points = n_fixed_points - self.graph_embed_file = graph_embed_file + + self.graph_embed_size = 0 + self.graph_embed_file = None + if os.path.exists(graph_embed_path): + self.graph_embed_file = KeyedVectors.load_word2vec_format(graph_embed_path) + first_cell = self.graph_embed_file.index_to_key[0] + self.graph_embed_size = self.graph_embed_file[first_cell].shape[0] + + logging.info(f"Dataset with graph embedding size {self.graph_embed_size}") self.unique_cellid = {} self.cellid_to_label = {} @@ -159,13 +169,12 @@ def create_dataset( unique_cells_df = pd.DataFrame( {'point': points, 'cellid': self.unique_cellid}) - self.cellid_to_coord, self.coord_to_cellid = self.create_grid(unique_cells_df) + logging.info(f"Created grid") dist_matrix = unique_cells_df.point.mapply( lambda x: calc_dist(x, unique_cells_df) ) - dist_matrix = dist_matrix.to_numpy() unique_cells_df['far'] = unique_cells_df.point.swifter.apply( @@ -216,14 +225,14 @@ def create_dataset( return dataset_item.TextGeoDataset.from_TextGeoSplit( train_dataset, val_dataset, test_dataset, np.array(self.unique_cellid), - tens_cells, self.label_to_cellid, self.cellid_to_coord) + tens_cells, self.label_to_cellid, self.coord_to_cellid, self.graph_embed_size) def create_grid(self, unique_cells_df): unique_cells_df['lon'] = unique_cells_df.point.apply(lambda p: p.x) unique_cells_df['lat'] = unique_cells_df.point.apply(lambda p: p.y) - unique_cells_df = unique_cells_df.sort_values(by=['lat','lon'], ascending=True) + unique_cells_df = unique_cells_df.sort_values(by=['lat', 'lon'], ascending=True) cellid_to_coord = {} coord_to_cellid = {} @@ -235,37 +244,67 @@ def create_grid(self, unique_cells_df): y_current = 0 current_cellid = unique_cells_df.cellid.tolist()[0] + size_region = len(unique_cells_df) + min_cell = current_cellid + tmp = [] + begin_cell = None - for i in range(100): + for down_idx in range(20): cell = s2.S2CellId(current_cellid) - prev_cell = cell.GetEdgeNeighbors()[0] - prev_cell_id = prev_cell.id() - current_cellid = prev_cell_id + down_cell = cell.GetEdgeNeighbors()[1] + current_cellid = down_cell.id() - while current_cellid not in full_cellid_list: + for up_idx in range(20): cell = s2.S2CellId(current_cellid) - next_cell = cell.GetEdgeNeighbors()[2] - current_cellid = next_cell.id() + up_cell = cell.GetEdgeNeighbors()[3] + current_cellid = up_cell.id() + + for left_idx in range(round(size_region/20)): + cell = s2.S2CellId(current_cellid) + left_cell = cell.GetEdgeNeighbors()[0] + left_cell_id = left_cell.id() + current_cellid = left_cell_id + if current_cellid in full_cellid_list: + begin_cell = current_cellid + + if begin_cell: + break + + for right_idx in range(round(size_region/20)): + cell = s2.S2CellId(current_cellid) + right_cell = cell.GetEdgeNeighbors()[2] + right_cell_id = right_cell.id() + current_cellid = right_cell_id + if current_cellid in full_cellid_list: + begin_cell = current_cellid + break + + if begin_cell: + break + + current_cellid = begin_cell + + status_cell_list = (len(full_cellid_list), 0) while True: cell = s2.S2CellId(current_cellid) next_cell = cell.GetEdgeNeighbors()[2] next_cellid = next_cell.id() - is_next = False + is_next = False if current_cellid in full_cellid_list: - cellid_to_coord[current_cellid] = (x_current, y_current) full_cellid_list.remove(current_cellid) empty_cellid_list.append(current_cellid) - x_current+=1 + + status_cell_list = (len(full_cellid_list), 0) + x_current += 1 current_cellid = next_cellid is_next = True - if not is_next: y_current += 1 @@ -273,40 +312,52 @@ def create_grid(self, unique_cells_df): current_cellid = upper_cell.id() prev_cell = upper_cell.GetEdgeNeighbors()[0] prev_cell_id = prev_cell.id() - while prev_cell_id in full_cellid_list: + + counter_move_left = 0 + while prev_cell_id in full_cellid_list and counter_move_left100: + sys.exit(f"Problem with creating grid. " + + f"There are still {size_cell_list} cells not in grid: {full_cellid_list}. " + + f"The beginig cell: {begin_cell}") min_x = min(cellid_to_coord.items(), key=lambda x: x[1][0])[1][0] - if min_x<0: - add_x = -1*min_x + if min_x < 0: + + add_x = -1 * min_x new_cellid_to_coord = {} for cellid, (x, y) in cellid_to_coord.items(): + new_x = x + add_x new_cellid_to_coord[cellid] = (new_x, y) @@ -320,10 +371,11 @@ def create_grid(self, unique_cells_df): coord_format_to_cellid = {dataset_item.coord_format(coord): cellid for cellid, coord in cellid_to_coord.items()} cellid_to_coord_format = {cellid: dataset_item.coord_format(coord) for cellid, coord in cellid_to_coord.items()} - assert len(full_cellid_list)==0, f"full_cellid_list: {len(full_cellid_list)} empty_cellid_list:{len(empty_cellid_list)}" + assert len( + full_cellid_list) == 0, f"full_cellid_list: {len(full_cellid_list)} empty_cellid_list:{len(empty_cellid_list)}" assert len(cellid_to_coord_format) == unique_cells_df.cellid.shape[0] - return cellid_to_coord_format, coord_format_to_cellid + return cellid_to_coord_format, coord_format_to_cellid def get_fixed_point_along_route(self, row): points_list = row.route @@ -348,11 +400,11 @@ def __init__( s2level: int, region: Optional[str], n_fixed_points: int = 4, - graph_embed_file: Any = None, + graph_embed_path: str = "", model_type: str = "Dual-Encoder-Bert"): Dataset.__init__( - self, data_dir, s2level, region, model_type, n_fixed_points, graph_embed_file) + self, data_dir, s2level, region, model_type, n_fixed_points, graph_embed_path) self.train = self.load_data(data_dir, 'train', lines=True) self.valid = self.load_data(data_dir, 'dev', lines=True) self.test = self.load_data(data_dir, 'test', lines=True) @@ -402,11 +454,11 @@ def __init__( data_dir: str, s2level: int, region: Optional[str], - graph_embed_file: Any = None, + graph_embed_path: str = "", n_fixed_points: int = 4, model_type: str = "Dual-Encoder-Bert"): Dataset.__init__( - self, data_dir, s2level, None, model_type, n_fixed_points, graph_embed_file + self, data_dir, s2level, None, model_type, n_fixed_points, graph_embed_path ) train_ds, valid_ds, test_ds, ds = self.load_data(data_dir, lines=False) @@ -442,7 +494,7 @@ def load_data(self, data_dir: str, lines: bool): ds['instructions'] = ds.groupby( ['id'])['instruction'].transform(lambda x: ' '.join(x)) - ds = ds.drop_duplicates(subset='id', keep="last") + ds = ds.drop_duplicates(subset='id', keep="last", ignore_index=True) columns_keep = ds.columns.difference( ['map', 'id', 'instructions', 'end_point', 'start_point']) @@ -469,11 +521,11 @@ def __init__( s2level: int, region: Optional[str], n_fixed_points: int = 4, - graph_embed_file: Any = None, + graph_embed_path: str = "", model_type: str = "Dual-Encoder-Bert" ): Dataset.__init__( - self, data_dir, s2level, region, model_type, n_fixed_points, graph_embed_file) + self, data_dir, s2level, region, model_type, n_fixed_points, graph_embed_path) train_ds = self.load_data(data_dir, 'train', True) valid_ds = self.load_data(data_dir, 'dev', True) test_ds = self.load_data(data_dir, 'test', True) @@ -501,7 +553,8 @@ def load_data(self, data_dir: str, split: str, lines: bool): ds['landmarks_ner'] = ds.geo_landmarks.apply(self.process_landmarks_ner) ds['landmarks_ner_and_point'] = ds.geo_landmarks.apply(self.process_landmarks_ner_single) - ds = pd.concat([ds.drop(['geo_landmarks'], axis=1), ds['geo_landmarks'].apply(pd.Series)], axis=1) + ds = pd.concat( + [ds.drop(['geo_landmarks'], axis=1), ds['geo_landmarks'].apply(pd.Series)], axis=1) ds['end_osmid'] = ds.end_point.apply(lambda x: x[1]) ds['start_osmid'] = ds.start_point.apply(lambda x: x[1]) @@ -515,10 +568,11 @@ def load_data(self, data_dir: str, split: str, lines: bool): logging.info(f"Size of dataset before removal of duplication: {ds.shape[0]}") - ds = ds.drop_duplicates(subset=['end_osmid', 'start_osmid'], keep='last') + ds = ds.drop_duplicates(subset=['end_osmid', 'start_osmid'], keep='last', ignore_index=True) logging.info(f"Size of dataset after removal of duplication: {ds.shape[0]}") + ds = ds.reset_index(drop=True) return ds def process_landmarks(self, row): diff --git a/cabby/model/text/model_trainer.py b/cabby/model/text/model_trainer.py index 28ee6e30..807d9d04 100755 --- a/cabby/model/text/model_trainer.py +++ b/cabby/model/text/model_trainer.py @@ -48,7 +48,6 @@ import torch.nn as nn from torch.utils.data import DataLoader from transformers import AdamW -from gensim.models import KeyedVectors from cabby.evals import utils as eu from cabby.model.text import train @@ -92,7 +91,7 @@ flags.DEFINE_string("model_path", None, "A path of a model the model to be fine tuned\ evaluated.") -flags.DEFINE_string("save_graph_embed_path", default="", +flags.DEFINE_string("graph_embed_path", default="", help="The path to the graph embedding.") flags.DEFINE_integer( @@ -147,13 +146,7 @@ def main(argv): dataset_model_path = os.path.join(FLAGS.dataset_dir, str(FLAGS.model)) dataset_path = os.path.join(dataset_model_path, str(FLAGS.s2_level)) - train_path_dataset = os.path.join(dataset_path, 'train.pth') - valid_path_dataset = os.path.join(dataset_path, 'valid.pth') - test_path_dataset = os.path.join(dataset_path, 'test.pth') - unique_cellid_path = os.path.join(dataset_path, "unique_cellid.npy") - tensor_cellid_path = os.path.join(dataset_path, "tensor_cellid.pth") - label_to_cellid_path = os.path.join(dataset_path, "label_to_cellid.npy") - coord_to_cellid_path = os.path.join(dataset_path, "coord_to_cellid.npy") + assert FLAGS.task in TASKS @@ -169,9 +162,6 @@ def main(argv): if FLAGS.is_single_sample_train: FLAGS.train_batch_size = 1 - graph_embed_file = None - if os.path.exists(FLAGS.save_graph_embed_path): - graph_embed_file = KeyedVectors.load_word2vec_format(FLAGS.save_graph_embed_path) if os.path.exists(dataset_path): dataset_text = dataset_item.TextGeoDataset.load( @@ -187,7 +177,7 @@ def main(argv): s2level=FLAGS.s2_level, model_type=FLAGS.model, n_fixed_points=FLAGS.n_fixed_points, - graph_embed_file=graph_embed_file) + graph_embed_path=FLAGS.graph_embed_path) if not os.path.exists(dataset_model_path): os.mkdir(dataset_model_path) @@ -201,13 +191,7 @@ def main(argv): dataset_item.TextGeoDataset.save( dataset_text=dataset_text, dataset_path=dataset_path, - train_path_dataset=train_path_dataset, - valid_path_dataset=valid_path_dataset, - test_path_dataset=test_path_dataset, - label_to_cellid_path=label_to_cellid_path, - unique_cellid_path=unique_cellid_path, - tensor_cellid_path=tensor_cellid_path, - coord_to_cellid_path=coord_to_cellid_path) + graph_embed_size=dataset.graph_embed_size) n_cells = len(dataset_text.unique_cellids) logging.info("Number of unique cells: {}".format( @@ -231,7 +215,8 @@ def main(argv): device=device, is_distance_distribution=FLAGS.is_distance_distribution) elif 'T5' in FLAGS.model: run_model = models.S2GenerationModel( - dataset_text.coord_to_cellid, device=device, model_type=FLAGS.model) + dataset_text.coord_to_cellid, device=device, model_type=FLAGS.model, + vq_dim=dataset_text.graph_embed_size) elif FLAGS.model == 'Classification-Bert': run_model = models.ClassificationModel(n_cells, device=device) else: diff --git a/cabby/model/text/model_trainer_multitask.py b/cabby/model/text/model_trainer_multitask.py index ac5ae7d3..8132badc 100644 --- a/cabby/model/text/model_trainer_multitask.py +++ b/cabby/model/text/model_trainer_multitask.py @@ -86,6 +86,9 @@ "A path of a model the model to be fine tuned\ evaluated.") +flags.DEFINE_string("graph_embed_path", default="", + help="The path to the graph embedding.") + flags.DEFINE_integer( 'train_batch_size', default=4, help=('Batch size for training.')) @@ -152,7 +155,7 @@ def main(argv): 'cuda') if torch.cuda.is_available() else torch.device('cpu') run_model = models.S2GenerationModel( - dataset_valid_test.label_to_cellid ,device=device) + dataset_valid_test.label_to_cellid, device=device, vq_dim=dataset_train.graph_embed_size) run_model.to(device) diff --git a/cabby/model/text/models.py b/cabby/model/text/models.py index 04a14ae7..0c5ccbe3 100644 --- a/cabby/model/text/models.py +++ b/cabby/model/text/models.py @@ -20,6 +20,8 @@ import torch.nn as nn from transformers import DistilBertModel, DistilBertForSequenceClassification from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Model +from vector_quantize_pytorch import VectorQuantize +from transformers import ViTForMaskedImageModeling from typing import Dict, Sequence @@ -29,6 +31,7 @@ T5_TYPE = "t5-small" T5_DIM = 512 if T5_TYPE == "t5-small" else 768 BERT_TYPE = "distilbert-base-uncased" +N_TOKEN = 1024 criterion = nn.CosineEmbeddingLoss() @@ -46,7 +49,7 @@ def forward(self, text: Dict, is_print, *args ): sys.exit("Implement compute_loss function in model") - def predict(self, text, *args): + def predict(self, text, is_print, *args): sys.exit("Implement prediction function in model") @@ -87,7 +90,7 @@ def get_embed(self, text_feat, cellid): return text_embedding.shape[0], text_embedding, cellid_embedding - def predict(self, text, all_cells, *args): + def predict(self, text, is_print, all_cells, *args): batch = args[1] batch_dim, text_embedding_exp, cellid_embedding = self.get_embed(text, all_cells) cell_dim = cellid_embedding.shape[0] @@ -148,44 +151,73 @@ def __init__( self, label_to_cellid, device, - model_type='S2-Generation-T5' + model_type='S2-Generation-T5', + vq_dim=224 ): - GeneralModel.__init__(self, device) self.model = T5ForConditionalGeneration.from_pretrained(T5_TYPE) self.tokenizer = T5Tokenizer.from_pretrained(T5_TYPE) self.is_generation = True self.label_to_cellid = label_to_cellid self.model_type = model_type - self.max_size = len(str(len(label_to_cellid))) + self.decoder = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + self.num_patches = (self.decoder.config.image_size // self.decoder.config.patch_size) ** 2 + self.discriminator = Discriminator() + self.vq_dim = vq_dim + + self.vq = VectorQuantize( + dim=vq_dim, + codebook_size=N_TOKEN, + codebook_dim=16, + decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster + commitment_weight = 1. # the weight on the commitment loss + ) + if model_type not in ['S2-Generation-T5']: self.max_size = self.max_size * 100 self.quant = torch.quantization.QuantStub() + if self.vq_dim: + self.original_size_tokenizer = len(self.tokenizer) + logging.info(f"Size of tokenizer before resized: {self.original_size_tokenizer}") + + add_tokens = [f"GRAPH_{t}" for t in range(N_TOKEN)] + self.tokenizer.add_tokens(add_tokens) + self.model.resize_token_embeddings(len(self.tokenizer)) + logging.info(f"Resized tokenizer to: {len(self.tokenizer)}") + def forward(self, text, cellid, is_print, *args): + batch = args[0] input_ids, attention_mask, labels = self.get_input_output(batch, text) - output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True) + if self.vq_dim: + loss, input_ids = self.get_loss_for_graph_embed(batch, input_ids, labels) + + else: + loss = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True).loss if is_print: input_ids_decoded = self.tokenizer.batch_decode( input_ids, skip_special_tokens=True) logging.info(f"Actual input decoded: {input_ids_decoded[0]}") - return output.loss + output_ids_decoded = self.tokenizer.batch_decode( + labels, skip_special_tokens=True) + logging.info(f"Actual output decoded: {output_ids_decoded[0]}") + + return loss def get_embed(self, text, cellid): text_dim = text['input_ids'].shape[0] return text_dim, text, cellid - def predict(self, text, *args): + def predict(self, text, is_print, *args): - label_to_cellid = args[1] batch = args[-1] input_ids, attention_mask, _ = self.get_input_output(batch, text) @@ -198,34 +230,105 @@ def predict(self, text, *args): min_length=1, ) + if self.vq_dim: + graph_embed_start = batch['graph_embed_start'] + input_ids, _, _ = self.get_input_indices_for_embedding(graph_embed_start, input_ids) + output_sequences = self.model.generate( + input_ids=input_ids, + num_beams=2, + max_length=self.max_size + 2, + min_length=1, + ) + + add_tokens = [f"GRAPH_{t}" for t in range(N_TOKEN)] + self.tokenizer.add_tokens(add_tokens) + self.model.resize_token_embeddings(len(self.tokenizer)) + prediction = self.tokenizer.batch_decode( output_sequences, skip_special_tokens=True) + if is_print: + logging.info(f"Actual prediction decoded: {prediction[0]}") + prediction_cellids = [] for pred_raw in prediction: - pred = pred_raw.split(";")[0].replace(" ", "") - - if not pred.isdigit(): - pred = 0 - label_int = int(pred) - if label_int in label_to_cellid: - cell_id = label_to_cellid[label_int] + coord = pred_raw.split(";")[0].strip() + if coord in self.label_to_cellid: + cell_id = self.label_to_cellid[coord] else: - cell_id = label_to_cellid[0] + first_key = list(self.label_to_cellid.keys())[0] + cell_id = self.label_to_cellid[first_key] prediction_cellids.append(cell_id) prediction_coords = gutil.get_center_from_s2cellids(prediction_cellids) return prediction_coords + def get_vg(self, graph_embed): + quantized, indices, vq_loss = self.vq(graph_embed) # (1, 1024, 256), (1, 1024), (1) + + assert torch.max(indices) Dict[Text, float]: def predictions_to_points(preds: Sequence, label_to_cellid: Dict[int, int]) -> Sequence[Tuple[float, float]]: default_cell = list(label_to_cellid.values())[0] - cellids = [] + cellids = [] for label in preds: cellids.append(label_to_cellid[label] if label in label_to_cellid else default_cell) coords = util.get_center_from_s2cellids(cellids) @@ -220,7 +220,6 @@ def get_valid_cell_label(dict_lables: Dict[int, Any], cellid: int): def get_valid_graph_embed(gensim_dict_lables: Any, cellid: str): while cellid not in gensim_dict_lables.index_to_key: - logging.info(f"!!!!!!!!!!!!!!! cellid:{cellid} type: {type(cellid)}") cellid = str(util.neighbor_cellid(int(cellid), gensim_dict_lables.index_to_key)) return gensim_dict_lables[cellid] \ No newline at end of file diff --git a/run_all.sh b/run_all.sh index ee282225..efc5c707 100755 --- a/run_all.sh +++ b/run_all.sh @@ -33,7 +33,7 @@ echo "* graph embeddings *" echo "****************************************" GRAPH_EMBEDDING_PATH=$MAP_DIR/graph_embedding.pth -bazel-bin/cabby/data/metagraph/create_graph_embedding --region $REGION_NAME --s2_level 15 --s2_node_levels 15 --base_osm_map_filepath $MAP_DIR --save_embedding_path $GRAPH_EMBEDDING_PATH --num_walks 2 --walk_length 2 +bazel-bin/cabby/data/metagraph/create_graph_embedding --region $REGION_NAME --dimensions 224 --s2_level 15 --s2_node_levels 15 --base_osm_map_filepath $MAP_DIR --save_embedding_path $GRAPH_EMBEDDING_PATH --num_walks 2 --walk_length 2 echo "****************************************" echo "* models *" @@ -45,17 +45,16 @@ mkdir -p $OUTPUT_DIR_MODEL_RVS_FIXED_5 mkdir -p $OUTPUT_DIR_MODEL_HUMAN echo "* S2-Generation-T5-text-start-embedding-to-landmarks-dist - RVS DATA *" -bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-text-start-embedding-to-landmarks-dist --save_graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 --is_distance_distribution True +bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-text-start-embedding-to-landmarks-dist --graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 --is_distance_distribution True echo "* S2-Generation-T5-text-start-embedding-to-landmarks - RVS DATA *" -bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-text-start-embedding-to-landmarks --save_graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 - +bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-text-start-embedding-to-landmarks --graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 echo "* S2-Generation-T5-Warmup-cell-embed-to-cell-label - RVS DATA *" -bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-Warmup-cell-embed-to-cell-label --save_graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 +bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-Warmup-cell-embed-to-cell-label --graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 echo "* S2-Generation-T5-start-embedding-text-input - RVS DATA *" -bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-start-embedding-text-input --save_graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 +bazel-bin/cabby/model/text/model_trainer --data_dir $OUTPUT_DIR --dataset_dir $OUTPUT_DIR_MODEL_RVS --region $REGION_NAME --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model S2-Generation-T5-start-embedding-text-input --graph_embed_path $GRAPH_EMBEDDING_PATH --far_distance_threshold 10 echo "* S2-Generation-T5-Warmup-start-end-to-dist - RVS DATA *" @@ -110,7 +109,7 @@ echo "* Landmarks-NER-2-S2-Generation-T5-Warmup - RVS DATA bazel-bin/cabby/model/text/model_trainer --data_dir ~/cabby/cabby/model/text/dataSamples/rvs --dataset_dir $OUTPUT_DIR_MODEL_RVS --region Manhattan --s2_level 15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 --task RVS --model Landmarks-NER-2-S2-Generation-T5-Warmup echo "* multitask *" -bazel-bin/cabby/model/text/model_trainer_multitask --dataset_dir_train $OUTPUT_DIR_MODEL_RVS/S2-Generation-T5-start-text-input/15 --dataset_dir_train $OUTPUT_DIR_MODEL_RVS_FIXED_5/S2-Generation-T5-Warmup-start-end/15 --dataset_dir_train $OUTPUT_DIR_MODEL_RVS_FIXED_4/S2-Generation-T5-Warmup-start-end/15 --dataset_dir_test $OUTPUT_DIR_MODEL_HUMAN/S2-Generation-T5-Landmarks/15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 +bazel-bin/cabby/model/text/model_trainer_multitask --dataset_dir_train $OUTPUT_DIR_MODEL_RVS/S2-Generation-T5-text-start-embedding-to-landmarks/15 --dataset_dir_train $OUTPUT_DIR_MODEL_RVS/S2-Generation-T5-Warmup-cell-embed-to-cell-label/15 --dataset_dir_test $OUTPUT_DIR_MODEL_RVS/S2-Generation-T5-text-start-embedding-to-landmarks/15 --output_dir $OUTPUT_DIR_MODEL_RVS --num_epochs 1 echo "* Baseline *" bazel-bin/cabby/model/baselines --data_dir ~/cabby/cabby/model/text/dataSamples/human --metrics_dir $OUTPUT_DIR_MODEL_HUMAN --task human --region Manhattan