Skip to content

Commit

Permalink
Merge pull request #61 from FREVA-CLINT/update-submodules
Browse files Browse the repository at this point in the history
Add userinfo endpoint.
  • Loading branch information
antarcticrainforest authored Aug 9, 2024
2 parents 9efc3f6 + 24df523 commit 6b9b824
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 65 deletions.
2 changes: 1 addition & 1 deletion dev-env/config
2 changes: 1 addition & 1 deletion freva-client/src/freva_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
from .auth import authenticate
from .query import databrowser

__version__ = "2408.0.0.dev2"
__version__ = "2408.0.0"
__all__ = ["authenticate", "databrowser", "__version__"]
2 changes: 1 addition & 1 deletion freva-rest/src/freva_rest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path

__version__ = "2408.0.0-dev2"
__version__ = "2408.0.0"
__all__ = ["__version__"]

REST_URL = (
Expand Down
48 changes: 42 additions & 6 deletions freva-rest/src/freva_rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,35 @@

import datetime
import os
from typing import Annotated, Dict, Literal, Optional, cast
from typing import Annotated, Any, Dict, Literal, Optional, cast

import aiohttp
from fastapi import Form, HTTPException, Security
from fastapi import Form, HTTPException, Request, Security
from fastapi.responses import RedirectResponse
from fastapi_third_party_auth import Auth, IDToken
from pydantic import BaseModel
from pydantic import BaseModel, Field, ValidationError

from .logger import logger
from .rest import app, server_config
from .utils import get_userinfo

auth = Auth(openid_connect_url=server_config.oidc_discovery_url)

Required: Any = Ellipsis

TIMEOUT: aiohttp.ClientTimeout = aiohttp.ClientTimeout(total=5)
"""5 seconds for timeout for key cloak interaction."""


class UserInfo(BaseModel):
"""Basic user info."""

username: Annotated[str, Field(min_length=1)]
last_name: Annotated[str, Field(min_length=1)]
first_name: Annotated[str, Field(min_length=1)]
email: str


class TokenPayload(BaseModel):
"""Model representing the payload of a JWT token."""

Expand All @@ -39,13 +51,37 @@ class Token(BaseModel):


@app.get("/api/auth/v2/status", tags=["Authentication"])
async def get_token_status(
id_token: IDToken = Security(auth.required),
) -> TokenPayload:
async def get_token_status(id_token: IDToken = Security(auth.required)) -> TokenPayload:
"""Check the status of an access token."""
return cast(TokenPayload, id_token)


@app.get("/api/auth/v2/userinfo", tags=["Authentication"])
async def userinfo(
id_token: IDToken = Security(auth.required), request: Request = Required
) -> UserInfo:
"""Get userinfo for the current token."""
token_data = {k.lower(): str(v) for (k, v) in dict(id_token).items()}
try:
return UserInfo(**get_userinfo(token_data))
except ValidationError:
authorization = dict(request.headers)["authorization"]
try:
async with aiohttp.ClientSession(timeout=TIMEOUT) as client:
response = await client.get(
server_config.oidc_overview["userinfo_endpoint"],
headers={"Authorization": authorization},
)
response.raise_for_status()
token_data = await response.json()
return UserInfo(
**get_userinfo({k.lower(): str(v) for (k, v) in token_data.items()})
)
except Exception as error:
logger.error(error)
raise HTTPException(status_code=404) from error


@app.get(
"/api/auth/v2/.well-known/openid-configuration",
tags=["Authentication"],
Expand Down
27 changes: 22 additions & 5 deletions freva-rest/src/freva_rest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,26 @@ def start(
"--services",
help="Set additional services this rest API should serve.",
),
oidc_url: str = typer.Option(
os.environ.get(
"OIDC_URL",
"http://localhost:8080/realms/freva/.well-known/openid-configuration",
),
"--oidc-url",
"--openid-connect-url",
help="The url to openid configuration",
),
oidc_client_id: str = typer.Option(
os.environ.get("OIDC_CLIENT_ID", "freva"),
"--oidc-client-id",
"--oidc-client",
help="Name of the openid client.",
),
oidc_client_secret: Optional[str] = typer.Option(
os.environ.get("OIDC_CLIENT_SECRET", ""),
"--oidc-client-secret",
help="Name of the openid client secret.",
),
ssl_cert_dir: Optional[str] = typer.Option(
None,
"--cert-dir",
Expand Down Expand Up @@ -159,10 +179,7 @@ def start(
asyncio.run(read_data(core, cfg.solr_host, cfg.solr_port))
workers = {False: int(os.environ.get("API_WORKER", 8)), True: None}
ssl_cert, ssl_key = get_cert_file(ssl_cert_dir, ssl_cert, ssl_key)
oidc_client_id = os.getenv("OIDC_CLIENT_ID", "freva")
oidc_client_secret = os.getenv("OIDC_CLIENT_SECRET", "")
api_services = ",".join(services).replace("_", "-")
oidc_url = "http://localhost:8080/realms/freva/.well-known/openid-configuration"
with NamedTemporaryFile(suffix=".conf", prefix="env") as temp_f:
Path(temp_f.name).write_text(
(
Expand All @@ -175,9 +192,9 @@ def start(
f"REDIS_USER={os.getenv('REDIS_USER', 'redis')}\n"
f"REDIS_SSL_CERTFILE={ssl_cert or ''}\n"
f"REDIS_SSL_KEYFILE={ssl_key or ''}\n"
f"OICD_URL={os.getenv('OIDC_URL', oidc_url)}\n"
f"OICD_URL={oidc_url}\n"
f"OICD_CLIENT_ID={oidc_client_id}\n"
f"OICD_CLIENT_SECRET={oidc_client_secret}\n"
f"OICD_CLIENT_SECRET={oidc_client_secret or ''}\n"
f"API_URL={defaults['API_URL']}\n"
f"API_SERVICES={api_services}\n"
),
Expand Down
35 changes: 29 additions & 6 deletions freva-rest/src/freva_rest/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Various utilities for the restAPI."""

import os
from typing import Optional
from typing import Dict, Optional

import redis.asyncio as redis
from fastapi import HTTPException, status
Expand All @@ -19,6 +19,33 @@
"""All the services that need the redis cache."""


def get_userinfo(user_info: Dict[str, str]) -> Dict[str, str]:
"""Convert a user_info dictionary to the UserInfo Model."""
output = {}
keys = {
"email": ("mail", "email"),
"username": ("preferred-username", "user-name", "uid"),
"last_name": ("last-name", "family-name", "name", "surname"),
"first_name": ("first-name", "given-name"),
}
for key, entries in keys.items():
for entry in entries:
if user_info.get(entry):
output[key] = user_info[entry]
break
if user_info.get(entry.replace("-", "_")):
output[key] = user_info[entry.replace("-", "_")]
break
if user_info.get(entry.replace("-", "")):
output[key] = user_info[entry.replace("-", "")]
break
# Strip all the middle names
name = output.get("first_name", "") + " " + output.get("last_name", "")
output["first_name"] = name.partition(" ")[0]
output["last_name"] = name.rpartition(" ")[-1]
return output


async def create_redis_connection(
cache: Optional[redis.Redis] = REDIS_CACHE,
) -> redis.Redis:
Expand All @@ -35,11 +62,7 @@ async def create_redis_connection(
db=0,
)
services = set(
[
s.strip()
for s in os.getenv("API_SERVICES", "").split(",")
if s.strip()
]
[s.strip() for s in os.getenv("API_SERVICES", "").split(",") if s.strip()]
)
if CACHING_SERVICES - services == CACHING_SERVICES:
# All services that would need caching are disabled.
Expand Down
26 changes: 7 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def valid_freva_config() -> Iterator[Path]:
with TemporaryDirectory() as temp_dir:
freva_config = Path(temp_dir) / "share" / "freva" / "freva.toml"
freva_config.parent.mkdir(exist_ok=True, parents=True)
freva_config.write_text(
"[freva]\nhost = 'https://www.freva.com:80/api'"
)
freva_config.write_text("[freva]\nhost = 'https://www.freva.com:80/api'")
yield Path(temp_dir)


Expand Down Expand Up @@ -157,9 +155,7 @@ def valid_eval_conf_file() -> Iterator[Path]:
_prep_env(EVALUATION_SYSTEM_CONFIG_FILE=str(eval_file)),
clear=True,
):
with mock.patch(
"sysconfig.get_path", lambda x, y="foo": str(temp_dir)
):
with mock.patch("sysconfig.get_path", lambda x, y="foo": str(temp_dir)):
yield eval_file


Expand All @@ -169,18 +165,14 @@ def invalid_eval_conf_file() -> Iterator[Path]:
with TemporaryDirectory() as temp_dir:
eval_file = Path(temp_dir) / "eval.conf"
eval_file.write_text(
"[foo]\n"
"solr.host = http://localhost\n"
"databrowser.port = 8080"
"[foo]\n" "solr.host = http://localhost\n" "databrowser.port = 8080"
)
with mock.patch.dict(
os.environ,
_prep_env(EVALUATION_SYSTEM_CONFIG_FILE=str(eval_file)),
clear=True,
):
with mock.patch(
"sysconfig.get_path", lambda x, y="foo": str(temp_dir)
):
with mock.patch("sysconfig.get_path", lambda x, y="foo": str(temp_dir)):
yield eval_file


Expand All @@ -196,9 +188,7 @@ def test_server() -> Iterator[str]:
thread1.daemon = True
thread1.start()
time.sleep(1)
thread2 = threading.Thread(
target=run_loader_process, args=(find_free_port(),)
)
thread2 = threading.Thread(target=run_loader_process, args=(find_free_port(),))
thread2.daemon = True
thread2.start()
time.sleep(5)
Expand Down Expand Up @@ -233,9 +223,7 @@ def client_no_mongo(cfg: ServerConfig) -> Iterator[TestClient]:
cfg = ServerConfig(defaults["API_CONFIG"], debug=True)
for core in cfg.solr_cores:
asyncio.run(read_data(core, cfg.solr_host, cfg.solr_port))
with mock.patch(
"freva_rest.rest.server_config.mongo_collection", None
):
with mock.patch("freva_rest.rest.server_config.mongo_collection", None):
with TestClient(app) as test_client:
yield test_client

Expand All @@ -252,7 +240,7 @@ def client_no_solr(cfg: ServerConfig) -> Iterator[TestClient]:


@pytest.fixture(scope="module")
def auth(client) -> Iterator[Dict[str, str]]:
def auth(client: TestClient) -> Iterator[Dict[str, str]]:
"""Create a valid acccess token."""
res = client.post(
"/api/auth/v2/token",
Expand Down
53 changes: 28 additions & 25 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from copy import deepcopy
from datetime import datetime, timezone
from unittest.mock import Mock
from typing import Dict

import pytest
import requests
Expand All @@ -15,9 +15,7 @@ def raise_for_status() -> None:
raise requests.HTTPError("Invalid")


def test_authenticate_with_password(
mocker: MockFixture, auth_instance: Auth
) -> None:
def test_authenticate_with_password(mocker: MockFixture, auth_instance: Auth) -> None:
"""Test authentication using username and password."""
old_token_data = deepcopy(auth_instance._auth_token)
try:
Expand All @@ -26,9 +24,7 @@ def test_authenticate_with_password(
"token_type": "Bearer",
"expires": int(datetime.now(timezone.utc).timestamp() + 3600),
"refresh_token": "test_refresh_token",
"refresh_expires": int(
datetime.now(timezone.utc).timestamp() + 7200
),
"refresh_expires": int(datetime.now(timezone.utc).timestamp() + 7200),
"scope": "profile email address",
}
with mocker.patch(
Expand All @@ -38,9 +34,7 @@ def test_authenticate_with_password(
auth_instance.authenticate(host="https://example.com")
assert isinstance(auth_instance._auth_token, dict)
assert auth_instance._auth_token["access_token"] == "test_access_token"
assert (
auth_instance._auth_token["refresh_token"] == "test_refresh_token"
)
assert auth_instance._auth_token["refresh_token"] == "test_refresh_token"
finally:
auth_instance._auth_token = old_token_data

Expand Down Expand Up @@ -69,9 +63,7 @@ def test_authenticate_with_refresh_token(

assert isinstance(auth_instance._auth_token, dict)
assert auth_instance._auth_token["access_token"] == "test_access_token"
assert (
auth_instance._auth_token["refresh_token"] == "test_refresh_token"
)
assert auth_instance._auth_token["refresh_token"] == "test_refresh_token"
finally:
auth_instance._auth_token = old_token_data

Expand Down Expand Up @@ -105,16 +97,12 @@ def test_refresh_token(mocker: MockFixture, auth_instance: Auth) -> None:

assert isinstance(auth_instance._auth_token, dict)
assert auth_instance._auth_token["access_token"] == "new_access_token"
assert (
auth_instance._auth_token["refresh_token"] == "new_refresh_token"
)
assert auth_instance._auth_token["refresh_token"] == "new_refresh_token"
finally:
auth_instance._auth_token = old_token_data


def test_authenticate_function(
mocker: MockFixture, auth_instance: Auth
) -> None:
def test_authenticate_function(mocker: MockFixture, auth_instance: Auth) -> None:
"""Test the authenticate function with username and password."""
old_token_data = deepcopy(auth_instance._auth_token)
token_data = {
Expand Down Expand Up @@ -195,14 +183,10 @@ def test_authentication_fail(mocker: MockFixture, auth_instance: Auth) -> None:
refresh_token="test_refresh_token",
)
with pytest.raises(ValueError):
auth_instance.check_authentication(
auth_url="https://example.com"
)
auth_instance.check_authentication(auth_url="https://example.com")
auth_instance._auth_token = mock_token_data
with pytest.raises(ValueError):
auth_instance.check_authentication(
auth_url="https://example.com"
)
auth_instance.check_authentication(auth_url="https://example.com")
finally:
auth_instance._auth_token = old_token_data

Expand Down Expand Up @@ -231,3 +215,22 @@ def test_real_auth(test_server: str, auth_instance: Auth) -> None:
assert token_data2["access_token"] == token
finally:
auth_instance._auth_token = old_token_data


def test_userinfo(mocker: MockFixture, test_server: str, auth: Dict[str, str]) -> None:
"""Test getting the user info."""

res = requests.get(
f"{test_server}//api/auth/v2/userinfo",
headers={"Authorization": f"Bearer {auth['access_token']}"},
timeout=3,
)
assert res.status_code == 200
assert "last_name" in res.json()
with mocker.patch("freva_rest.auth.get_userinfo", return_value={}):
res = requests.get(
f"{test_server}//api/auth/v2/userinfo",
headers={"Authorization": f"Bearer {auth['access_token']}"},
timeout=3,
)
assert res.status_code == 404
Loading

0 comments on commit 6b9b824

Please sign in to comment.