Skip to content

Commit

Permalink
Merge pull request #114 from googleinterns/warmup-model
Browse files Browse the repository at this point in the history
Warmup model
  • Loading branch information
tzufgoogle authored May 9, 2022
2 parents 3200a68 + 7ba866f commit 3f45e47
Show file tree
Hide file tree
Showing 27 changed files with 967 additions and 504 deletions.
164 changes: 164 additions & 0 deletions app/notebooks/join_data.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion app/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
PIVOTS_COLORS = {"end_point":'red', "start_point":'green'}
ICON_TAGS = ['tourism','amenity', 'shop', 'leisure']

icon_dir = os.path.abspath("./static/osm_icons")
dirname = os.path.dirname(__file__)
icon_dir = os.path.join(dirname, "./static/osm_icons")


onlyfiles = [
f for f in os.listdir(icon_dir) if os.path.isfile(os.path.join(icon_dir, f))]
Expand Down
139 changes: 139 additions & 0 deletions cabby/evals/plot.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions cabby/evals/testdata/sample_test_evals_real.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0-0 40.73999276636397 -73.99324488535609 40.74374243749733 -73.98208530470956
0-1 40.721682513136415 -74.00469558549776 40.717490300684155 -73.97684515379727
3 changes: 3 additions & 0 deletions cabby/geo/geo_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class RVSSample:
geo_landmarks: Dict[str, gpd.GeoDataFrame] = attr.ib()
geo_features: Dict[str, Any] = attr.ib()
route_len: int = attr.ib()
route: list = attr.ib()
instructions: str = attr.ib()
id: int = attr.ib()
version: float = attr.ib()
Expand All @@ -169,6 +170,7 @@ def to_rvs_sample(self,
landmark_list,
geo_entity.geo_features,
route_length,
geo_entity.route.coords[:],
instructions,
id,
VERSION,
Expand Down Expand Up @@ -205,6 +207,7 @@ def save(entities: Sequence[GeoEntity], path_to_save: str):
mode = 'w'

for geo_type, pivots_gdf in geo_types_all.items():
pivots_gdf = pivots_gdf.set_crs('epsg:4326')
pivots_gdf.to_file(
path_to_save, layer=geo_type, mode=mode, driver=_Geo_DataFrame_Driver)

Expand Down
27 changes: 22 additions & 5 deletions cabby/geo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def get_line_length(line: LineString) -> float:
return dist


def point_from_list_coord(coord: Sequence) -> Point:
def point_from_list_coord_yx(coord: Sequence) -> Point:
'''Converts coordinates in list format (latitude and longtitude) to Point.
E.g, of list [40.715865, -74.037258].
Arguments:
Expand All @@ -519,16 +519,33 @@ def point_from_list_coord(coord: Sequence) -> Point:

return Point(lon, lat)

def point_from_list_coord_xy(coord: Sequence) -> Point:
'''Converts coordinates in list format (longtitude and latitude) to Point.
E.g, of list [-74.037258, 40.715865].
Arguments:
coord: A lat-lng coordinate to be converted to a point.
Returns:
A point.
'''
lat = coord[1]
lon = coord[0]

return Point(lon, lat)


def point_from_str_coord_yx(coord_str: Text) -> Point:
'''Converts coordinates in string format (latitude and longtitude) to Point.
E.g, of string '(40.715865, -74.037258)' or 'POINT(40.715865 -74.037258)'.
E.g, of string '(40.715865, -74.037258)' or '[40.715865, -74.037258]' or 'POINT(40.715865 -74.037258)'.
Arguments:
coord: A lat-lng coordinate to be converted to a point.
Returns:
A point.
'''
list_coords_str = coord_str.replace("POINT", "").replace("(", "").replace(")", "").split(',')
coord_str = coord_str.replace("POINT", "").replace("(", "").replace(")", "")
coord_str = coord_str.replace("[", "").replace("]", "")

list_coords_str = coord_str.split(',')

if len(list_coords_str)==1:
list_coords_str = list_coords_str[0].split(' ')

Expand All @@ -540,9 +557,9 @@ def point_from_str_coord_yx(coord_str: Text) -> Point:

def point_from_str_coord_xy(coord_str: Text) -> Point:
'''Converts coordinates in string format (latitude and longtitude) to Point.
E.g, of string '(40.715865, -74.037258)' or 'POINT(40.715865 -74.037258)'.
E.g, of string '(-74.037258, 40.715865)' or '[-74.037258, 40.715865]' or 'POINT(-74.037258 40.715865)'.
Arguments:
coord: A lat-lng coordinate to be converted to a point.
coord: A lng-lat coordinate to be converted to a point.
Returns:
A point.
'''
Expand Down
2 changes: 2 additions & 0 deletions cabby/geo/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,9 +1097,11 @@ def generate_and_save_rvs_routes(self,
def load_entities(path: str) -> Sequence[geo_item.GeoEntity]:
if not os.path.exists(path):
return []

geo_types_all = {}
for landmark_type in LANDMARK_TYPES:
geo_types_all[landmark_type] = gpd.read_file(path, layer=landmark_type)

geo_types_all['route'] = gpd.read_file(path, layer='path_features')['geometry']
geo_types_all['path_features'] = gpd.read_file(path, layer='path_features')
geo_entities = []
Expand Down
74 changes: 64 additions & 10 deletions cabby/model/dataset_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,20 @@ class TextGeoSplit(torch.utils.data.Dataset):
"""
def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: int,
unique_cells_df: pd.DataFrame, cellid_to_label: Dict[int, int],
dprob: util.DistanceProbability, is_dist: Boolean = False):
model_type: str, dprob: util.DistanceProbability, is_dist: Boolean = False):

self.text_tokenizer = text_tokenizer
self.s2_tokenizer = s2_tokenizer

self.is_dist = is_dist

data = data.assign(point=data.end_point)
data = data.assign(end_point=data.end_point)

data['cellid'] = data.point.apply(

data['cellid'] = data.end_point.apply(
lambda x: gutil.cellid_from_point(x, s2level))


data['neighbor_cells'] = data.cellid.apply(
lambda x: gutil.neighbor_cellid(x))

Expand All @@ -149,11 +151,12 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in
data['far_cells'] = data.cellid.apply(
lambda cellid: unique_cells_df[unique_cells_df['cellid']==cellid].far.iloc[0])


cellids_array = np.array(data.cellid.tolist())
neighbor_cells_array = np.array(data.neighbor_cells.tolist())
far_cells_array = np.array(data.far_cells.tolist())

self.points = data.point.apply(
self.end_point = data.end_point.apply(
lambda x: gutil.tuple_from_point(x)).tolist()

self.labels = data.cellid.apply(lambda x: cellid_to_label[x]).tolist()
Expand All @@ -164,6 +167,48 @@ def __init__(self, text_tokenizer, s2_tokenizer, data: pd.DataFrame, s2level: in
self.far_cells = self.s2_tokenizer(far_cells_array)


if 'T5' in model_type and 'landmarks' in data:
data['landmarks'] = data.landmarks.apply(
lambda l: [gutil.cellid_from_point(x, s2level) for x in l])

self.landmarks = self.s2_tokenizer(data.landmarks.tolist())


else:
self.landmarks = [0] * len(self.cellids)
logging.warning("Landmarks not processed")


if 'T5' in model_type and 'route' in data:
data['route'] = data.route.apply(
lambda l: [gutil.cellid_from_point(x, s2level) for x in l])
route_array = np.array(data.route.tolist())
self.route = self.s2_tokenizer(route_array)

data['route_fixed'] = data.route_fixed.apply(
lambda l: [gutil.cellid_from_point(x, s2level) for x in l])
self.route_fixed = self.s2_tokenizer(data.route_fixed.tolist())

start_point_cells = data.start_point.apply(
lambda x: gutil.cellid_from_point(x, s2level))


start_point_list = [
'; '.join(
[str(cellid_to_label[e]), str(cellid_to_label[s])]) for s, e in zip(
start_point_cells.tolist(), data.cellid.tolist())]
self.start_end = self.text_tokenizer(
start_point_list, truncation=True, padding=True, add_special_tokens=True)

else:
self.route = [0] * len(self.cellids)
self.route_fixed = [0] * len(self.cellids)
self.start_end = {
'attention_mask': [0] * len(self.cellids),
'input_ids': [0] * len(self.cellids)}
logging.warning("Route not processed")


def __getitem__(self, idx: int):
'''Supports indexing such that TextGeoDataset[i] can be used to get
i-th sample.
Expand All @@ -175,18 +220,29 @@ def __getitem__(self, idx: int):
'''
text = {key: torch.tensor(val[idx])
for key, val in self.encodings.items()}
cellid = self.cellids[idx]

cellid = torch.tensor(self.cellids[idx])
landmarks = torch.tensor(self.landmarks[idx])

route = torch.tensor(self.route[idx])
route_fixed = torch.tensor(self.route_fixed[idx])

start_end = {key: torch.tensor(val[idx])
for key, val in self.start_end.items()}

neighbor_cells = torch.tensor(self.neighbor_cells[idx])
far_cells = torch.tensor(self.far_cells[idx])
point = torch.tensor(self.points[idx])
end_point = torch.tensor(self.end_point[idx])
label = torch.tensor(self.labels[idx])
if self.is_dist:
prob = torch.tensor(self.prob[idx])
else:
prob = torch.tensor([])

sample = {'text': text, 'cellid': cellid, 'neighbor_cells': neighbor_cells,
'far_cells': far_cells, 'point': point, 'label': label, 'prob': prob}
'far_cells': far_cells, 'end_point': end_point, 'label': label, 'prob': prob,
'landmarks': landmarks, 'route': route, 'route_fixed': route_fixed,
'start_end_input_ids': start_end['input_ids'], 'start_end_attention_mask': start_end['attention_mask']}

return sample

Expand All @@ -195,6 +251,4 @@ def __len__(self):

def calc_dist(start, unique_cells_df):
return unique_cells_df.swifter.apply(
lambda end: gutil.get_distance_between_points(start, end.point), axis=1)


lambda end: gutil.get_distance_between_points(start, end.point), axis=1)
Loading

0 comments on commit 3f45e47

Please sign in to comment.