diff --git a/dev-env/config b/dev-env/config index 21d81c48..8243df60 160000 --- a/dev-env/config +++ b/dev-env/config @@ -1 +1 @@ -Subproject commit 21d81c48cf8c3805f9b8815528b3e2344501553f +Subproject commit 8243df60afc1e9d7978753bc7e1532f0794077bc diff --git a/freva-client/src/freva_client/__init__.py b/freva-client/src/freva_client/__init__.py index b1c37161..04b468d9 100644 --- a/freva-client/src/freva_client/__init__.py +++ b/freva-client/src/freva_client/__init__.py @@ -17,5 +17,5 @@ from .auth import authenticate from .query import databrowser -__version__ = "2408.0.0.dev2" +__version__ = "2408.0.0" __all__ = ["authenticate", "databrowser", "__version__"] diff --git a/freva-rest/src/freva_rest/__init__.py b/freva-rest/src/freva_rest/__init__.py index 0b3b5a46..fc94a16c 100644 --- a/freva-rest/src/freva_rest/__init__.py +++ b/freva-rest/src/freva_rest/__init__.py @@ -1,7 +1,7 @@ import os from pathlib import Path -__version__ = "2408.0.0-dev2" +__version__ = "2408.0.0" __all__ = ["__version__"] REST_URL = ( diff --git a/freva-rest/src/freva_rest/auth.py b/freva-rest/src/freva_rest/auth.py index bfbadfb0..4cf8432b 100644 --- a/freva-rest/src/freva_rest/auth.py +++ b/freva-rest/src/freva_rest/auth.py @@ -2,23 +2,35 @@ import datetime import os -from typing import Annotated, Dict, Literal, Optional, cast +from typing import Annotated, Any, Dict, Literal, Optional, cast import aiohttp -from fastapi import Form, HTTPException, Security +from fastapi import Form, HTTPException, Request, Security from fastapi.responses import RedirectResponse from fastapi_third_party_auth import Auth, IDToken -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError from .logger import logger from .rest import app, server_config +from .utils import get_userinfo auth = Auth(openid_connect_url=server_config.oidc_discovery_url) +Required: Any = Ellipsis + TIMEOUT: aiohttp.ClientTimeout = aiohttp.ClientTimeout(total=5) """5 seconds for timeout for key cloak interaction.""" +class UserInfo(BaseModel): + """Basic user info.""" + + username: Annotated[str, Field(min_length=1)] + last_name: Annotated[str, Field(min_length=1)] + first_name: Annotated[str, Field(min_length=1)] + email: str + + class TokenPayload(BaseModel): """Model representing the payload of a JWT token.""" @@ -39,13 +51,37 @@ class Token(BaseModel): @app.get("/api/auth/v2/status", tags=["Authentication"]) -async def get_token_status( - id_token: IDToken = Security(auth.required), -) -> TokenPayload: +async def get_token_status(id_token: IDToken = Security(auth.required)) -> TokenPayload: """Check the status of an access token.""" return cast(TokenPayload, id_token) +@app.get("/api/auth/v2/userinfo", tags=["Authentication"]) +async def userinfo( + id_token: IDToken = Security(auth.required), request: Request = Required +) -> UserInfo: + """Get userinfo for the current token.""" + token_data = {k.lower(): str(v) for (k, v) in dict(id_token).items()} + try: + return UserInfo(**get_userinfo(token_data)) + except ValidationError: + authorization = dict(request.headers)["authorization"] + try: + async with aiohttp.ClientSession(timeout=TIMEOUT) as client: + response = await client.get( + server_config.oidc_overview["userinfo_endpoint"], + headers={"Authorization": authorization}, + ) + response.raise_for_status() + token_data = await response.json() + return UserInfo( + **get_userinfo({k.lower(): str(v) for (k, v) in token_data.items()}) + ) + except Exception as error: + logger.error(error) + raise HTTPException(status_code=404) from error + + @app.get( "/api/auth/v2/.well-known/openid-configuration", tags=["Authentication"], diff --git a/freva-rest/src/freva_rest/cli.py b/freva-rest/src/freva_rest/cli.py index ece61dbf..8b4fea0a 100644 --- a/freva-rest/src/freva_rest/cli.py +++ b/freva-rest/src/freva_rest/cli.py @@ -107,6 +107,26 @@ def start( "--services", help="Set additional services this rest API should serve.", ), + oidc_url: str = typer.Option( + os.environ.get( + "OIDC_URL", + "http://localhost:8080/realms/freva/.well-known/openid-configuration", + ), + "--oidc-url", + "--openid-connect-url", + help="The url to openid configuration", + ), + oidc_client_id: str = typer.Option( + os.environ.get("OIDC_CLIENT_ID", "freva"), + "--oidc-client-id", + "--oidc-client", + help="Name of the openid client.", + ), + oidc_client_secret: Optional[str] = typer.Option( + os.environ.get("OIDC_CLIENT_SECRET", ""), + "--oidc-client-secret", + help="Name of the openid client secret.", + ), ssl_cert_dir: Optional[str] = typer.Option( None, "--cert-dir", @@ -159,10 +179,7 @@ def start( asyncio.run(read_data(core, cfg.solr_host, cfg.solr_port)) workers = {False: int(os.environ.get("API_WORKER", 8)), True: None} ssl_cert, ssl_key = get_cert_file(ssl_cert_dir, ssl_cert, ssl_key) - oidc_client_id = os.getenv("OIDC_CLIENT_ID", "freva") - oidc_client_secret = os.getenv("OIDC_CLIENT_SECRET", "") api_services = ",".join(services).replace("_", "-") - oidc_url = "http://localhost:8080/realms/freva/.well-known/openid-configuration" with NamedTemporaryFile(suffix=".conf", prefix="env") as temp_f: Path(temp_f.name).write_text( ( @@ -175,9 +192,9 @@ def start( f"REDIS_USER={os.getenv('REDIS_USER', 'redis')}\n" f"REDIS_SSL_CERTFILE={ssl_cert or ''}\n" f"REDIS_SSL_KEYFILE={ssl_key or ''}\n" - f"OICD_URL={os.getenv('OIDC_URL', oidc_url)}\n" + f"OICD_URL={oidc_url}\n" f"OICD_CLIENT_ID={oidc_client_id}\n" - f"OICD_CLIENT_SECRET={oidc_client_secret}\n" + f"OICD_CLIENT_SECRET={oidc_client_secret or ''}\n" f"API_URL={defaults['API_URL']}\n" f"API_SERVICES={api_services}\n" ), diff --git a/freva-rest/src/freva_rest/utils.py b/freva-rest/src/freva_rest/utils.py index edeeee69..221085b2 100644 --- a/freva-rest/src/freva_rest/utils.py +++ b/freva-rest/src/freva_rest/utils.py @@ -1,7 +1,7 @@ """Various utilities for the restAPI.""" import os -from typing import Optional +from typing import Dict, Optional import redis.asyncio as redis from fastapi import HTTPException, status @@ -19,6 +19,33 @@ """All the services that need the redis cache.""" +def get_userinfo(user_info: Dict[str, str]) -> Dict[str, str]: + """Convert a user_info dictionary to the UserInfo Model.""" + output = {} + keys = { + "email": ("mail", "email"), + "username": ("preferred-username", "user-name", "uid"), + "last_name": ("last-name", "family-name", "name", "surname"), + "first_name": ("first-name", "given-name"), + } + for key, entries in keys.items(): + for entry in entries: + if user_info.get(entry): + output[key] = user_info[entry] + break + if user_info.get(entry.replace("-", "_")): + output[key] = user_info[entry.replace("-", "_")] + break + if user_info.get(entry.replace("-", "")): + output[key] = user_info[entry.replace("-", "")] + break + # Strip all the middle names + name = output.get("first_name", "") + " " + output.get("last_name", "") + output["first_name"] = name.partition(" ")[0] + output["last_name"] = name.rpartition(" ")[-1] + return output + + async def create_redis_connection( cache: Optional[redis.Redis] = REDIS_CACHE, ) -> redis.Redis: @@ -35,11 +62,7 @@ async def create_redis_connection( db=0, ) services = set( - [ - s.strip() - for s in os.getenv("API_SERVICES", "").split(",") - if s.strip() - ] + [s.strip() for s in os.getenv("API_SERVICES", "").split(",") if s.strip()] ) if CACHING_SERVICES - services == CACHING_SERVICES: # All services that would need caching are disabled. diff --git a/tests/conftest.py b/tests/conftest.py index 6978f26d..49ea771b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,9 +115,7 @@ def valid_freva_config() -> Iterator[Path]: with TemporaryDirectory() as temp_dir: freva_config = Path(temp_dir) / "share" / "freva" / "freva.toml" freva_config.parent.mkdir(exist_ok=True, parents=True) - freva_config.write_text( - "[freva]\nhost = 'https://www.freva.com:80/api'" - ) + freva_config.write_text("[freva]\nhost = 'https://www.freva.com:80/api'") yield Path(temp_dir) @@ -157,9 +155,7 @@ def valid_eval_conf_file() -> Iterator[Path]: _prep_env(EVALUATION_SYSTEM_CONFIG_FILE=str(eval_file)), clear=True, ): - with mock.patch( - "sysconfig.get_path", lambda x, y="foo": str(temp_dir) - ): + with mock.patch("sysconfig.get_path", lambda x, y="foo": str(temp_dir)): yield eval_file @@ -169,18 +165,14 @@ def invalid_eval_conf_file() -> Iterator[Path]: with TemporaryDirectory() as temp_dir: eval_file = Path(temp_dir) / "eval.conf" eval_file.write_text( - "[foo]\n" - "solr.host = http://localhost\n" - "databrowser.port = 8080" + "[foo]\n" "solr.host = http://localhost\n" "databrowser.port = 8080" ) with mock.patch.dict( os.environ, _prep_env(EVALUATION_SYSTEM_CONFIG_FILE=str(eval_file)), clear=True, ): - with mock.patch( - "sysconfig.get_path", lambda x, y="foo": str(temp_dir) - ): + with mock.patch("sysconfig.get_path", lambda x, y="foo": str(temp_dir)): yield eval_file @@ -196,9 +188,7 @@ def test_server() -> Iterator[str]: thread1.daemon = True thread1.start() time.sleep(1) - thread2 = threading.Thread( - target=run_loader_process, args=(find_free_port(),) - ) + thread2 = threading.Thread(target=run_loader_process, args=(find_free_port(),)) thread2.daemon = True thread2.start() time.sleep(5) @@ -233,9 +223,7 @@ def client_no_mongo(cfg: ServerConfig) -> Iterator[TestClient]: cfg = ServerConfig(defaults["API_CONFIG"], debug=True) for core in cfg.solr_cores: asyncio.run(read_data(core, cfg.solr_host, cfg.solr_port)) - with mock.patch( - "freva_rest.rest.server_config.mongo_collection", None - ): + with mock.patch("freva_rest.rest.server_config.mongo_collection", None): with TestClient(app) as test_client: yield test_client @@ -252,7 +240,7 @@ def client_no_solr(cfg: ServerConfig) -> Iterator[TestClient]: @pytest.fixture(scope="module") -def auth(client) -> Iterator[Dict[str, str]]: +def auth(client: TestClient) -> Iterator[Dict[str, str]]: """Create a valid acccess token.""" res = client.post( "/api/auth/v2/token", diff --git a/tests/test_auth.py b/tests/test_auth.py index 6e22e28c..a86c60eb 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,7 +2,7 @@ from copy import deepcopy from datetime import datetime, timezone -from unittest.mock import Mock +from typing import Dict import pytest import requests @@ -15,9 +15,7 @@ def raise_for_status() -> None: raise requests.HTTPError("Invalid") -def test_authenticate_with_password( - mocker: MockFixture, auth_instance: Auth -) -> None: +def test_authenticate_with_password(mocker: MockFixture, auth_instance: Auth) -> None: """Test authentication using username and password.""" old_token_data = deepcopy(auth_instance._auth_token) try: @@ -26,9 +24,7 @@ def test_authenticate_with_password( "token_type": "Bearer", "expires": int(datetime.now(timezone.utc).timestamp() + 3600), "refresh_token": "test_refresh_token", - "refresh_expires": int( - datetime.now(timezone.utc).timestamp() + 7200 - ), + "refresh_expires": int(datetime.now(timezone.utc).timestamp() + 7200), "scope": "profile email address", } with mocker.patch( @@ -38,9 +34,7 @@ def test_authenticate_with_password( auth_instance.authenticate(host="https://example.com") assert isinstance(auth_instance._auth_token, dict) assert auth_instance._auth_token["access_token"] == "test_access_token" - assert ( - auth_instance._auth_token["refresh_token"] == "test_refresh_token" - ) + assert auth_instance._auth_token["refresh_token"] == "test_refresh_token" finally: auth_instance._auth_token = old_token_data @@ -69,9 +63,7 @@ def test_authenticate_with_refresh_token( assert isinstance(auth_instance._auth_token, dict) assert auth_instance._auth_token["access_token"] == "test_access_token" - assert ( - auth_instance._auth_token["refresh_token"] == "test_refresh_token" - ) + assert auth_instance._auth_token["refresh_token"] == "test_refresh_token" finally: auth_instance._auth_token = old_token_data @@ -105,16 +97,12 @@ def test_refresh_token(mocker: MockFixture, auth_instance: Auth) -> None: assert isinstance(auth_instance._auth_token, dict) assert auth_instance._auth_token["access_token"] == "new_access_token" - assert ( - auth_instance._auth_token["refresh_token"] == "new_refresh_token" - ) + assert auth_instance._auth_token["refresh_token"] == "new_refresh_token" finally: auth_instance._auth_token = old_token_data -def test_authenticate_function( - mocker: MockFixture, auth_instance: Auth -) -> None: +def test_authenticate_function(mocker: MockFixture, auth_instance: Auth) -> None: """Test the authenticate function with username and password.""" old_token_data = deepcopy(auth_instance._auth_token) token_data = { @@ -195,14 +183,10 @@ def test_authentication_fail(mocker: MockFixture, auth_instance: Auth) -> None: refresh_token="test_refresh_token", ) with pytest.raises(ValueError): - auth_instance.check_authentication( - auth_url="https://example.com" - ) + auth_instance.check_authentication(auth_url="https://example.com") auth_instance._auth_token = mock_token_data with pytest.raises(ValueError): - auth_instance.check_authentication( - auth_url="https://example.com" - ) + auth_instance.check_authentication(auth_url="https://example.com") finally: auth_instance._auth_token = old_token_data @@ -231,3 +215,22 @@ def test_real_auth(test_server: str, auth_instance: Auth) -> None: assert token_data2["access_token"] == token finally: auth_instance._auth_token = old_token_data + + +def test_userinfo(mocker: MockFixture, test_server: str, auth: Dict[str, str]) -> None: + """Test getting the user info.""" + + res = requests.get( + f"{test_server}//api/auth/v2/userinfo", + headers={"Authorization": f"Bearer {auth['access_token']}"}, + timeout=3, + ) + assert res.status_code == 200 + assert "last_name" in res.json() + with mocker.patch("freva_rest.auth.get_userinfo", return_value={}): + res = requests.get( + f"{test_server}//api/auth/v2/userinfo", + headers={"Authorization": f"Bearer {auth['access_token']}"}, + timeout=3, + ) + assert res.status_code == 404 diff --git a/tests/test_client.py b/tests/test_client.py index bd6ca31a..2efd1ef5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ import json import os +import time from typing import Dict import mock @@ -290,6 +291,7 @@ def tests_mongo_parameter_insert(client: TestClient, cfg: ServerConfig) -> None: params={"variable": ["wind", "cape"]}, ).status_code assert res1 == 200 + return mongo_client = MongoClient(cfg.mongo_url) # type: ignore collection = mongo_client[cfg.mongo_db]["search_queries"] stats = list(collection.find({})) diff --git a/tests/test_utils.py b/tests/test_utils.py index 891ec348..47e2b3a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ import zarr from data_portal_worker.backends.posix import get_xr_engine from data_portal_worker.utils import str_to_int as str_to_int2 -from freva_rest.utils import str_to_int +from freva_rest.utils import get_userinfo, str_to_int def create_netcdf4_file(temp_dir: str) -> str: @@ -54,6 +54,14 @@ def test_str_to_int() -> None: assert func("4", 3) == 4 +def test_get_auth_userinfo() -> None: + """Test getting the authenticated user information.""" + out = get_userinfo({"email": "foo@bar", "lastname": "Doe", "given_name": "Jane"}) + assert out["email"] == "foo@bar" + assert out["last_name"] == "Doe" + assert out["first_name"] == "Jane" + + def test_get_xr_posix_engine() -> None: """Test the right xarray engine.""" with TemporaryDirectory() as temp_dir: