Skip to content

Commit

Permalink
multi task with prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
tzufgoogle committed May 9, 2022
1 parent 3f45e47 commit b9f49d5
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 36 deletions.
26 changes: 22 additions & 4 deletions cabby/model/dataset_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,25 @@ def from_TextGeoSplit(cls, train, valid, test, unique_cellids,
)

@classmethod
def load(cls, dataset_path: Text, train_path_dataset: Text,
valid_path_dataset: Text, test_path_dataset: Text,
unique_cellid_path: Text, tensor_cellid_path: Text,
def load(cls, dataset_dir: Text, model_type: Text,
s2_level: Text, unique_cellid_path: Text, tensor_cellid_path: Text,
label_to_cellid_path: Text):

dataset_model_path = os.path.join(dataset_dir, str(model_type))
dataset_path = os.path.join(dataset_model_path, str(s2_level))
train_path_dataset = os.path.join(dataset_path,'train.pth')
valid_path_dataset = os.path.join(dataset_path,'valid.pth')
test_path_dataset = os.path.join(dataset_path,'test.pth')
unique_cellid_path = os.path.join(dataset_path,"unique_cellid.npy")
tensor_cellid_path = os.path.join(dataset_path,"tensor_cellid.pth")
label_to_cellid_path = os.path.join(dataset_path,"label_to_cellid.npy")

logging.info("Loading dataset from <== {}.".format(dataset_path))
train_dataset = torch.load(train_path_dataset)
valid_dataset = torch.load(valid_path_dataset)
test_dataset = torch.load(test_path_dataset)
logging.info(f"Size of train set: {len(train_dataset)}" +
f", Size of validation set: {len(valid_dataset)}, Size of test set: {len(test_dataset)}")

unique_cellid = np.load(unique_cellid_path, allow_pickle='TRUE')
label_to_cellid = np.load(
Expand Down Expand Up @@ -144,8 +154,16 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in


# Tokenize instructions.

instruction_list = data.instructions.tolist()
if 'T5' in model_type:
# Add prompt
instruction_list = [model_type + ": " + t for t in instruction_list]

logging.info(f"An example of the text encoded: '{instruction_list[0]}'")

self.encodings = self.text_tokenizer(
data.instructions.tolist(), truncation=True,
instruction_list, truncation=True,
padding=True, add_special_tokens=True)

data['far_cells'] = data.cellid.apply(
Expand Down
24 changes: 14 additions & 10 deletions cabby/model/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ def process_route(self, route_str):
gutil.point_from_str_coord_xy(landmark_str) for landmark_str in ladmarks_str_list]


def process_landmarks(self, landmarks_str_one_line):
ladmarks_str_list = landmarks_str_one_line.split(';')
return [gutil.point_from_str_coord_yx(
landmark_str.split(':')[-1]) for landmark_str in ladmarks_str_list]
def process_landmarks(self, row):
points = [row['end_point'], row['start_point'], row['main_pivot'], row['near_pivot']]
return points

def get_specific_landmark(self, landmarks_str_one_line, landmark_name):

Expand All @@ -118,7 +117,7 @@ def get_specific_landmark(self, landmarks_str_one_line, landmark_name):
return landmark_found


def create_dataset(self, infer_only: bool = False
def create_dataset(self, infer_only: bool = False,
) -> dataset_item.TextGeoDataset:
'''Loads data and creates datasets and train, validate and test sets.
Returns:
Expand Down Expand Up @@ -219,13 +218,13 @@ def load_data(self, data_dir: str, ds_set: str, lines: bool):
ds['main_pivot'] = ds.landmarks.apply(
lambda x: self.get_specific_landmark(x, 'main_pivot'))

ds['landmarks'] = ds.landmarks.apply(self.process_landmarks)
ds['landmarks'] = ds.apply(self.process_landmarks, axis=1)

if 'route' in ds:
ds['route'] = ds.route.apply(self.process_route)
ds['route_fixed'] = ds.route.apply(self.get_fixed_point_along_route)

ds['start_end'] = ds.route.apply(self.get_fixed_point_along_route)
ds['start_end'] = ds.route.apply(self.get_fixed_point_along_route)
columns_keep = ds.columns.difference(
[
'instructions',
Expand Down Expand Up @@ -353,9 +352,14 @@ def load_data(self, data_dir: str, split: str, lines: bool):
return ds

def process_landmarks(self, landmarks_dict):
ladmarks_list = list(landmarks_dict.values())
return [gutil.point_from_list_coord_yx(
landmark_l[-1]) for landmark_l in ladmarks_list if landmark_l[-1]]
landmarks_corrds = [
landmarks_dict['end_point'][-1],
landmarks_dict['start_point'][-1],
landmarks_dict['main_pivot'][-1],
landmarks_dict['near_pivot'][-1]]
points = [gutil.point_from_list_coord_yx(
coord) for coord in landmarks_corrds]
return points

def process_route(self, route_list):
return [
Expand Down
16 changes: 16 additions & 0 deletions cabby/model/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ py_binary(
],
)

py_binary(
name = 'model_trainer_multitask',
main = 'model_trainer_multitask.py',
srcs = ['model_trainer_multitask.py'],
deps = [
'//cabby/model/text:train',
'//cabby/model:datasets',
'//cabby/model:dataset_item',
"//cabby/model:util",
"//cabby/geo:util",
':models'

],
)



py_binary(
name = 'models',
Expand Down
54 changes: 34 additions & 20 deletions cabby/model/text/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
flags.DEFINE_string("model_path", None,
"A path of a model the model to be fine tuned\ evaluated.")


flags.DEFINE_integer(
'train_batch_size', default=4,
help=('Batch size for training.'))
Expand All @@ -107,6 +108,10 @@
'infer_only', default=False,
help=('Train and infer\ just infer.'))

flags.DEFINE_bool(
'is_single_sample_train', default=False,
help=('Train on a single sample and do not evaluate.'))


flags.DEFINE_bool(
'is_val_loss_from_model', default=False,
Expand Down Expand Up @@ -150,23 +155,26 @@ def main(argv):
else:
sys.exit("Dataset invalid")

dataset = dataset_init(
data_dir = FLAGS.data_dir,
region = FLAGS.region,
s2level = FLAGS.s2_level,
model_type = FLAGS.model)


if FLAGS.is_single_sample_train:
FLAGS.train_batch_size = 1

if os.path.exists(dataset_path):
dataset_text = dataset_item.TextGeoDataset.load(
dataset_path = dataset_path,
train_path_dataset = train_path_dataset,
valid_path_dataset = valid_path_dataset,
test_path_dataset = test_path_dataset,
dataset_dir = FLAGS.dataset_dir,
model_type = str(FLAGS.model),
s2_level = FLAGS.s2_level,
label_to_cellid_path = label_to_cellid_path,
unique_cellid_path = unique_cellid_path,
tensor_cellid_path = tensor_cellid_path)

else:
dataset = dataset_init(
data_dir = FLAGS.data_dir,
region = FLAGS.region,
s2level = FLAGS.s2_level,
model_type = FLAGS.model)

if not os.path.exists(dataset_model_path):
os.mkdir(dataset_model_path)
logging.info("Preparing data.")
Expand Down Expand Up @@ -204,13 +212,17 @@ def main(argv):
if 'Dual-Encoder' in FLAGS.model:
run_model = models.DualEncoder(device=device)
elif FLAGS.model == 'S2-Generation-T5':
run_model = models.S2GenerationModel(dataset_text.label_to_cellid, device=device)
run_model = models.S2GenerationModel(
dataset_text.label_to_cellid, device=device)
elif FLAGS.model == 'S2-Generation-T5-Landmarks':
run_model = models.S2GenerationModel(dataset_text.label_to_cellid, is_landmarks=True, device=device)
run_model = models.S2GenerationModel(
dataset_text.label_to_cellid, is_landmarks=True, device=device)
elif FLAGS.model == 'S2-Generation-T5-Path':
run_model = models.S2GenerationModel(dataset_text.label_to_cellid, is_path=True, device=device)
run_model = models.S2GenerationModel(
dataset_text.label_to_cellid, is_path=True, device=device)
elif FLAGS.model == 'S2-Generation-T5-Warmup-start-end':
run_model = models.S2GenerationModel(dataset_text.label_to_cellid, is_warmup_start_end=True, device=device)
run_model = models.S2GenerationModel(
dataset_text.label_to_cellid, is_warmup_start_end=True, device=device)
elif FLAGS.model == 'Classification-Bert':
run_model = models.ClassificationModel(n_cells, device=device)
else:
Expand Down Expand Up @@ -250,7 +262,8 @@ def main(argv):
cells_tensor = dataset_text.unique_cellids_binary,
label_to_cellid = dataset_text.label_to_cellid,
is_distance_distribution = FLAGS.is_distance_distribution,
best_valid_loss = run_model.best_valid_loss
best_valid_loss = run_model.best_valid_loss,
is_single_sample_train = FLAGS.is_single_sample_train
)
if FLAGS.infer_only:
logging.info("Starting to infer model.")
Expand All @@ -262,17 +275,18 @@ def main(argv):
true_points,
pred_points)

accuracy = accuracy_score(true_vals, predictions)

evaluator = eu.Evaluator()
error_distances = evaluator.get_error_distances(trainer.metrics_path)
_, mean_distance, median_distance, max_error, norm_auc = (
evaluator.compute_metrics(error_distances))

logging.info(f"\nTest Accuracy: {accuracy}, \n" +
f"Mean distance: {mean_distance},\nMedian distance: {median_distance},\n" +
f"Max error: {max_error},\nNorm AUC: {norm_auc}")

logging.info(f"\
Mean distance: {mean_distance}, \
Median distance: {median_distance}, \
Max error: {max_error}, \
Norm AUC: {norm_auc}")

else:
logging.info("Starting to train model.")
trainer.train_model()
Expand Down
Loading

0 comments on commit b9f49d5

Please sign in to comment.