From 2bf7e004bdc0d5d6dc62feab790404f3689e67b5 Mon Sep 17 00:00:00 2001 From: Dmytro Trotsko Date: Wed, 13 Sep 2023 22:52:27 +0300 Subject: [PATCH] Added geo based restricted endpoint. Resulting data is based on user's role, so user can see only role allowed data (county based) --- .../server/test_differentiated_access.py | 118 +++++++++++++ src/server/endpoints/__init__.py | 2 + src/server/endpoints/differentiated_access.py | 159 ++++++++++++++++++ 3 files changed, 279 insertions(+) create mode 100644 integrations/server/test_differentiated_access.py create mode 100644 src/server/endpoints/differentiated_access.py diff --git a/integrations/server/test_differentiated_access.py b/integrations/server/test_differentiated_access.py new file mode 100644 index 000000000..c053c62f2 --- /dev/null +++ b/integrations/server/test_differentiated_access.py @@ -0,0 +1,118 @@ +import requests + +# third party +import mysql.connector + +# frirst party +from delphi.epidata.acquisition.covidcast.test_utils import ( + CovidcastBase, + CovidcastTestRow, +) + + +class DifferentiatedAccessTests(CovidcastBase): + def localSetUp(self): + """Perform per-test setup""" + self._db._cursor.execute( + 'update covidcast_meta_cache set timestamp = 0, epidata = "[]"' + ) + + def setUp(self): + # connect to the `epidata` database + + super().setUp() + + self.maxDiff = None + + cnx = mysql.connector.connect( + user="user", + password="pass", + host="delphi_database_epidata", + database="epidata", + ) + + cur = cnx.cursor() + + cur.execute("DELETE FROM `api_user`") + cur.execute("TRUNCATE TABLE `user_role`") + cur.execute("TRUNCATE TABLE `user_role_link`") + + cur.execute( + 'INSERT INTO `api_user`(`api_key`, `email`) VALUES ("api_key", "api_key@gmail.com")' + ) + cur.execute( + 'INSERT INTO `api_user`(`api_key`, `email`) VALUES("ny_key", "ny_key@gmail.com")' + ) + cur.execute('INSERT INTO `user_role`(`name`) VALUES("state:ny")') + cur.execute( + 'INSERT INTO `user_role_link`(`user_id`, `role_id`) SELECT `api_user`.`id`, 1 FROM `api_user` WHERE `api_key` = "ny_key"' + ) + + cnx.commit() + cur.close() + cnx.close() + + def request_based_on_row(self, row: CovidcastTestRow, **kwargs): + params = self.params_from_row(row, endpoint="differentiated_access", **kwargs) + # use local instance of the Epidata API + + response = requests.get( + "http://delphi_web_epidata/epidata/api.php", params=params + ) + response.raise_for_status() + return response.json() + + def _insert_placeholder_restricted_geo(self): + geo_values = ["36029", "36047", "36097", "36103", "36057", "36041", "36033"] + rows = [ + CovidcastTestRow.make_default_row( + source="restricted-source", + geo_type="county", + geo_value=geo_values[i], + time_value=2000_01_01 + i, + value=i * 1.0, + stderr=i * 10.0, + sample_size=i * 100.0, + issue=2000_01_03, + lag=2 - i, + ) + for i in [1, 2, 3] + ] + [ + # time value intended to overlap with the time values above, with disjoint geo values + CovidcastTestRow.make_default_row( + source="restricted-source", + geo_type="county", + geo_value=geo_values[i], + time_value=2000_01_01 + i - 3, + value=i * 1.0, + stderr=i * 10.0, + sample_size=i * 100.0, + issue=2000_01_03, + lag=5 - i, + ) + for i in [4, 5, 6] + ] + self._insert_rows(rows) + return rows + + def test_restricted_geo_ny_role(self): + # insert placeholder data + rows = self._insert_placeholder_restricted_geo() + + # make request + response = self.request_based_on_row(rows[0], token="ny_key") + expected = { + "result": 1, + "epidata": [rows[0].as_api_compatibility_row_dict()], + "message": "success", + } + self.assertEqual(response, expected) + + def test_restricted_geo_default_role(self): + # insert placeholder data + rows = self._insert_placeholder_restricted_geo() + + # make request + response = self.request_based_on_row(rows[0], token="api_key") + expected = {"result": -2, "message": "no results"} + self.assertEqual(response, expected) diff --git a/src/server/endpoints/__init__.py b/src/server/endpoints/__init__.py index 94f1de5b8..cecb40819 100644 --- a/src/server/endpoints/__init__.py +++ b/src/server/endpoints/__init__.py @@ -31,6 +31,7 @@ wiki, signal_dashboard_status, signal_dashboard_coverage, + differentiated_access ) endpoints = [ @@ -66,6 +67,7 @@ wiki, signal_dashboard_status, signal_dashboard_coverage, + differentiated_access ] __all__ = ["endpoints"] diff --git a/src/server/endpoints/differentiated_access.py b/src/server/endpoints/differentiated_access.py new file mode 100644 index 000000000..92ed57e5d --- /dev/null +++ b/src/server/endpoints/differentiated_access.py @@ -0,0 +1,159 @@ +from flask import Blueprint +from werkzeug.exceptions import Unauthorized + +from .._common import is_compatibility_mode +from .._params import ( + extract_date, + extract_dates, + extract_integer, + parse_geo_sets, + parse_source_signal_sets, + parse_time_set, +) +from .._query import QueryBuilder, execute_query +from .._security import current_user, sources_protected_by_roles +from .covidcast_utils.model import create_source_signal_alias_mapper +from delphi_utils import GeoMapper +from delphi.epidata.common.logger import get_structured_logger + +# first argument is the endpoint name +bp = Blueprint("differentiated_access", __name__) +alias = None + +latest_table = "epimetric_latest_v" +history_table = "epimetric_full_v" + + +def restrict_by_roles(source_signal_sets): + # takes a list of SourceSignalSet objects + # and returns only those from the list + # that the current user is permitted to access. + user = current_user + allowed_source_signal_sets = [] + for src_sig_set in source_signal_sets: + src = src_sig_set.source + if src in sources_protected_by_roles: + role = sources_protected_by_roles[src] + if user and user.has_role(role): + allowed_source_signal_sets.append(src_sig_set) + else: + # protected src and user does not have permission => leave it out of the srcsig sets + get_structured_logger("covcast_endpt").warning( + "non-authZd request for restricted 'source'", + api_key=(user and user.api_key), + src=src, + ) + else: + allowed_source_signal_sets.append(src_sig_set) + return allowed_source_signal_sets + + +def serve_geo_restricted(geo_sets: set): + geomapper = GeoMapper() + if not current_user: + raise Unauthorized("User is not authenticated.") + allowed_counties = set() + # Getting allowed counties set from user's roles. + # Example: role 'state:ny' will give user access to all counties in ny state. + for role in current_user.roles: + if role.name.startswith("state:"): + state = role.name.split(":", 1)[1] + counties_in_state = geomapper.get_geos_within(state, "county", "state") + allowed_counties.update(counties_in_state) + + for geo_set in geo_sets: + # Reject if `geo_type` is not county. + if geo_set.geo_type != "county": + raise Unauthorized("Only `county` geo_type is allowed") + # If `geo_value` = '*' then we want to query only that counties that user has access to. + if geo_set.geo_values is True: + geo_set.geo_values = list(allowed_counties) + # Actually we don't need to check whether `geo_set.geo_values` (user requested counties) is a superset of `allowed_counties` + # We do want to return set of counties that are in both `geo_set.geo_values` and `allowed_counties` + # Because if user requested less -> we will get only requested list of counties, in other case (user requested more + # than he can get -> he will get only that counties that he is allowed to). + + # elif set(geo_set.geo_values).issuperset(allowed_counties): + # geo_set.geo_values = list(set(geo_set.geo_values).intersection(allowed_counties)) + + # If user provided more counties that he is able to query, then we want to show him only + # that counties that he is allowed to. + else: + geo_set.geo_values = list( + set(geo_set.geo_values).intersection(allowed_counties) + ) + return geo_sets + + +@bp.route("/", methods=("GET", "POST")) +def handle(): + source_signal_sets = parse_source_signal_sets() + source_signal_sets = restrict_by_roles(source_signal_sets) + source_signal_sets, alias_mapper = create_source_signal_alias_mapper( + source_signal_sets + ) + time_set = parse_time_set() + geo_sets = serve_geo_restricted(parse_geo_sets()) + + as_of = extract_date("as_of") + issues = extract_dates("issues") + lag = extract_integer("lag") + + # build query + q = QueryBuilder(latest_table, "t") + + fields_string = ["geo_value", "signal"] + fields_int = [ + "time_value", + "direction", + "issue", + "lag", + "missing_value", + "missing_stderr", + "missing_sample_size", + ] + fields_float = ["value", "stderr", "sample_size"] + is_compatibility = is_compatibility_mode() + if is_compatibility: + q.set_sort_order("signal", "time_value", "geo_value", "issue") + else: + # transfer also the new detail columns + fields_string.extend(["source", "geo_type", "time_type"]) + q.set_sort_order( + "source", + "signal", + "time_type", + "time_value", + "geo_type", + "geo_value", + "issue", + ) + q.set_fields(fields_string, fields_int, fields_float) + + # basic query info + # data type of each field + # build the source, signal, time, and location (type and id) filters + + q.apply_source_signal_filters("source", "signal", source_signal_sets) + q.apply_geo_filters("geo_type", "geo_value", geo_sets) + q.apply_time_filter("time_type", "time_value", time_set) + + q.apply_issues_filter(history_table, issues) + q.apply_lag_filter(history_table, lag) + q.apply_as_of_filter(history_table, as_of) + + def transform_row(row, proxy): + if is_compatibility or not alias_mapper or "source" not in row: + return row + row["source"] = alias_mapper(row["source"], proxy["signal"]) + return row + + # send query + return execute_query( + str(q), + q.params, + fields_string, + fields_int, + fields_float, + transform=transform_row, + )