diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..d720f307 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = D203,E121,E124,E126,E203,E231,E261,E251,E701,F403,F405,E402 +exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,*/migrations/*,tests.py,*/tests/* +max-line-length=180 +max-complexity = 25 diff --git a/README.md b/README.md index fca8d068..ba4554d1 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,14 @@ Please consider the following branches: ### Executing Unit Tests -Once you have activate the virtualenv, unit tests can be executed as show below. +Once you have activate the virtualenv, further dependencies must be installed as show below. + +```` +pip install -r requirements-dev.txt + +```` + +Therefore the unit tests can be executed as show below. ```` pytest pyeudiw -x @@ -131,6 +138,23 @@ you can run the test by passing the mon user and password in this way PYEUDIW_MONGO_TEST_AUTH_INLINE="satosa:thatpassword@" pytest pyeudiw -x ```` +### Executing integration tests + +iam-proxy-italia project must be configured and in execution. + +Integrations tests checks bot hthe cross device flow and the same device flow. + +The cross device flow requires `playwrite` to be installed. + +```` +cd examples/satosa/integration_tests + +playwrite install + +PYEUDIW_MONGO_TEST_AUTH_INLINE="satosa:thatpassword@" pytest pyeudiw -x +```` + + ## Authors diff --git a/docs/TRUST.md b/docs/TRUST.md index 42fabef6..fa761953 100644 --- a/docs/TRUST.md +++ b/docs/TRUST.md @@ -29,18 +29,18 @@ Some HTTPC parameters are commonly used, have a default value and as an alternat ### Federation -Module `pyeudiw.trust.default.federation` provides a source of trusted entities and metadata based on [OpenID Federation](https://openid.net/specs/openid-federation-1_0.html) that is intended to be applicable to Issuer, Holders and Verifiers. In the specific case of the Verifier (this application), the module can expose verifier metadata at the `.well-known/openid-federation` endpoint. +Module `pyeudiw.trust.handler.federation` provides a source of trusted entities and metadata based on [OpenID Federation](https://openid.net/specs/openid-federation-1_0.html) that is intended to be applicable to Issuer, Holders and Verifiers. In the specific case of the Verifier (this application), the module can expose verifier metadata at the `.well-known/openid-federation` endpoint. The configuration parameters of the module are the following. | Parameter | Description | Example Value | | -------------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------ | -| config.federation.metadata_type | The type of metadata to use for the federation | wallet_relying_party | +| config.federation.metadata_type | The type of metadata to use for the federation | openid_credential_verifier | | config.federation.authority_hints | The list of authority hints to use for the federation | [http://127.0.0.1:10000] | | config.federation.trust_anchors | The list of trust anchors to use for the federation | [http://127.0.0.1:10000] | | config.federation.default_sig_alg | The default signature algorithm to use for the federation | RS256 | -| config.federation.federation_entity_metadata.organization_name | The organization name | Developers Italia SATOSA OpenID4VP backend policy_uri, tos_uri, logo_uri | +| config.federation.federation_entity_metadata.organization_name | The organization name | IAM Proxy Italia OpenID4VP backend policy_uri, tos_uri, logo_uri | | config.federation.federation_entity_metadata.homepage_uri | The URI of the homepage | https://developers.italia.it | | config.federation.federation_entity_metadata.policy_uri | The URI of the policy | https://developers.italia.it/policy.html | | config.federation.federation_entity_metadata.tos_uri | The URI of the TOS | https://developers.italia.it/tos.html | diff --git a/example/satosa/integration_test/commons.py b/example/satosa/integration_test/commons.py index 69670390..8538ad5a 100644 --- a/example/satosa/integration_test/commons.py +++ b/example/satosa/integration_test/commons.py @@ -30,9 +30,10 @@ ) from pyeudiw.sd_jwt.holder import SDJWTHolder from pyeudiw.trust.model.trust_source import TrustSourceData -from saml2_sp import saml2_request -from settings import ( +from . saml2_sp import saml2_request + +from . settings import ( IDP_BASEURL, CONFIG_DB, RP_EID, @@ -177,8 +178,8 @@ def create_authorize_response(vp_token: str, state: str, response_uri: str) -> s ).content.decode() rp_ec = decode_jwt_payload(rp_ec_jwt) - assert response_uri == rp_ec["metadata"]["wallet_relying_party"]["response_uris_supported"][0] - encryption_key = rp_ec["metadata"]["wallet_relying_party"]["jwks"]["keys"][1] + # assert response_uri == rp_ec["metadata"]["openid_credential_verifier"]["response_uris"][0] + encryption_key = rp_ec["metadata"]["openid_credential_verifier"]["jwks"]["keys"][1] response = { "state": state, diff --git a/example/satosa/integration_test/cross_device_integration_test.py b/example/satosa/integration_test/cross_device_integration_test.py index bab9cf8a..11853cbc 100644 --- a/example/satosa/integration_test/cross_device_integration_test.py +++ b/example/satosa/integration_test/cross_device_integration_test.py @@ -7,7 +7,7 @@ from pyeudiw.jwt.utils import decode_jwt_payload -from commons import ( +from . commons import ( ISSUER_CONF, setup_test_db_engine, apply_trust_settings, @@ -18,7 +18,7 @@ extract_saml_attributes, verify_request_object_jwt ) -from settings import TIMEOUT_S +from . settings import TIMEOUT_S # put a trust attestation related itself into the storage # this is then used as trust_chain header parameter in the signed request object @@ -92,6 +92,7 @@ def run(playwright: Playwright): request_object_claims["nonce"], request_object_claims["client_id"] ) + wallet_response_data = create_authorize_response( verifiable_presentations, request_object_claims["state"], diff --git a/example/satosa/integration_test/same_device_integration_test.py b/example/satosa/integration_test/same_device_integration_test.py index 676da713..08b60cd4 100644 --- a/example/satosa/integration_test/same_device_integration_test.py +++ b/example/satosa/integration_test/same_device_integration_test.py @@ -4,7 +4,7 @@ from pyeudiw.jwt.utils import decode_jwt_payload -from commons import ( +from . commons import ( ISSUER_CONF, setup_test_db_engine, apply_trust_settings, @@ -15,7 +15,7 @@ extract_saml_attributes, verify_request_object_jwt ) -from settings import TIMEOUT_S +from . settings import TIMEOUT_S # put a trust attestation related itself into the storage # this is then used as trust_chain header parameter in the signed request object diff --git a/example/satosa/integration_test/settings.py b/example/satosa/integration_test/settings.py index 082693d2..152d2739 100644 --- a/example/satosa/integration_test/settings.py +++ b/example/satosa/integration_test/settings.py @@ -103,7 +103,7 @@ "sub": RP_EID, 'jwks': {"keys": rp_jwks}, "metadata": { - "wallet_relying_party": { + "openid_credential_verifier": { 'jwks': {"keys": []} }, "federation_entity": { diff --git a/example/satosa/pyeudiw_backend.yaml b/example/satosa/pyeudiw_backend.yaml index 8af38c48..2b125084 100644 --- a/example/satosa/pyeudiw_backend.yaml +++ b/example/satosa/pyeudiw_backend.yaml @@ -2,7 +2,7 @@ module: pyeudiw.satosa.backend.OpenID4VPBackend name: OpenID4VP config: - + ui: static_storage_url: !ENV SATOSA_BASE_STATIC template_folder: "templates" # project root @@ -20,7 +20,6 @@ config: module: pyeudiw.satosa.default.response_handler class: ResponseHandler path: '/response-uri' - entity_configuration: '/.well-known/openid-federation' status: '/status' get_response: '/get-response' @@ -107,38 +106,108 @@ config: subject_id_random_value: CHANGEME! network: - httpc_params: + httpc_params: &httpc_params connection: ssl: true session: timeout: 6 + # private jwk + metadata_jwks: &metadata_jwks + - crv: P-256 + d: KzQBowMMoPmSZe7G8QsdEWc1IvR2nsgE8qTOYmMcLtc + kid: dDwPWXz5sCtczj7CJbqgPGJ2qQ83gZ9Sfs-tJyULi6s + use: sig + kty: EC + x: TSO-KOqdnUj5SUuasdlRB2VVFSqtJOxuR5GftUTuBdk + y: ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7 + - kty: RSA + d: QUZsh1NqvpueootsdSjFQz-BUvxwd3Qnzm5qNb-WeOsvt3rWMEv0Q8CZrla2tndHTJhwioo1U4NuQey7znijhZ177bUwPPxSW1r68dEnL2U74nKwwoYeeMdEXnUfZSPxzs7nY6b7vtyCoA-AjiVYFOlgKNAItspv1HxeyGCLhLYhKvS_YoTdAeLuegETU5D6K1xGQIuw0nS13Icjz79Y8jC10TX4FdZwdX-NmuIEDP5-s95V9DMENtVqJAVE3L-wO-NdDilyjyOmAbntgsCzYVGH9U3W_djh4t3qVFCv3r0S-DA2FD3THvlrFi655L0QHR3gu_Fbj3b9Ybtajpue_Q + e: AQAB + use: enc + kid: 9Cquk0X-fNPSdePQIgQcQZtD6J0IjIRrFigW2PPK_-w + n: utqtxbs-jnK0cPsV7aRkkZKA9t4S-WSZa3nCZtYIKDpgLnR_qcpeF0diJZvKOqXmj2cXaKFUE-8uHKAHo7BL7T-Rj2x3vGESh7SG1pE0thDGlXj4yNsg0qNvCXtk703L2H3i1UXwx6nq1uFxD2EcOE4a6qDYBI16Zl71TUZktJwmOejoHl16CPWqDLGo9GUSk_MmHOV20m4wXWkB4qbvpWVY8H6b2a0rB1B1YPOs5ZLYarSYZgjDEg6DMtZ4NgiwZ-4N1aaLwyO-GLwt9Vf-NBKwoxeRyD3zWE2FXRFBbhKGksMrCGnFDsNl5JTlPjaM3kYyImE941ggcuc495m-Fw + p: 2zmGXIMCEHPphw778YjVTar1eycih6fFSJ4I4bl1iq167GqO0PjlOx6CZ1-OdBTVU7HfrYRiUK_BnGRdPDn-DQghwwkB79ZdHWL14wXnpB5y-boHz_LxvjsEqXtuQYcIkidOGaMG68XNT1nM4F9a8UKFr5hHYT5_UIQSwsxlRQ0 + q: 2jMFt2iFrdaYabdXuB4QMboVjPvbLA-IVb6_0hSG_-EueGBvgcBxdFGIZaG6kqHqlB7qMsSzdptU0vn6IgmCZnX-Hlt6c5X7JB_q91PZMLTO01pbZ2Bk58GloalCHnw_mjPh0YPviH5jGoWM5RHyl_HDDMI-UeLkzP7ImxGizrM + + #This is the configuration for the relaying party metadata + metadata: &metadata + application_type: web + + #The following section contains all the algorithms supported for the encryption of response + authorization_encrypted_response_alg: *enc_alg_supported + authorization_encrypted_response_enc: *enc_enc_supported + authorization_signed_response_alg: *sig_alg_supported + + #Various informations of the client + client_id: # this field is autopopulated using internal variables base_url and name using the following format: "/" + client_name: Name of an example organization + contacts: + - ops@verifier.example.org + default_acr_values: + - https://www.spid.gov.it/SpidL2 + - https://www.spid.gov.it/SpidL3 + + #The following section contains all the algorithms supported for the encryption of id token response + id_token_encrypted_response_alg: *enc_alg_supported + id_token_encrypted_response_enc: *enc_enc_supported + id_token_signed_response_alg: *sig_alg_supported + + # loaded in the __init__ + # jwks: + + redirect_uris: + # This field is autopopulated using internal variables base_url and name using the following format: //redirect-uri" + request_uris: + # This field is autopopulated using internal variables base_url and name using the following format: //request-uri" + + # not necessary according to openid4vp + # default_max_age: 1111 + # require_auth_time: true + # subject_type: pairwise + + vp_formats: + vc+sd-jwt: + sd-jwt_alg_values: + - ES256 + - ES384 + kb-jwt_alg_values: + - ES256 + - ES384 + trust: direct_trust_sd_jwt_vc: module: pyeudiw.trust.handler.direct_trust_sd_jwt_vc class: DirectTrustSdJwtVc config: + cache_ttl: 0 + httpc_params: *httpc_params jwk_endpoint: /.well-known/jwt-vc-issuer direct_trust_jar: module: pyeudiw.trust.handler.direct_trust_jar class: DirectTrustJar config: + cache_ttl: 0 + httpc_params: *httpc_params jwk_endpoint: /.well-known/jar-issuer jwks: *metadata_jwks federation: module: pyeudiw.trust.handler.federation class: FederationHandler config: - metadata_type: "wallet_relying_party" + httpc_params: *httpc_params + cache_ttl: 0 + entity_configuration_exp: 600 + metadata_type: "openid_credential_verifier" + metadata: *metadata authority_hints: - http://127.0.0.1:8000 trust_anchors: - - public_keys: [] - - http://127.0.0.1:8000 + - http://127.0.0.1:8000: [] # array of public keys default_sig_alg: "RS256" trust_marks: [] federation_entity_metadata: - organization_name: Developers Italia SATOSA OpenID4VP backend + organization_name: IAM Proxy Italia OpenID4VP backend homepage_uri: https://developers.italia.it policy_uri: https://developers.italia.it tos_uri: https://developers.italia.it @@ -184,68 +253,3 @@ config: db_trust_sources_collection: trust_sources data_ttl: 63072000 # 2 years # - connection_params: - - # private jwk - metadata_jwks: &metadata_jwks - # !ENV PYEUDIW_METADATA_JWKS - - crv: P-256 - d: KzQBowMMoPmSZe7G8QsdEWc1IvR2nsgE8qTOYmMcLtc - kid: dDwPWXz5sCtczj7CJbqgPGJ2qQ83gZ9Sfs-tJyULi6s - use: sig - kty: EC - x: TSO-KOqdnUj5SUuasdlRB2VVFSqtJOxuR5GftUTuBdk - y: ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7 - - kty: RSA - d: QUZsh1NqvpueootsdSjFQz-BUvxwd3Qnzm5qNb-WeOsvt3rWMEv0Q8CZrla2tndHTJhwioo1U4NuQey7znijhZ177bUwPPxSW1r68dEnL2U74nKwwoYeeMdEXnUfZSPxzs7nY6b7vtyCoA-AjiVYFOlgKNAItspv1HxeyGCLhLYhKvS_YoTdAeLuegETU5D6K1xGQIuw0nS13Icjz79Y8jC10TX4FdZwdX-NmuIEDP5-s95V9DMENtVqJAVE3L-wO-NdDilyjyOmAbntgsCzYVGH9U3W_djh4t3qVFCv3r0S-DA2FD3THvlrFi655L0QHR3gu_Fbj3b9Ybtajpue_Q - e: AQAB - use: enc - kid: 9Cquk0X-fNPSdePQIgQcQZtD6J0IjIRrFigW2PPK_-w - n: utqtxbs-jnK0cPsV7aRkkZKA9t4S-WSZa3nCZtYIKDpgLnR_qcpeF0diJZvKOqXmj2cXaKFUE-8uHKAHo7BL7T-Rj2x3vGESh7SG1pE0thDGlXj4yNsg0qNvCXtk703L2H3i1UXwx6nq1uFxD2EcOE4a6qDYBI16Zl71TUZktJwmOejoHl16CPWqDLGo9GUSk_MmHOV20m4wXWkB4qbvpWVY8H6b2a0rB1B1YPOs5ZLYarSYZgjDEg6DMtZ4NgiwZ-4N1aaLwyO-GLwt9Vf-NBKwoxeRyD3zWE2FXRFBbhKGksMrCGnFDsNl5JTlPjaM3kYyImE941ggcuc495m-Fw - p: 2zmGXIMCEHPphw778YjVTar1eycih6fFSJ4I4bl1iq167GqO0PjlOx6CZ1-OdBTVU7HfrYRiUK_BnGRdPDn-DQghwwkB79ZdHWL14wXnpB5y-boHz_LxvjsEqXtuQYcIkidOGaMG68XNT1nM4F9a8UKFr5hHYT5_UIQSwsxlRQ0 - q: 2jMFt2iFrdaYabdXuB4QMboVjPvbLA-IVb6_0hSG_-EueGBvgcBxdFGIZaG6kqHqlB7qMsSzdptU0vn6IgmCZnX-Hlt6c5X7JB_q91PZMLTO01pbZ2Bk58GloalCHnw_mjPh0YPviH5jGoWM5RHyl_HDDMI-UeLkzP7ImxGizrM - - #This is the configuration for the relaying party metadata - metadata: - application_type: web - - #The following section contains all the algorithms supported for the encryption of response - authorization_encrypted_response_alg: *enc_alg_supported - authorization_encrypted_response_enc: *enc_enc_supported - authorization_signed_response_alg: *sig_alg_supported - - #Various informations of the client - client_id: # this field is autopopulated using internal variables base_url and name using the following format: "/" - client_name: Name of an example organization - contacts: - - ops@verifier.example.org - default_acr_values: - - https://www.spid.gov.it/SpidL2 - - https://www.spid.gov.it/SpidL3 - - default_max_age: 1111 - - #The following section contains all the algorithms supported for the encryption of id token response - id_token_encrypted_response_alg: *enc_alg_supported - id_token_encrypted_response_enc: *enc_enc_supported - id_token_signed_response_alg: *sig_alg_supported - - # loaded in the __init__ - # jwks: - - - redirect_uris: - # This field is autopopulated using internal variables base_url and name using the following format: //redirect-uri" - request_uris: - # This field is autopopulated using internal variables base_url and name using the following format: //request-uri" - - require_auth_time: true - subject_type: pairwise - - vp_formats: - vc+sd-jwt: - sd-jwt_alg_values: - - ES256 - - ES384 - kb-jwt_alg_values: - - ES256 - - ES384 diff --git a/html_linting.sh b/html_linting.sh new file mode 100644 index 00000000..72df234e --- /dev/null +++ b/html_linting.sh @@ -0,0 +1,16 @@ +echo -e '\nHTML linting:' +shopt -s globstar nullglob +for file in `find example -type f | grep html` +do + echo -e "\n$file:" + html_lint.py "$file" | awk -v path="file://$PWD/$file:" '$0=path$0' | sed -e 's/: /:\n\t/'; +done + +errors=0 +for file in "${array[@]}" +do + errors=$((errors + $(html_lint.py "$file" | grep -c 'Error'))) +done + +echo -e "\nHTML errors: $errors" +if [ "$errors" -gt 0 ]; then exit 1; fi; diff --git a/linting.sh b/linting.sh index f5fdf989..a2020c41 100755 --- a/linting.sh +++ b/linting.sh @@ -7,23 +7,13 @@ autopep8 -r --in-place $SRC autoflake -r --in-place --remove-unused-variables --expand-star-imports --remove-all-unused-imports $SRC flake8 $SRC --count --select=E9,F63,F7,F82 --show-source --statistics -flake8 $SRC --max-line-length 120 --count --statistics + +# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide +flake8 $SRC --count --exit-zero --statistics + +isort --atomic $SRC + +black $SRC bandit -r -x $SRC/test* $SRC/* -echo -e '\nHTML linting:' -shopt -s globstar nullglob -for file in `find example -type f | grep html` -do - echo -e "\n$file:" - html_lint.py "$file" | awk -v path="file://$PWD/$file:" '$0=path$0' | sed -e 's/: /:\n\t/'; -done - -errors=0 -for file in "${array[@]}" -do - errors=$((errors + $(html_lint.py "$file" | grep -c 'Error'))) -done - -echo -e "\nHTML errors: $errors" -if [ "$errors" -gt 0 ]; then exit 1; fi; diff --git a/pyeudiw/tests/trust/test_trust_evaluation_helper.py b/oldies/_test_trust_evaluation_helper.py similarity index 96% rename from pyeudiw/tests/trust/test_trust_evaluation_helper.py rename to oldies/_test_trust_evaluation_helper.py index 9773085c..e846cbbe 100644 --- a/pyeudiw/tests/trust/test_trust_evaluation_helper.py +++ b/oldies/_test_trust_evaluation_helper.py @@ -1,10 +1,12 @@ -import pytest from datetime import datetime -from pyeudiw.tests.settings import CONFIG -from pyeudiw.trust import TrustEvaluationHelper + +import pytest + from pyeudiw.storage.db_engine import DBEngine, TrustType from pyeudiw.tests.federation.base import trust_chain_issuer -from pyeudiw.tests.x509.test_x509 import gen_chain, chain_to_pem +from pyeudiw.tests.settings import CONFIG +from pyeudiw.tests.x509.test_x509 import chain_to_pem, gen_chain +from pyeudiw.trust import TrustEvaluationHelper class TestTrustEvaluationHelper: diff --git a/oldies/federation/__init__.py b/oldies/federation/__init__.py new file mode 100644 index 00000000..811a4fac --- /dev/null +++ b/oldies/federation/__init__.py @@ -0,0 +1,13 @@ +def is_ec(payload: dict) -> None: + """ + Determines if payload dict is an Entity Configuration + + :param payload: the object to determine if is an Entity Configuration + :type payload: dict + """ + + try: + EntityConfigurationPayload(**payload) + except ValueError as e: + _msg = f"Invalid Entity Configuration: {e}" + raise InvalidEntityConfiguration(_msg) \ No newline at end of file diff --git a/oldies/federation/policy.py b/oldies/federation/policy.py new file mode 100644 index 00000000..9d75dda0 --- /dev/null +++ b/oldies/federation/policy.py @@ -0,0 +1,15 @@ +def diff2policy(new, old): + res = {} + for claim in set(new).intersection(set(old)): + if new[claim] == old[claim]: + continue + else: + res[claim] = {'value': new[claim]} + + for claim in set(new).difference(set(old)): + if claim in ['contacts']: + res[claim] = {'add': new[claim]} + else: + res[claim] = {'value': new[claim]} + + return res \ No newline at end of file diff --git a/oldies/federation/statement.py b/oldies/federation/statement.py new file mode 100644 index 00000000..e86e4042 --- /dev/null +++ b/oldies/federation/statement.py @@ -0,0 +1,19 @@ +def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict, http_async: bool = True) -> list[dict]: + """ + Retrieves jwks from an entity uri. + + :param jwks_uri: the uri where the jwks are located. + :type jwks_uri: str + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param http_async: if is set to True the operation will be performed in async (deafault True) + :type http_async: bool + + :returns: A list of entity jwks. + :rtype: list[dict] + """ + + response = get_http_url(jwks_uri, httpc_params, http_async) + jwks = [i.json() for i in response] + + return jwks \ No newline at end of file diff --git a/oldies/federation/test/test_schema.py b/oldies/federation/test/test_schema.py new file mode 100644 index 00000000..eeee35eb --- /dev/null +++ b/oldies/federation/test/test_schema.py @@ -0,0 +1,9 @@ +def test_is_ec(): + is_ec(ta_ec) + + +def test_is_ec_false(): + try: + is_ec(ta_es) + except InvalidEntityConfiguration: + pass \ No newline at end of file diff --git a/oldies/jwk/__init__.py b/oldies/jwk/__init__.py new file mode 100644 index 00000000..7e312bb2 --- /dev/null +++ b/oldies/jwk/__init__.py @@ -0,0 +1,31 @@ +class RSAJWK(JWK): + def __init__(self, key: dict | None = None, hash_func: str = "SHA-256") -> None: + super().__init__(key, "RSA", hash_func, None) + + +class ECJWK(JWK): + def __init__( + self, key: dict | None = None, hash_func: str = "SHA-256", ec_crv: str = "P-256" + ) -> None: + super().__init__(key, "EC", hash_func, ec_crv) + + +def jwk_form_dict(key: dict, hash_func: str = "SHA-256") -> RSAJWK | ECJWK: + """ + Returns a JWK instance from a dict. + + :param key: a dict that represents the key. + :type key: dict + + :returns: a JWK instance. + :rtype: JWK + """ + _kty = key.get("kty", None) + + if _kty is None or _kty not in ["EC", "RSA"]: + raise InvalidJwk("Invalid JWK") + elif _kty == "RSA": + return RSAJWK(key, hash_func) + else: + ec_crv = key.get("crv", "P-256") + return ECJWK(key, hash_func, ec_crv) \ No newline at end of file diff --git a/oldies/jwk/exceptions.py b/oldies/jwk/exceptions.py new file mode 100644 index 00000000..a7f041a8 --- /dev/null +++ b/oldies/jwk/exceptions.py @@ -0,0 +1,7 @@ + +class JwkError(Exception): + pass + + +class InvalidJwk(Exception): + pass diff --git a/oldies/jwk/parse.py b/oldies/jwk/parse.py new file mode 100644 index 00000000..2e897d76 --- /dev/null +++ b/oldies/jwk/parse.py @@ -0,0 +1,11 @@ +def adapt_key_to_JWK(key: dict | JWK | cryptojwt.jwk.JWK) -> JWK: + """Function adapt_key_to_JWK normalize key representation format to + the internal JWK. + """ + if isinstance(key, JWK): + return key + if isinstance(key, dict): + return JWK(key) + if isinstance(key, cryptojwt.jwk.JWK): + return JWK(key.to_dict()) + raise ValueError(f"not a valid or supported key format: {type(key)}") diff --git a/oldies/jwt/parse.py b/oldies/jwt/parse.py new file mode 100644 index 00000000..1eb38b4f --- /dev/null +++ b/oldies/jwt/parse.py @@ -0,0 +1,6 @@ +def _unsafe_decode_part(part: str) -> dict: + padding_needed = len(part) % 4 + if padding_needed: + part += "=" * (4 - padding_needed) + decoded_bytes = base64.urlsafe_b64decode(part) + return json.loads(decoded_bytes.decode("utf-8")) \ No newline at end of file diff --git a/oldies/oauth2/exceptions.py b/oldies/oauth2/exceptions.py new file mode 100644 index 00000000..2c2a0a85 --- /dev/null +++ b/oldies/oauth2/exceptions.py @@ -0,0 +1,2 @@ +class InvalidDPoPJwk(Exception): + pass \ No newline at end of file diff --git a/oldies/openid4vp/authorization_response.py b/oldies/openid4vp/authorization_response.py new file mode 100644 index 00000000..f7ec3511 --- /dev/null +++ b/oldies/openid4vp/authorization_response.py @@ -0,0 +1,35 @@ +import json +from cryptojwt.jwk.ec import ECKey +from cryptojwt.jwk.rsa import RSAKey +from pyeudiw.jwt.jws_helper import JWSHelper +from pyeudiw.jwt.utils import decode_jwt_header + +from pyeudiw.jwk.exceptions import KidNotFoundError + +def _get_jwk_kid_from_store(jwt: str, key_store: dict[str, dict]) -> dict: + headers = decode_jwt_header(jwt) + kid: str | None = headers.get("kid", None) + if kid is None: + raise KidNotFoundError( + "authorization response is missing mandatory parameter [kid] in header section" + ) + jwk_dict = key_store.get(kid, None) + if jwk_dict is None: + raise KidNotFoundError( + f"authorization response is encrypted with jwk with kid='{kid}' not found in store" + ) + return jwk_dict + + +def _decrypt_jwe(jwe: str, decrypting_jwk: dict[str, any]) -> dict: + decrypter = JWEHelper(decrypting_jwk) + return decrypter.decrypt(jwe) + + +def _verify_and_decode_jwt( + jwt: str, verifying_jwk: dict[dict, ECKey | RSAKey | dict] +) -> dict: + verifier = JWSHelper(verifying_jwk) + raw_payload: str = verifier.verify(jwt)["msg"] + payload: dict = json.loads(raw_payload) + return payload diff --git a/pyeudiw/openid4vp/direct_post_response.py b/oldies/openid4vp/direct_post_response.py similarity index 87% rename from pyeudiw/openid4vp/direct_post_response.py rename to oldies/openid4vp/direct_post_response.py index 46b54c10..cb8819c3 100644 --- a/pyeudiw/openid4vp/direct_post_response.py +++ b/oldies/openid4vp/direct_post_response.py @@ -1,22 +1,22 @@ import logging +from typing import Dict + import pydantic -from typing import Dict from pyeudiw.jwk import JWK - from pyeudiw.jwk.exceptions import KidNotFoundError from pyeudiw.jwt.jwe_helper import JWEHelper from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format from pyeudiw.openid4vp.exceptions import ( InvalidVPToken, - VPNotFound, + NoNonceInVPToken, VPInvalidNonce, - NoNonceInVPToken + VPNotFound, ) -from pyeudiw.openid4vp.schemas.vp_token import VPTokenPayload, VPTokenHeader -from pyeudiw.openid4vp.vp import Vp +from pyeudiw.openid4vp.schemas.vp_token import VPTokenHeader, VPTokenPayload from pyeudiw.openid4vp.utils import vp_parser +from pyeudiw.openid4vp.vp import Vp logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ def _decode_payload(self) -> None: :raises JWSVerificationError: if jws field is not in a JWS Format :raises JWEDecryptionError: if jwe field is not in a JWE Format """ - _kid = self.headers.get('kid', None) + _kid = self.headers.get("kid", None) if not _kid: raise KidNotFoundError( f"The JWT headers {self.headers} doesnt have any KID value" @@ -90,10 +90,10 @@ def _validate_vp(self, vp: dict) -> bool: try: # check nonce if self.nonce: - if not vp.payload.get('nonce', None): + if not vp.payload.get("nonce", None): raise NoNonceInVPToken() - if self.nonce != vp.payload['nonce']: + if self.nonce != vp.payload["nonce"]: raise VPInvalidNonce( "VP has a unknown nonce: " f"{self.nonce} != {vp.payload['nonce']}" @@ -101,9 +101,7 @@ def _validate_vp(self, vp: dict) -> bool: VPTokenPayload(**vp.payload) VPTokenHeader(**vp.headers) except pydantic.ValidationError as e: - raise InvalidVPToken( - f"VP is not valid, {e}: {vp.headers}.{vp.payload}" - ) + raise InvalidVPToken(f"VP is not valid, {e}: {vp.headers}.{vp.payload}") return True def validate(self) -> bool: @@ -120,9 +118,7 @@ def validate(self) -> bool: if all_valid is None: all_valid = True except Exception: - logger.error( - - ) + logger.error() all_valid = False return all_valid @@ -139,23 +135,21 @@ def get_presentation_vps(self) -> list[Vp]: if self._vps: return self._vps - _vps = self.payload.get('vp_token', []) + _vps = self.payload.get("vp_token", []) vps = [_vps] if isinstance(_vps, str) else _vps if not vps: - raise VPNotFound( - f'Vps are empty for response with nonce "{self.nonce}"' - ) + raise VPNotFound(f'Vps are empty for response with nonce "{self.nonce}"') for vp in vps: # TODO - add an exception handling here _vp = vp_parser(vp) self._vps.append(_vp) - cred_iss = _vp.credential_payload['iss'] + cred_iss = _vp.credential_payload["iss"] if not self.credentials_by_issuer.get(cred_iss, None): self.credentials_by_issuer[cred_iss] = [] - self.credentials_by_issuer[cred_iss].append(_vp.payload['vp']) + self.credentials_by_issuer[cred_iss].append(_vp.payload["vp"]) return self._vps diff --git a/oldies/openid4vp/exceptions.py b/oldies/openid4vp/exceptions.py new file mode 100644 index 00000000..137a557b --- /dev/null +++ b/oldies/openid4vp/exceptions.py @@ -0,0 +1,33 @@ +class KIDNotFound(Exception): + """ + Raised when kid is not present in the public key dict + """ + + +class VPSchemaException(Exception): + pass + + +class VPNotFound(Exception): + pass + + +class VPInvalidNonce(Exception): + pass + + +class NoNonceInVPToken(Exception): + """ + Raised when a given VP has no nonce + """ + +class InvalidVPSignature(InvalidVPKeyBinding): + """Raised when a VP contains a proof of possession key binding and + its signature verification failed. + """ + + +class RevokedVPToken(Exception): + """ + Raised when a given VP is revoked + """ \ No newline at end of file diff --git a/oldies/openid4vp/utils.py b/oldies/openid4vp/utils.py new file mode 100644 index 00000000..3020ba28 --- /dev/null +++ b/oldies/openid4vp/utils.py @@ -0,0 +1,77 @@ + +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload + + +def vp_parser(jwt: str) -> Vp: + """ + Handle the jwt returning the correct VP istance. + + :param jwt: a string that represents the jwt. + :type jwt: str + + :raises VPFormatNotSupported: if the VP Digital credentials type is not implemented yet. + + :returns: the VP istance. + :rtype: Vp + """ + + headers = decode_jwt_header(jwt) + + typ: str | None = headers.get("typ", None) + if typ is None: + raise ValueError("missing mandatory header [typ] in jwt header") + + match typ.lower(): + case "jwt": + return VpSdJwt(jwt) + case "vc+sd-jwt": + raise NotImplementedError( + "parsing of vp tokens with typ vc+sd-jwt not supported yet" + ) + case "mcdoc_cbor": + return VpMDocCbor(jwt) + case unsupported: + raise VPFormatNotSupported(f"parsing of unsupported vp typ [{unsupported}]") + +def infer_vp_header_claim(jws: str, claim_name: str) -> Any: + """ + Infer a claim from the header of a VP token. + + :param jws: the VP token + :type jws: str + + :param claim_name: the name of the claim to infer + :type claim_name: str + + :returns: the value of the claim + :rtype: Any + """ + headers = decode_jwt_header(jws) + claim_value = headers.get(claim_name, "") + return claim_value + + +def infer_vp_payload_claim(jws: str, claim_name: str) -> Any: + """ + Infer a claim from the payload of a VP token. + + :param jws: the VP token + :type jws: str + + :param claim_name: the name of the claim to infer + :type claim_name: str + + :returns: the value of the claim + :rtype: Any + """ + headers = decode_jwt_payload(jws) + claim_value: str = headers.get(claim_name, "") + return claim_value + + +def infer_vp_typ(jws: str) -> str: + return infer_vp_header_claim(jws, claim_name="typ") + + +def infer_vp_iss(jws: str) -> str: + return infer_vp_payload_claim(jws, claim_name="iss") diff --git a/pyeudiw/openid4vp/vp_sd_jwt.py b/oldies/openid4vp/vp_sd_jwt.py similarity index 88% rename from pyeudiw/openid4vp/vp_sd_jwt.py rename to oldies/openid4vp/vp_sd_jwt.py index 3e369b71..871ad945 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt.py +++ b/oldies/openid4vp/vp_sd_jwt.py @@ -1,13 +1,11 @@ from typing import Dict +from pyeudiw.jwk.exceptions import KidNotFoundError from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.jwt.verification import verify_jws_with_key from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload, is_jwt_format - - -from pyeudiw.jwk.exceptions import KidNotFoundError -from pyeudiw.openid4vp.vp import Vp +from pyeudiw.jwt.verification import verify_jws_with_key from pyeudiw.openid4vp.exceptions import InvalidVPToken +from pyeudiw.openid4vp.vp import Vp class VpSdJwt(Vp): @@ -43,13 +41,10 @@ def parse_digital_credential(self) -> None: Parse the digital credential of VP. """ - self.credential_headers = decode_jwt_header(self.payload['vp']) - self.credential_payload = decode_jwt_payload(self.payload['vp']) + self.credential_headers = decode_jwt_header(self.payload["vp"]) + self.credential_payload = decode_jwt_payload(self.payload["vp"]) - def verify( - self, - **kwargs - ) -> bool: + def verify(self, **kwargs) -> bool: """ Verifies a SDJWT. @@ -62,8 +57,7 @@ def verify( :returns: True if is valid, False otherwise. """ - issuer_jwks_by_kid: Dict[str, dict] = kwargs.get( - "issuer_jwks_by_kid", {}) + issuer_jwks_by_kid: Dict[str, dict] = kwargs.get("issuer_jwks_by_kid", {}) if not issuer_jwks_by_kid.get(self.credential_headers["kid"], None): raise KidNotFoundError( @@ -85,13 +79,11 @@ def verify( # TODO: with unit tests we have holder_disclosed_claims while in # interop we don't have it! - self.disclosed_user_attributes = result.get( - "holder_disclosed_claims", result - ) + self.disclosed_user_attributes = result.get("holder_disclosed_claims", result) # If IDA flatten the user attributes to be released - if 'verified_claims' in result: - result.update(result['verified_claims'].get('claims', {})) + if "verified_claims" in result: + result.update(result["verified_claims"].get("claims", {})) return True @@ -130,6 +122,6 @@ def credential_jwks(self) -> list[dict]: @property def credential_issuer(self) -> str: """Returns the credential issuer""" - if not self.credential_payload.get('iss', None): + if not self.credential_payload.get("iss", None): self.parse_digital_credential() - return self.credential_payload.get('iss', None) + return self.credential_payload.get("iss", None) diff --git a/oldies/satosa/default/openid4vp_backend.py b/oldies/satosa/default/openid4vp_backend.py new file mode 100644 index 00000000..bae3beca --- /dev/null +++ b/oldies/satosa/default/openid4vp_backend.py @@ -0,0 +1,14 @@ +from urllib.parse import quote_plus, urlencode + +def _build_authz_request_url(self, payload: dict) -> str: + scheme = self.config["authorization"]["url_scheme"] + if "://" not in scheme: + scheme = scheme + "://" + if not scheme.endswith("/"): + scheme = f"{scheme}/" + # NOTE: path component is currently unused by the protocol, but currently + # we leave it there as 'authorize' to stress the fact that this is an + # OAuth 2.0 request modified by JAR (RFC9101) + path = "authorize" + query_params = urlencode(payload, quote_via=quote_plus) + return f"{scheme}{path}?{query_params}" \ No newline at end of file diff --git a/oldies/satosa/default/response_handler.py b/oldies/satosa/default/response_handler.py new file mode 100644 index 00000000..dd79f6ff --- /dev/null +++ b/oldies/satosa/default/response_handler.py @@ -0,0 +1,39 @@ +def _is_same_device_flow(request_session: dict, context: Context) -> bool: + initiating_session_id: str | None = request_session.get("session_id", None) + if initiating_session_id is None: + raise ValueError( + "invalid session storage information: missing [session_id]" + ) + current_session_id: str | None = context.state.get("SESSION_ID", None) + if current_session_id is None: + raise ValueError("missing session id in wallet authorization response") + return initiating_session_id == current_session_id + +def _parse_http_request(self, context: Context) -> dict: + """Parse the http layer of the request to extract the dictionary data. + + :param context: the satosa context containing, among the others, the details of the HTTP request + :type context: satosa.Context + + :return: a dictionary containing the request data + :rtype: dict + + :raises BadRequestError: when request paramets are in a not processable state; the expected handling is returning 400 + """ + if ( + http_method := context.request_method.lower() + ) != ResponseHandler._SUPPORTED_RESPONSE_METHOD: + raise BadRequestError(f"HTTP method [{http_method}] not supported") + + if ( + content_type := context.http_headers["HTTP_CONTENT_TYPE"] + ) != ResponseHandler._SUPPORTED_RESPONSE_CONTENT_TYPE: + raise BadRequestError(f"HTTP content type [{content_type}] not supported") + + _endpoint = f"{self.server_url}{context.request_uri}" + + if self.config["metadata"].get("response_uris", None): + if _endpoint not in self.config["metadata"]["response_uris"]: + raise BadRequestError("response_uri not valid") + + return context.request \ No newline at end of file diff --git a/oldies/satosa/exceptions.py b/oldies/satosa/exceptions.py new file mode 100644 index 00000000..ed1808df --- /dev/null +++ b/oldies/satosa/exceptions.py @@ -0,0 +1,14 @@ +class NoBoundEndpointError(Exception): + """ + Raised when a given url path is not bound to any endpoint function + """ + + +class NotTrustedFederationError(Exception): + pass + +class DPOPValidationError(Exception): + """ + Raised when a DPoP validation error occurs + """ + pass \ No newline at end of file diff --git a/pyeudiw/satosa/utils/dpop.py b/oldies/satosa/utils/dpop.py similarity index 79% rename from pyeudiw/satosa/utils/dpop.py rename to oldies/satosa/utils/dpop.py index c2e297d3..afd2d00a 100644 --- a/pyeudiw/satosa/utils/dpop.py +++ b/oldies/satosa/utils/dpop.py @@ -1,24 +1,27 @@ -import pydantic - from typing import Union + +import pydantic from satosa.context import Context from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPVerifier from pyeudiw.openid4vp.schemas.wallet_instance_attestation import ( - WalletInstanceAttestationHeader, WalletInstanceAttestationPayload) + WalletInstanceAttestationHeader, + WalletInstanceAttestationPayload, +) +from pyeudiw.satosa.exceptions import DPOPValidationError from pyeudiw.satosa.utils.response import JsonResponse from pyeudiw.tools.base_logger import BaseLogger -from pyeudiw.satosa.exceptions import DPOPValidationError - class BackendDPoP(BaseLogger): """ Backend DPoP class. """ - def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, None]: + def _request_endpoint_dpop( + self, context: Context, *args + ) -> Union[JsonResponse, None]: """ Validates, if any, the DPoP http request header @@ -33,25 +36,30 @@ def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, :rtype: Union[JsonResponse, None] """ - if context.http_headers and 'HTTP_AUTHORIZATION' in context.http_headers: + if context.http_headers and "HTTP_AUTHORIZATION" in context.http_headers: # The wallet instance uses the endpoint authentication to give its WIA # take WIA - dpop_jws = context.http_headers['HTTP_AUTHORIZATION'].split()[-1] + dpop_jws = context.http_headers["HTTP_AUTHORIZATION"].split()[-1] _head = decode_jwt_header(dpop_jws) wia = decode_jwt_payload(dpop_jws) self._log_debug( - context, message=f"[FOUND WIA] Headers: {_head} and Payload: {wia}") + context, message=f"[FOUND WIA] Headers: {_head} and Payload: {wia}" + ) try: WalletInstanceAttestationHeader(**_head) except pydantic.ValidationError as e: self._log_warning( - context, message=f"[FOUND WIA] Invalid Headers: {_head}. Validation error: {e}") + context, + message=f"[FOUND WIA] Invalid Headers: {_head}. Validation error: {e}", + ) except Exception as e: self._log_warning( - context, message=f"[FOUND WIA] Invalid Headers: {_head}. Unexpected error: {e}") + context, + message=f"[FOUND WIA] Invalid Headers: {_head}. Unexpected error: {e}", + ) try: WalletInstanceAttestationPayload(**wia) @@ -72,9 +80,9 @@ def _request_endpoint_dpop(self, context: Context, *args) -> Union[JsonResponse, try: dpop = DPoPVerifier( - public_jwk=wia['cnf']['jwk'], - http_header_authz=context.http_headers['HTTP_AUTHORIZATION'], - http_header_dpop=context.http_headers['HTTP_DPOP'] + public_jwk=wia["cnf"]["jwk"], + http_header_authz=context.http_headers["HTTP_AUTHORIZATION"], + http_header_dpop=context.http_headers["HTTP_DPOP"], ) except pydantic.ValidationError as e: _msg = f"DPoP validation error: {e}" diff --git a/oldies/sd_jwt/exceptions.py b/oldies/sd_jwt/exceptions.py new file mode 100644 index 00000000..155d1e1e --- /dev/null +++ b/oldies/sd_jwt/exceptions.py @@ -0,0 +1,3 @@ +class UnknownCurveNistName(Exception): + pass + diff --git a/oldies/sd_jwt/sd_jwt.py b/oldies/sd_jwt/sd_jwt.py new file mode 100644 index 00000000..c00ff4f5 --- /dev/null +++ b/oldies/sd_jwt/sd_jwt.py @@ -0,0 +1,11 @@ + + +class SdJwtKb(SdJwt): + def __init__(self, token: str): + if not is_sd_jwt_kb_format(token): + raise ValueError( + f"input [token]={token} is not an sd-jwt with key binding with: maybe it is a regular jwt?" + ) + super().__init__(token) + if not self.holder_kb: + raise ValueError("missing key binding jwt") \ No newline at end of file diff --git a/oldies/sd_jwt/utils/demo_utils.py b/oldies/sd_jwt/utils/demo_utils.py new file mode 100644 index 00000000..4de00c53 --- /dev/null +++ b/oldies/sd_jwt/utils/demo_utils.py @@ -0,0 +1,17 @@ +def print_repr(values: Union[str, list], nlines=2): + value = "\n".join(values) if isinstance(values, (list, tuple)) else values + _nlines = "\n" * nlines if nlines else "" + print(value, end=_nlines) + + +def print_decoded_repr(value: str, nlines=2): + seq = [] + for i in value.split("."): + try: + padded = f"{i}{'=' * divmod(len(i),4)[1]}" + seq.append(f"{base64.urlsafe_b64decode(padded).decode()}") + except Exception as e: + logging.debug(f"{e} - for value: {i}") + seq.append(i) + _nlines = "\n" * nlines if nlines else "" + print("\n.\n".join(seq), end=_nlines) \ No newline at end of file diff --git a/pyeudiw/tools/jwk_handling.py b/oldies/tools/jwk_handling.py similarity index 90% rename from pyeudiw/tools/jwk_handling.py rename to oldies/tools/jwk_handling.py index 8bd65e3a..1b6d7de8 100644 --- a/pyeudiw/tools/jwk_handling.py +++ b/oldies/tools/jwk_handling.py @@ -1,7 +1,7 @@ -from pyeudiw.jwk import JWK +from pyeudiw.jwk import JWK +from pyeudiw.jwk.jwks import find_jwk_by_kid from pyeudiw.openid4vp.interface import VpTokenParser from pyeudiw.trust.interface import TrustEvaluator -from pyeudiw.jwk import find_jwk_by_kid def find_vp_token_key(token_parser: VpTokenParser, key_source: TrustEvaluator) -> JWK: @@ -27,6 +27,7 @@ def find_vp_token_key(token_parser: VpTokenParser, key_source: TrustEvaluator) - if isinstance(verification_key, dict): raise NotImplementedError( - "TODO: matching of public key (ex. from x5c) with keys from trust source") + "TODO: matching of public key (ex. from x5c) with keys from trust source" + ) raise Exception(f"invalid state: key with type {type(verification_key)}") diff --git a/oldies/tools/utils.py b/oldies/tools/utils.py new file mode 100644 index 00000000..83e5a756 --- /dev/null +++ b/oldies/tools/utils.py @@ -0,0 +1,57 @@ +def get_jwks( + httpc_params: dict, metadata: dict, federation_jwks: list[dict] = [] +) -> dict: + """ + Get jwks or jwks_uri or signed_jwks_uri + + :param httpc_params: parameters to perform http requests. + :type httpc_params: dict + :param metadata: metadata of the entity + :type metadata: dict + :param federation_jwks: jwks of the federation + :type federation_jwks: list + + :returns: A list of responses. + :rtype: list[dict] + """ + jwks_list = [] + if metadata.get("jwks"): + jwks_list = metadata["jwks"]["keys"] + elif metadata.get("jwks_uri"): + try: + jwks_uri = metadata["jwks_uri"] + jwks_list = get_http_url([jwks_uri], httpc_params=httpc_params) + jwks_list = jwks_list[0].json() + except Exception as e: + logger.error(f"Failed to download jwks from {jwks_uri}: {e}") + elif metadata.get("signed_jwks_uri"): + try: + signed_jwks_uri = metadata["signed_jwks_uri"] + jwks_list = get_http_url([signed_jwks_uri], httpc_params=httpc_params)[ + 0 + ].json() + except Exception as e: + logger.error(f"Failed to download jwks from {signed_jwks_uri}: {e}") + return jwks_list + +def satisfy_interface(o: object, interface: type) -> bool: + """ + Returns true if and only if an object satisfy an interface. + + :param o: an object (instance of a class) + :type o: object + :param interface: an interface type + :type interface: type + + :returns: True if the object satisfy the interface, otherwise False + """ + for cls_attr in dir(interface): + if cls_attr.startswith("_"): + continue + if not hasattr(o, cls_attr): + return False + if callable(getattr(interface, cls_attr)) and not callable( + getattr(o, cls_attr) + ): + return False + return True \ No newline at end of file diff --git a/pyeudiw/trust/trust_chain.py b/oldies/trust_chain.py similarity index 99% rename from pyeudiw/trust/trust_chain.py rename to oldies/trust_chain.py index 9c0fe93c..13501459 100644 --- a/pyeudiw/trust/trust_chain.py +++ b/oldies/trust_chain.py @@ -1,6 +1,7 @@ from typing import Optional from cryptojwt.jwt import utc_time_sans_frac + from pyeudiw.tools.base_logger import BaseLogger __author__ = "Roland Hedberg" diff --git a/oldies/trust_evaluation_helper.py b/oldies/trust_evaluation_helper.py new file mode 100644 index 00000000..46cc25c7 --- /dev/null +++ b/oldies/trust_evaluation_helper.py @@ -0,0 +1,291 @@ +import logging +from datetime import datetime + +from pyeudiw.federation.trust_chain_builder import TrustChainBuilder +from pyeudiw.federation.trust_chain_validator import StaticTrustChainValidator +from pyeudiw.federation.exceptions import ProtocolMetadataNotFound +from pyeudiw.satosa.exceptions import DiscoveryFailedError +from pyeudiw.storage.db_engine import DBEngine +from pyeudiw.jwt.utils import decode_jwt_payload, is_jwt_format +from pyeudiw.x509.verify import verify_x509_anchor, get_issuer_from_x5c, is_der_format + +from pyeudiw.storage.exceptions import EntryNotFound +from pyeudiw.trust.exceptions import ( + MissingProtocolSpecificJwks, + UnknownTrustAnchor, + InvalidTrustType, + MissingTrustType, + InvalidAnchor +) + +from pyeudiw.federation.statements import EntityStatement +from pyeudiw.federation.exceptions import TimeValidationError +from pyeudiw.federation.policy import TrustChainPolicy, combine + +logger = logging.getLogger(__name__) + + +class TrustEvaluationHelper: + def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): + self.exp: int = 0 + self.trust_chain: list[str] = [] + self.trust_anchor = trust_anchor + self.storage = storage + self.entity_id: str = "" + self.httpc_params = httpc_params + self.is_trusted = False + + for k, v in kwargs.items(): + setattr(self, k, v) + + def _get_evaluation_method(self): + # The trust chain can be either federation or x509 + # If the trust_chain is empty, and we don't have a trust anchor + if not self.trust_chain and not self.trust_anchor: + raise MissingTrustType( + "Static trust chain is not available" + ) + + try: + if is_jwt_format(self.trust_chain[0]): + return self.federation + except TypeError: + pass + + if is_der_format(self.trust_chain[0]): + return self.x509 + + raise InvalidTrustType( + "Invalid Trust Type: trust type not supported" + ) + + def evaluation_method(self) -> bool: + ev_method = self._get_evaluation_method() + return ev_method() + + def _update_chain(self, entity_id: str | None = None, exp: datetime | None = None, trust_chain: list | None = None): + if entity_id is not None: + self.entity_id = entity_id + + if exp is not None: + self.exp = exp + + if trust_chain is not None: + self.trust_chain = trust_chain + + def _handle_federation_chain(self): + _first_statement = decode_jwt_payload(self.trust_chain[-1]) + trust_anchor_eid = self.trust_anchor or _first_statement.get( + 'iss', None) + + if not trust_anchor_eid: + raise UnknownTrustAnchor( + "Unknown Trust Anchor: can't find 'iss' in the " + f"first entity statement: {_first_statement} " + ) + + try: + trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) + except EntryNotFound: + raise UnknownTrustAnchor( + f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " + "a recognizable Trust Anchor." + ) + + decoded_ec = decode_jwt_payload( + trust_anchor['federation']['entity_configuration'] + ) + jwks = decoded_ec.get('jwks', {}).get('keys', []) + + if not jwks: + raise MissingProtocolSpecificJwks( + f"Cannot find any jwks in {decoded_ec}" + ) + + tc = StaticTrustChainValidator( + self.trust_chain, jwks, self.httpc_params + ) + self._update_chain( + entity_id=tc.entity_id, + exp=tc.exp + ) + + _is_valid = False + + try: + _is_valid = tc.validate() + except TimeValidationError: + logger.warn(f"Trust Chain {tc.entity_id} is expired") + except Exception as e: + logger.warn( + f"Cannot validate Trust Chain {tc.entity_id} for the following reason: {e}") + + db_chain = None + + if not _is_valid: + try: + db_chain = self.storage.get_trust_attestation( + self.entity_id + )["federation"]["chain"] + if StaticTrustChainValidator(db_chain, jwks, self.httpc_params).is_valid: + self.is_trusted = True + return self.is_trusted + + except (EntryNotFound, Exception): + pass + + _is_valid = tc.update() + + self._update_chain( + trust_chain=tc.trust_chain, + exp=tc.exp + ) + + # the good trust chain is then stored + self.storage.add_or_update_trust_attestation( + entity_id=self.entity_id, + attestation=tc.trust_chain, + exp=datetime.fromtimestamp(tc.exp) + ) + + self.is_trusted = _is_valid + return _is_valid + + def _handle_x509_pem(self): + trust_anchor_eid = self.trust_anchor or get_issuer_from_x5c( + self.trust_chain) + _is_valid = False + + if not trust_anchor_eid: + raise UnknownTrustAnchor( + "Unknown Trust Anchor: can't find 'iss' in the " + "first entity statement" + ) + + try: + trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) + except EntryNotFound: + raise UnknownTrustAnchor( + f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " + "a recognizable Trust Anchor." + ) + + pem = trust_anchor['x509'].get('pem') + + if pem is None: + raise MissingTrustType( + f"Trust Anchor: '{trust_anchor_eid}' has no x509 trust entity" + ) + + try: + _is_valid = verify_x509_anchor(pem) + except Exception as e: + raise InvalidAnchor( + f"Anchor verification raised the following exception: {e}" + ) + + if not self.is_trusted and trust_anchor['federation'].get("chain", None) is not None: + self._handle_federation_chain() + + self.is_trusted = _is_valid + return _is_valid + + def federation(self) -> bool: + if len(self.trust_chain) == 0: + self.discovery(self.entity_id) + + if self.trust_chain: + self.is_valid = self._handle_federation_chain() + return self.is_valid + + return False + + def x509(self) -> bool: + self.is_valid = self._handle_x509_pem() + return self.is_valid + + def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: + policy_acc = {"metadata": {}, "metadata_policy": {}} + + for policy in policies: + policy_acc = combine(policy, policy_acc) + + self.final_metadata = decode_jwt_payload(self.trust_chain[0]) + + try: + # TODO: there are some cases where the jwks are taken from a uri ... + selected_metadata = { + "metadata": self.final_metadata['metadata'], + "metadata_policy": {} + } + + self.final_metadata = TrustChainPolicy().apply_policy( + selected_metadata, + policy_acc + ) + + return self.final_metadata["metadata"][metadata_type] + except KeyError: + raise ProtocolMetadataNotFound( + f"{metadata_type} not found in the final metadata:" + f" {self.final_metadata['metadata']}" + ) + + def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list[dict]: + return self.get_final_metadata( + metadata_type=metadata_type, + policies=policies + ).get('jwks', {}).get('keys', []) + + def discovery(self, entity_id: str, entity_configuration: EntityStatement | None = None): + """ + Updates fields ``trust_chain`` and ``exp`` based on the discovery process. + + :raises: DiscoveryFailedError: raises an error if the discovery fails. + """ + trust_anchor_eid = self.trust_anchor + _ta_ec = self.storage.get_trust_anchor(entity_id=trust_anchor_eid) + ta_ec = _ta_ec['federation']['entity_configuration'] + + tcbuilder = TrustChainBuilder( + subject=entity_id, + trust_anchor=trust_anchor_eid, + trust_anchor_configuration=ta_ec, + subject_configuration=entity_configuration, + httpc_params=self.httpc_params + ) + + self._update_chain( + trust_chain=tcbuilder.get_trust_chain(), + exp=tcbuilder.exp + ) + is_good = tcbuilder.is_valid + if not is_good: + raise DiscoveryFailedError( + f"Discovery failed for entity {entity_id} with configuration {entity_configuration}" + ) + + @staticmethod + def build_trust_chain_for_entity_id(storage: DBEngine, entity_id, entity_configuration, httpc_params): + """ + Builds a ``TrustEvaluationHelper`` and returns it if the trust chain is valid. + In case the trust chain is invalid, tries to validate it in discovery before returning it. + + :return: The svg data for html, base64 encoded + :rtype: str + """ + db_chain = storage.get_trust_attestation(entity_id) + + trust_evaluation_helper = TrustEvaluationHelper( + storage=storage, + httpc_params=httpc_params, + trust_chain=db_chain + ) + + is_good = trust_evaluation_helper.evaluation_method() + if is_good: + return trust_evaluation_helper + + trust_evaluation_helper.discovery( + entity_id=entity_id, entity_configuration=entity_configuration) + return trust_evaluation_helper diff --git a/oldies/x509_lost_n_found.py b/oldies/x509_lost_n_found.py new file mode 100644 index 00000000..1cc3be3e --- /dev/null +++ b/oldies/x509_lost_n_found.py @@ -0,0 +1,43 @@ + def _handle_x509_pem(self): + trust_anchor_eid = self.trust_anchor or get_issuer_from_x5c( + self.trust_chain) + _is_valid = False + + if not trust_anchor_eid: + raise UnknownTrustAnchor( + "Unknown Trust Anchor: can't find 'iss' in the " + "first entity statement" + ) + + try: + trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) + except EntryNotFound: + raise UnknownTrustAnchor( + f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " + "a recognizable Trust Anchor." + ) + + pem = trust_anchor['x509'].get('pem') + + if pem is None: + raise MissingTrustType( + f"Trust Anchor: '{trust_anchor_eid}' has no x509 trust entity" + ) + + try: + _is_valid = verify_x509_anchor(pem) + except Exception as e: + raise InvalidAnchor( + f"Anchor verification raised the following exception: {e}" + ) + + if not self.is_trusted and trust_anchor['federation'].get("chain", None) is not None: + self._handle_federation_chain() + + self.is_trusted = _is_valid + return _is_valid + + + def x509(self) -> bool: + self.is_valid = self._handle_x509_pem() + return self.is_valid diff --git a/pyeudiw/federation/__init__.py b/pyeudiw/federation/__init__.py index a74de24a..e69de29b 100644 --- a/pyeudiw/federation/__init__.py +++ b/pyeudiw/federation/__init__.py @@ -1,35 +0,0 @@ -from .exceptions import InvalidEntityStatement, InvalidEntityConfiguration -from pyeudiw.federation.schemas.entity_configuration import EntityStatementPayload, EntityConfigurationPayload - - -def is_es(payload: dict) -> None: - """ - Determines if payload dict is a Subordinate Entity Statement - - :param payload: the object to determine if is a Subordinate Entity Statement - :type payload: dict - """ - - try: - EntityStatementPayload(**payload) - if payload["iss"] == payload["sub"]: - _msg = "Invalid Entity Statement: iss and sub cannot be the same" - raise InvalidEntityStatement(_msg) - except ValueError as e: - _msg = f"Invalid Entity Statement: {e}" - raise InvalidEntityStatement(_msg) - - -def is_ec(payload: dict) -> None: - """ - Determines if payload dict is an Entity Configuration - - :param payload: the object to determine if is an Entity Configuration - :type payload: dict - """ - - try: - EntityConfigurationPayload(**payload) - except ValueError as e: - _msg = f"Invalid Entity Configuration: {e}" - raise InvalidEntityConfiguration(_msg) diff --git a/pyeudiw/federation/http_client.py b/pyeudiw/federation/http_client.py index b0f4571e..e9d52dbe 100644 --- a/pyeudiw/federation/http_client.py +++ b/pyeudiw/federation/http_client.py @@ -1,11 +1,14 @@ -import aiohttp import asyncio + +import aiohttp import requests from .exceptions import HttpError -async def fetch(session: aiohttp.ClientSession, url: str, httpc_params: dict) -> requests.Response: +async def fetch( + session: aiohttp.ClientSession, url: str, httpc_params: dict +) -> requests.Response: """ Fetches the content of a URL. @@ -26,7 +29,9 @@ async def fetch(session: aiohttp.ClientSession, url: str, httpc_params: dict) -> return await response -async def fetch_all(session: aiohttp.ClientSession, urls: list[str], httpc_params: dict) -> list[requests.Response]: +async def fetch_all( + session: aiohttp.ClientSession, urls: list[str], httpc_params: dict +) -> list[requests.Response]: """ Fetches the content of a list of URL. @@ -77,14 +82,11 @@ def http_get_sync(urls, httpc_params: dict) -> list[requests.Response]: :rtype: list[requests.Response] """ _conf = { - 'verify': httpc_params['connection']['ssl'], - 'timeout': httpc_params['session']['timeout'] + "verify": httpc_params["connection"]["ssl"], + "timeout": httpc_params["session"]["timeout"], } try: - res = [ - requests.get(url, **_conf) # nosec - B113 - for url in urls - ] + res = [requests.get(url, **_conf) for url in urls] # nosec - B113 except requests.exceptions.ConnectionError as e: raise HttpError(f"Connection error: {e}") @@ -111,9 +113,9 @@ async def http_get_async(urls, httpc_params: dict) -> list[requests.Response]: :returns: the list of responses :rtype: list[requests.Response] """ - if not isinstance(httpc_params['session']['timeout'], aiohttp.ClientTimeout): - httpc_params['session']['timeout'] = aiohttp.ClientTimeout( - total=httpc_params['session']['timeout'] + if not isinstance(httpc_params["session"]["timeout"], aiohttp.ClientTimeout): + httpc_params["session"]["timeout"] = aiohttp.ClientTimeout( + total=httpc_params["session"]["timeout"] ) async with aiohttp.ClientSession(**httpc_params.get("session", {})) as session: diff --git a/pyeudiw/federation/policy.py b/pyeudiw/federation/policy.py index 7845bb9a..216d6192 100644 --- a/pyeudiw/federation/policy.py +++ b/pyeudiw/federation/policy.py @@ -1,5 +1,6 @@ import logging from typing import Optional + from .exceptions import PolicyError __author__ = "Roland Hedberg" @@ -33,8 +34,15 @@ def combine_add(s1, s2): return list(set1.union(set2)) -POLICY_FUNCTIONS = {"subset_of", "superset_of", - "one_of", "add", "value", "default", "essential"} +POLICY_FUNCTIONS = { + "subset_of", + "superset_of", + "one_of", + "add", + "value", + "default", + "essential", +} OP2FUNC = { "subset_of": combine_subset_of, @@ -72,8 +80,8 @@ def do_value(superior, child, policy): def do_default(superior, child, policy): # A child's default can not override a superiors if policy in superior and policy in child: - if superior['default'] == child['default']: - return superior['default'] + if superior["default"] == child["default"]: + return superior["default"] else: raise PolicyError("Not allowed to change default") elif policy in superior: @@ -87,7 +95,7 @@ def do_essential(superior, child, policy): # but not the other way around if policy in superior and policy in child: - if not superior[policy] and child['essential']: + if not superior[policy] and child["essential"]: return True else: return superior[policy] @@ -104,7 +112,7 @@ def do_essential(superior, child, policy): "add": do_sub_one_super_add, "value": do_value, "default": do_default, - "essential": do_essential + "essential": do_essential, } @@ -128,58 +136,63 @@ def combine_claim_policy(superior, child): return {"value": superior["value"], "essential": child["essential"]} else: raise PolicyError( - f"value can only be combined with essential, not {child_set}") + f"value can only be combined with essential, not {child_set}" + ) elif "value" in child_set: if child["value"] != superior["value"]: # Not OK - raise PolicyError( - "Child can not set another value then superior") + raise PolicyError("Child can not set another value then superior") else: return superior else: raise PolicyError( - f"Not allowed combination of policies: {superior} + {child}") + f"Not allowed combination of policies: {superior} + {child}" + ) return superior else: if "essential" in superior_set and "essential" in child_set: # can only go from False to True - if superior["essential"] != child["essential"] and child["essential"] is False: + if ( + superior["essential"] != child["essential"] + and child["essential"] is False + ): raise PolicyError("Essential can not go from True to False") comb_policy = superior_set.union(child_set) if "one_of" in comb_policy: if "subset_of" in comb_policy or "superset_of" in comb_policy: raise PolicyError( - "one_of can not be combined with subset_of/superset_of") + "one_of can not be combined with subset_of/superset_of" + ) rule = {} for policy in comb_policy: rule[policy] = DO_POLICY[policy](superior, child, policy) - if comb_policy == {'superset_of', 'subset_of'}: + if comb_policy == {"superset_of", "subset_of"}: # make sure the subset_of is a superset of superset_of. - if set(rule['superset_of']).difference(set(rule['subset_of'])): - raise PolicyError('superset_of not a super set of subset_of') - elif comb_policy == {'superset_of', 'subset_of', 'default'}: + if set(rule["superset_of"]).difference(set(rule["subset_of"])): + raise PolicyError("superset_of not a super set of subset_of") + elif comb_policy == {"superset_of", "subset_of", "default"}: # make sure the subset_of is a superset of superset_of. - if set(rule['superset_of']).difference(set(rule['subset_of'])): - raise PolicyError('superset_of not a super set of subset_of') - if set(rule['default']).difference(set(rule['subset_of'])): - raise PolicyError('default not a sub set of subset_of') - if set(rule['superset_of']).difference(set(rule['default'])): - raise PolicyError('default not a super set of subset_of') - elif comb_policy == {'subset_of', 'default'}: - if set(rule['default']).difference(set(rule['subset_of'])): - raise PolicyError('default not a sub set of subset_of') - elif comb_policy == {'superset_of', 'default'}: - if set(rule['superset_of']).difference(set(rule['default'])): - raise PolicyError('default not a super set of subset_of') - elif comb_policy == {'one_of', 'default'}: - if isinstance(rule['default'], list): - if set(rule['default']).difference(set(rule['one_of'])): - raise PolicyError('default not a super set of one_of') + if set(rule["superset_of"]).difference(set(rule["subset_of"])): + raise PolicyError("superset_of not a super set of subset_of") + if set(rule["default"]).difference(set(rule["subset_of"])): + raise PolicyError("default not a sub set of subset_of") + if set(rule["superset_of"]).difference(set(rule["default"])): + raise PolicyError("default not a super set of subset_of") + elif comb_policy == {"subset_of", "default"}: + if set(rule["default"]).difference(set(rule["subset_of"])): + raise PolicyError("default not a sub set of subset_of") + elif comb_policy == {"superset_of", "default"}: + if set(rule["superset_of"]).difference(set(rule["default"])): + raise PolicyError("default not a super set of subset_of") + elif comb_policy == {"one_of", "default"}: + if isinstance(rule["default"], list): + if set(rule["default"]).difference(set(rule["one_of"])): + raise PolicyError("default not a super set of one_of") else: - if {rule['default']}.difference(set(rule['one_of'])): - raise PolicyError('default not a super set of one_of') + if {rule["default"]}.difference(set(rule["one_of"])): + raise PolicyError("default not a super set of one_of") return rule @@ -190,8 +203,8 @@ def combine(superior: dict, sub: dict) -> dict: :param sub: Dictionary with two keys metadata_policy and metadata :return: """ - sup_metadata = superior.get('metadata', {}) - sub_metadata = sub.get('metadata', {}) + sup_metadata = superior.get("metadata", {}) + sub_metadata = sub.get("metadata", {}) sup_m_set = set(sup_metadata.keys()) if sub_metadata: chi_m_set = set(sub_metadata.keys()) @@ -200,28 +213,27 @@ def combine(superior: dict, sub: dict) -> dict: for key in _overlap: if sup_metadata[key] != sub_metadata[key]: raise PolicyError( - 'A subordinate is not allowed to set a value different then the superiors') + "A subordinate is not allowed to set a value different then the superiors" + ) _metadata = sup_metadata.copy() _metadata.update(sub_metadata) - superior['metadata'] = _metadata + superior["metadata"] = _metadata # Now for metadata_policies - _sup_policy = superior.get('metadata_policy', {}) - _sub_policy = sub.get('metadata_policy', {}) + _sup_policy = superior.get("metadata_policy", {}) + _sub_policy = sub.get("metadata_policy", {}) if _sub_policy: sup_set = set(_sup_policy.keys()) - chi_set = set(sub['metadata_policy'].keys()) + chi_set = set(sub["metadata_policy"].keys()) # A metadata_policy claim can not change a metadata claim for claim in chi_set.intersection(sup_m_set): - combine_claim_policy( - {'value': sup_metadata[claim]}, _sub_policy[claim]) + combine_claim_policy({"value": sup_metadata[claim]}, _sub_policy[claim]) _mp = {} for claim in set(sup_set).intersection(chi_set): - _mp[claim] = combine_claim_policy( - _sup_policy[claim], _sub_policy[claim]) + _mp[claim] = combine_claim_policy(_sup_policy[claim], _sub_policy[claim]) for claim in sup_set.difference(chi_set): _mp[claim] = _sup_policy[claim] @@ -229,7 +241,7 @@ def combine(superior: dict, sub: dict) -> dict: for claim in chi_set.difference(sup_set): _mp[claim] = _sub_policy[claim] - superior['metadata_policy'] = _mp + superior["metadata_policy"] = _mp return superior @@ -277,29 +289,31 @@ def gather_policies(self, chain, entity_type): :return: The combined metadata policy """ - _rule = {'metadata_policy': {}, 'metadata': {}} - for _item in ['metadata_policy', 'metadata']: + _rule = {"metadata_policy": {}, "metadata": {}} + for _item in ["metadata_policy", "metadata"]: try: _rule[_item] = chain[0][_item][entity_type] except KeyError: pass for es in chain[1:]: - _sub_policy = {'metadata_policy': {}, 'metadata': {}} - for _item in ['metadata_policy', 'metadata']: + _sub_policy = {"metadata_policy": {}, "metadata": {}} + for _item in ["metadata_policy", "metadata"]: try: _sub_policy[_item] = es[_item][entity_type] except KeyError: pass - if _sub_policy == {'metadata_policy': {}, 'metadata': {}}: + if _sub_policy == {"metadata_policy": {}, "metadata": {}}: continue - _overlap = set(_sub_policy['metadata_policy']).intersection( - set(_sub_policy['metadata'])) + _overlap = set(_sub_policy["metadata_policy"]).intersection( + set(_sub_policy["metadata"]) + ) if _overlap: # Not allowed raise PolicyError( - 'Claim appearing both in metadata and metadata_policy not allowed') + "Claim appearing both in metadata and metadata_policy not allowed" + ) _rule = combine(_rule, _sub_policy) return _rule @@ -322,55 +336,74 @@ def _apply_metadata_policy(self, metadata, metadata_policy): # The is for claims that can have only one value # Should not be but ... if isinstance(metadata[claim], list): - _claim = [c for c in metadata[claim] if - c in metadata_policy[claim]['one_of']] + _claim = [ + c + for c in metadata[claim] + if c in metadata_policy[claim]["one_of"] + ] if _claim: metadata[claim] = _claim[0] else: raise PolicyError( - "{}: None of {} among {}".format(claim, metadata[claim], - metadata_policy[claim]['one_of'])) + "{}: None of {} among {}".format( + claim, + metadata[claim], + metadata_policy[claim]["one_of"], + ) + ) else: - if metadata[claim] in metadata_policy[claim]['one_of']: + if metadata[claim] in metadata_policy[claim]["one_of"]: pass else: raise PolicyError( - f"{metadata[claim]} not among {metadata_policy[claim]['one_of']}") + f"{metadata[claim]} not among {metadata_policy[claim]['one_of']}" + ) else: # The following is for claims that can have lists of values if "add" in metadata_policy[claim]: metadata[claim] = list( - union(metadata[claim], metadata_policy[claim]['add'])) + union(metadata[claim], metadata_policy[claim]["add"]) + ) if "subset_of" in metadata_policy[claim]: - _val = set(metadata_policy[claim]['subset_of']).intersection( - set(metadata[claim])) + _val = set(metadata_policy[claim]["subset_of"]).intersection( + set(metadata[claim]) + ) if _val: metadata[claim] = list(_val) else: - raise PolicyError("{} not subset of {}".format(metadata[claim], - metadata_policy[claim][ - 'subset_of'])) + raise PolicyError( + "{} not subset of {}".format( + metadata[claim], metadata_policy[claim]["subset_of"] + ) + ) if "superset_of" in metadata_policy[claim]: - if set(metadata_policy[claim]['superset_of']).difference( - set(metadata[claim])): - raise PolicyError("{} not superset of {}".format(metadata[claim], - metadata_policy[claim][ - 'superset_of'])) + if set(metadata_policy[claim]["superset_of"]).difference( + set(metadata[claim]) + ): + raise PolicyError( + "{} not superset of {}".format( + metadata[claim], + metadata_policy[claim]["superset_of"], + ) + ) else: pass # In policy but not in metadata for claim in policy_set.difference(metadata_set): if "value" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['value'] + metadata[claim] = metadata_policy[claim]["value"] elif "add" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['add'] + metadata[claim] = metadata_policy[claim]["add"] elif "default" in metadata_policy[claim]: - metadata[claim] = metadata_policy[claim]['default'] + metadata[claim] = metadata_policy[claim]["default"] if claim not in metadata: - if "essential" in metadata_policy[claim] and metadata_policy[claim]["essential"]: + if ( + "essential" in metadata_policy[claim] + and metadata_policy[claim]["essential"] + ): raise PolicyError(f"Essential claim '{claim}' missing") return metadata @@ -384,12 +417,11 @@ def apply_policy(self, metadata: dict, policy: dict) -> dict: :return: A metadata statement that adheres to a metadata policy """ - if policy['metadata_policy']: - metadata = self._apply_metadata_policy( - metadata, policy['metadata_policy']) + if policy["metadata_policy"]: + metadata = self._apply_metadata_policy(metadata, policy["metadata_policy"]) # All that are in metadata but not in policy should just remain - metadata.update(policy['metadata']) + metadata.update(policy["metadata"]) return metadata @@ -399,7 +431,7 @@ def _policy(self, trust_chain, entity_type: str): logger.debug("Combined policy: %s", combined_policy) try: # This should be the entity configuration - metadata = trust_chain.verified_chain[-1]['metadata'][entity_type] + metadata = trust_chain.verified_chain[-1]["metadata"][entity_type] except KeyError: return None else: @@ -409,7 +441,7 @@ def _policy(self, trust_chain, entity_type: str): logger.debug(f"After applied policy: {_metadata}") return _metadata - def __call__(self, trust_chain, entity_type: Optional[str] = ''): + def __call__(self, trust_chain, entity_type: Optional[str] = ""): """ :param trust_chain: TrustChain instance :param entity_type: Which Entity Type the entity are @@ -417,28 +449,14 @@ def __call__(self, trust_chain, entity_type: Optional[str] = ''): if len(trust_chain.verified_chain) > 1: if entity_type: trust_chain.metadata[entity_type] = self._policy( - trust_chain, entity_type) + trust_chain, entity_type + ) else: - for _type in trust_chain.verified_chain[-1]['metadata'].keys(): - trust_chain.metadata[_type] = self._policy( - trust_chain, _type) + for _type in trust_chain.verified_chain[-1]["metadata"].keys(): + trust_chain.metadata[_type] = self._policy(trust_chain, _type) else: - trust_chain.metadata = trust_chain.verified_chain[0]["metadata"][entity_type] + trust_chain.metadata = trust_chain.verified_chain[0]["metadata"][ + entity_type + ] trust_chain.combined_policy[entity_type] = {} - -def diff2policy(new, old): - res = {} - for claim in set(new).intersection(set(old)): - if new[claim] == old[claim]: - continue - else: - res[claim] = {'value': new[claim]} - - for claim in set(new).difference(set(old)): - if claim in ['contacts']: - res[claim] = {'add': new[claim]} - else: - res[claim] = {'value': new[claim]} - - return res diff --git a/pyeudiw/federation/schemas/entity_configuration.py b/pyeudiw/federation/schemas/entity_configuration.py index 901cc429..e97c2995 100644 --- a/pyeudiw/federation/schemas/entity_configuration.py +++ b/pyeudiw/federation/schemas/entity_configuration.py @@ -1,10 +1,12 @@ from typing import List, Literal, Optional from pydantic import BaseModel, HttpUrl, field_validator -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from pyeudiw.federation.schemas.federation_entity import FederationEntity -from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty +from pyeudiw.federation.schemas.openid_credential_verifier import ( + OpenIDCredentialVerifier, +) from pyeudiw.jwk.schemas.public import JwksSchema from pyeudiw.tools.schema_utils import check_algorithm @@ -16,12 +18,12 @@ class EntityConfigurationHeader(BaseModel): @field_validator("alg") @classmethod - def _check_alg(cls, alg, info: FieldValidationInfo): + def _check_alg(cls, alg, info: ValidationInfo): return check_algorithm(alg, info) class EntityConfigurationMetadataSchema(BaseModel): - wallet_relying_party: WalletRelyingParty + openid_credential_verifier: OpenIDCredentialVerifier federation_entity: FederationEntity @@ -35,7 +37,7 @@ class EntityConfigurationPayload(BaseModel): authority_hints: List[HttpUrl] -class EntityStatementPayload(BaseModel, extra='forbid'): +class EntityStatementPayload(BaseModel, extra="forbid"): exp: int iat: int iss: HttpUrl diff --git a/pyeudiw/federation/schemas/federation_configuration.py b/pyeudiw/federation/schemas/federation_configuration.py index 8c3edcb7..bfa24a93 100644 --- a/pyeudiw/federation/schemas/federation_configuration.py +++ b/pyeudiw/federation/schemas/federation_configuration.py @@ -1,5 +1,8 @@ from pydantic import BaseModel, HttpUrl -from pyeudiw.federation.schemas.wallet_relying_party import SigningAlgValuesSupported + +from pyeudiw.federation.schemas.openid_credential_verifier import ( + SigningAlgValuesSupported, +) from pyeudiw.jwk.schemas.public import JwkSchema diff --git a/pyeudiw/federation/schemas/wallet_relying_party.py b/pyeudiw/federation/schemas/openid_credential_verifier.py similarity index 97% rename from pyeudiw/federation/schemas/wallet_relying_party.py rename to pyeudiw/federation/schemas/openid_credential_verifier.py index 0c1efa2f..9946cf8d 100644 --- a/pyeudiw/federation/schemas/wallet_relying_party.py +++ b/pyeudiw/federation/schemas/openid_credential_verifier.py @@ -1,7 +1,9 @@ from enum import Enum from typing import List -from pyeudiw.jwk.schemas.public import JwksSchema + from pydantic import BaseModel, HttpUrl, PositiveInt + +from pyeudiw.jwk.schemas.public import JwksSchema from pyeudiw.openid4vp.schemas.vp_formats import VpFormats @@ -47,7 +49,7 @@ class AuthorizationSignedResponseAlg(str, Enum): es512 = "ES512" -class WalletRelyingParty(BaseModel): +class OpenIDCredentialVerifier(BaseModel): application_type: str client_id: HttpUrl client_name: str diff --git a/pyeudiw/federation/statements.py b/pyeudiw/federation/statements.py index bb81e32c..9d3e491d 100644 --- a/pyeudiw/federation/statements.py +++ b/pyeudiw/federation/statements.py @@ -1,52 +1,31 @@ from __future__ import annotations +import logging +from copy import deepcopy + import pydantic -from copy import deepcopy from pyeudiw.federation.exceptions import ( - UnknownKid, + InvalidEntityHeader, + InvalidEntityStatementPayload, MissingJwksClaim, MissingTrustMark, TrustAnchorNeeded, - InvalidEntityHeader, - InvalidEntityStatementPayload + UnknownKid, ) from pyeudiw.federation.schemas.entity_configuration import ( EntityConfigurationHeader, - EntityStatementPayload + EntityStatementPayload, ) +from pyeudiw.jwk.jwks import find_jwk_by_kid from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header -from pyeudiw.jwk import find_jwk_by_kid +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.tools.utils import get_http_url -import logging - OIDCFED_FEDERATION_WELLKNOWN_URL = ".well-known/openid-federation" logger = logging.getLogger(__name__) -def jwks_from_jwks_uri(jwks_uri: str, httpc_params: dict, http_async: bool = True) -> list[dict]: - """ - Retrieves jwks from an entity uri. - - :param jwks_uri: the uri where the jwks are located. - :type jwks_uri: str - :param httpc_params: parameters to perform http requests. - :type httpc_params: dict - :param http_async: if is set to True the operation will be performed in async (deafault True) - :type http_async: bool - - :returns: A list of entity jwks. - :rtype: list[dict] - """ - - response = get_http_url(jwks_uri, httpc_params, http_async) - jwks = [i.json() for i in response] - - return jwks - - def get_federation_jwks(jwt_payload: dict) -> list[dict]: """ Returns the list of JWKS inside a JWT payload. @@ -63,7 +42,9 @@ def get_federation_jwks(jwt_payload: dict) -> list[dict]: return keys -def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[bytes]: +def get_entity_statements( + urls: list[str] | str, httpc_params: dict, http_async: bool = True +) -> list[bytes]: """ Fetches an entity statement from the specified urls. @@ -82,13 +63,12 @@ def get_entity_statements(urls: list[str] | str, httpc_params: dict, http_async: for url in urls: logger.debug(f"Starting Entity Statement Request to {url}") - return [ - i.content for i in - get_http_url(urls, httpc_params, http_async) - ] + return [i.content for i in get_http_url(urls, httpc_params, http_async)] -def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, http_async: bool = False) -> list[bytes]: +def get_entity_configurations( + subjects: list[str] | str, httpc_params: dict, http_async: bool = False +) -> list[bytes]: """ Fetches an entity configuration from the specified subjects. @@ -113,10 +93,7 @@ def get_entity_configurations(subjects: list[str] | str, httpc_params: dict, htt urls.append(url) logger.info(f"Starting Entity Configuration Request for {url}") - return [ - i.content for i in - get_http_url(urls, httpc_params, http_async) - ] + return [i.content for i in get_http_url(urls, httpc_params, http_async)] class TrustMark: @@ -172,7 +149,7 @@ def validate_by(self, ec: dict) -> bool: f"{self.header.get('kid')} not found in {ec.jwks}" ) - _jwk = find_jwk_by_kid(_kid, ec.jwks) + _jwk = find_jwk_by_kid(ec.jwks, _kid) # verify signature jwsh = JWSHelper(_jwk) @@ -189,13 +166,11 @@ def validate_by_its_issuer(self) -> bool: """ if not self.issuer_entity_configuration: self.issuer_entity_configuration = [ - i.content for i in - get_entity_configurations( - self.iss, self.httpc_params, False - ) + i.content + for i in get_entity_configurations(self.iss, self.httpc_params, False) ] - _kid = self.header.get('kid') + _kid = self.header.get("kid") try: ec = EntityStatement(self.issuer_entity_configuration[0]) ec.validate_by_itself() @@ -203,16 +178,16 @@ def validate_by_its_issuer(self) -> bool: logger.warning( f"Trust Mark validation failed by its Issuer: " f"{_kid} not found in " - f"{self.issuer_entity_configuration.jwks}") + f"{self.issuer_entity_configuration.jwks}" + ) return False except Exception: - logger.warning( - f"Issuer {self.iss} of trust mark {self.id} is not valid.") + logger.warning(f"Issuer {self.iss} of trust mark {self.id} is not valid.") self.is_valid = False return False # verify signature - _jwk = find_jwk_by_kid(_kid, ec.jwks) + _jwk = find_jwk_by_kid(ec.jwks, _kid) jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) self.is_valid = True @@ -274,7 +249,6 @@ def __init__( # a dict with sup_sub : entity statement issued for self self.verified_by_superiors = {} - self.failed_by_superiors = {} # a dict with the paylaod of valid entity statements for each descendant subject self.verified_descendant_statements = {} @@ -286,7 +260,9 @@ def __init__( self.verified_trust_marks = [] self.is_valid = False - def update_trust_anchor_conf(self, trust_anchor_entity_conf: 'EntityStatement') -> None: + def update_trust_anchor_conf( + self, trust_anchor_entity_conf: "EntityStatement" + ) -> None: """ Updates the internal Trust Anchor conf. @@ -311,11 +287,10 @@ def validate_by_itself(self) -> bool: _kid = self.header.get("kid") if _kid not in self.kids: - raise UnknownKid( - f"{_kid} not found in {self.jwks}") # pragma: no cover + raise UnknownKid(f"{_kid} not found in {self.jwks}") # pragma: no cover # verify signature - _jwk = find_jwk_by_kid(_kid, self.jwks) + _jwk = find_jwk_by_kid(self.jwks, _kid) jwsh = JWSHelper(_jwk) jwsh.verify(self.jwt) self.is_valid = True @@ -366,7 +341,8 @@ def validate_by_allowed_trust_marks(self) -> bool: if not trust_marks: raise MissingTrustMark( - "Required Trust marks are missing.") # pragma: no cover + "Required Trust marks are missing." + ) # pragma: no cover trust_mark_issuers_by_id = self.trust_anchor_entity_conf.payload.get( "trust_marks_issuers", {} @@ -406,8 +382,7 @@ def validate_by_allowed_trust_marks(self) -> bool: elif id_issuers and trust_mark.iss in id_issuers: is_valid = trust_mark.validate_by_its_issuer() elif not id_issuers: - is_valid = trust_mark.validate_by( - self.trust_anchor_entity_conf) + is_valid = trust_mark.validate_by(self.trust_anchor_entity_conf) if not trust_mark.is_valid: is_valid = False @@ -441,7 +416,8 @@ def get_superiors( """ # apply limits if defined authority_hints = authority_hints or deepcopy( - self.payload.get("authority_hints", [])) + self.payload.get("authority_hints", []) + ) if ( max_authority_hints and authority_hints != authority_hints[:max_authority_hints] @@ -473,16 +449,14 @@ def get_superiors( jwts = [self.trust_anchor_configuration] if not jwts: - jwts = get_entity_configurations( - authority_hints, self.httpc_params, False - ) + jwts = get_entity_configurations(authority_hints, self.httpc_params, False) for jwt in jwts: try: ec = self.__class__( jwt, httpc_params=self.httpc_params, - trust_anchor_entity_conf=self.trust_anchor_entity_conf + trust_anchor_entity_conf=self.trust_anchor_entity_conf, ) except Exception as e: logger.warning(f"Get Entity Configuration for {jwt}: {e}") @@ -521,26 +495,23 @@ def validate_descendant_statement(self, jwt: str) -> bool: EntityConfigurationHeader(**header) except pydantic.ValidationError as e: raise InvalidEntityHeader( # pragma: no cover - f"Trust Mark validation failed: " - f"{e}" + f"Trust Mark validation failed: " f"{e}" ) try: EntityStatementPayload(**payload) except pydantic.ValidationError as e: raise InvalidEntityStatementPayload( # pragma: no cover - f"Trust Mark validation failed: " - f"{e}" + f"Trust Mark validation failed: " f"{e}" ) _kid = header.get("kid") if _kid not in self.kids: - raise UnknownKid( - f"{_kid} not found in {self.jwks}") + raise UnknownKid(f"{_kid} not found in {self.jwks}") # verify signature - _jwk = find_jwk_by_kid(_kid, self.jwks) + _jwk = find_jwk_by_kid(self.jwks, _kid) jwsh = JWSHelper(_jwk) payload = jwsh.verify(jwt) @@ -548,7 +519,7 @@ def validate_descendant_statement(self, jwt: str) -> bool: self.verified_descendant_statements_as_jwt[payload["sub"]] = jwt return self.verified_descendant_statements - def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str: + def validate_by_superior_statement(self, jwt: str, ec: "EntityStatement") -> str: """ validates self with the jwks contained in statement of the superior :param jwt: the statement issued by a superior in form of JWT @@ -566,7 +537,7 @@ def validate_by_superior_statement(self, jwt: str, ec: 'EntityStatement') -> str ec.validate_by_itself() ec.validate_descendant_statement(jwt) _jwks = get_federation_jwks(payload) - _jwk = find_jwk_by_kid(self.header["kid"], _jwks) + _jwk = find_jwk_by_kid(_jwks, self.header["kid"]) jwsh = JWSHelper(_jwk) payload = jwsh.verify(self.jwt) @@ -629,16 +600,12 @@ def validate_by_superiors( logger.info(f"Getting entity statements from {_url}") jwts = get_entity_statements([_url], self.httpc_params, False) if not jwts: - logger.error( - f"Empty response for {_url}" - ) + logger.error(f"Empty response for {_url}") jwt = jwts[0] if jwt: self.validate_by_superior_statement(jwt, ec) else: - logger.error( - f"JWT validation for {_url}" - ) + logger.error(f"JWT validation for {_url}") return self.verified_by_superiors diff --git a/pyeudiw/federation/trust_chain/parse.py b/pyeudiw/federation/trust_chain/parse.py deleted file mode 100644 index 8e8238ac..00000000 --- a/pyeudiw/federation/trust_chain/parse.py +++ /dev/null @@ -1,6 +0,0 @@ -from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey - - -def get_public_key_from_trust_chain(trust_chain: list[str]) -> ECKey | RSAKey | dict: - raise NotImplementedError("TODO") diff --git a/pyeudiw/federation/trust_chain_builder.py b/pyeudiw/federation/trust_chain_builder.py index 1e3a5446..aca9a4f3 100644 --- a/pyeudiw/federation/trust_chain_builder.py +++ b/pyeudiw/federation/trust_chain_builder.py @@ -1,24 +1,18 @@ import datetime import json import logging - from collections import OrderedDict from typing import Union -from .policy import TrustChainPolicy +from pyeudiw.tools.utils import datetime_from_timestamp from .exceptions import ( InvalidEntityStatement, InvalidRequiredTrustMark, - MetadataDiscoveryException -) - -from .statements import ( - get_entity_configurations, - EntityStatement, + MetadataDiscoveryException, ) -from pyeudiw.tools.utils import datetime_from_timestamp - +from .policy import TrustChainPolicy +from .statements import EntityStatement, get_entity_configurations logger = logging.getLogger(__name__) @@ -82,7 +76,8 @@ def __init__( ) subject_configuration.update_trust_anchor_conf( - trust_anchor_configuration) + trust_anchor_configuration + ) subject_configuration.validate_by_itself() except Exception as e: _msg = f"Entity Configuration for {self.trust_anchor} failed: {e}" @@ -90,8 +85,7 @@ def __init__( raise InvalidEntityStatement(_msg) elif isinstance(trust_anchor_configuration, str): trust_anchor_configuration = EntityStatement( - jwt=trust_anchor_configuration, - httpc_params=self.httpc_params + jwt=trust_anchor_configuration, httpc_params=self.httpc_params ) self.trust_anchor_configuration = trust_anchor_configuration @@ -154,8 +148,7 @@ def apply_metadata_policy(self) -> dict: # once I filtered a concrete and unique trust path I can apply the metadata policy if path_found: logger.info(f"Found a trust path: {self.trust_path}") - self.final_metadata = self.subject_configuration.payload.get( - "metadata", {}) + self.final_metadata = self.subject_configuration.payload.get("metadata", {}) if not self.final_metadata: logger.error( f"Missing metadata in {self.subject_configuration.payload['metadata']}" @@ -164,9 +157,8 @@ def apply_metadata_policy(self) -> dict: for i in range(len(self.trust_path))[::-1]: self.trust_path[i - 1].sub - _pol = ( - self.trust_path[i] - .verified_descendant_statements.get("metadata_policy", {}) + _pol = self.trust_path[i].verified_descendant_statements.get( + "metadata_policy", {} ) for md_type, md in _pol.items(): if not self.final_metadata.get(md_type): @@ -197,8 +189,7 @@ def discovery(self) -> bool: :returns: the validity status of the updated chain :rtype: bool """ - logger.info( - f"Starting a Walk into Metadata Discovery for {self.subject}") + logger.info(f"Starting a Walk into Metadata Discovery for {self.subject}") self.tree_of_trust[0] = [self.subject_configuration] ecs_history = [] @@ -254,8 +245,7 @@ def get_trust_anchor_configuration(self) -> None: with the entity statement of trust anchor. """ if not isinstance(self.trust_anchor, EntityStatement): - logger.info( - f"Get Trust Anchor Entity Configuration for {self.subject}") + logger.info(f"Get Trust Anchor Entity Configuration for {self.subject}") ta_jwt = get_entity_configurations( self.trust_anchor, httpc_params=self.httpc_params )[0] @@ -300,8 +290,9 @@ def get_subject_configuration(self) -> None: self.subject, httpc_params=self.httpc_params ) self.subject_configuration = EntityStatement( - jwts[0], trust_anchor_entity_conf=self.trust_anchor_configuration, - httpc_params=self.httpc_params + jwts[0], + trust_anchor_entity_conf=self.trust_anchor_configuration, + httpc_params=self.httpc_params, ) self.subject_configuration.validate_by_itself() except Exception as e: @@ -346,9 +337,9 @@ def get_trust_chain(self) -> list[str]: # we keep just the leaf's and TA's EC, all the intermediates EC will be dropped ta_ec: str = "" for stat in self.trust_path: - if (self.subject == stat.sub == stat.iss): + if self.subject == stat.sub == stat.iss: res.append(stat.jwt) - elif (self.trust_anchor_configuration.sub == stat.sub == stat.iss): + elif self.trust_anchor_configuration.sub == stat.sub == stat.iss: ta_ec = stat.jwt if stat.verified_descendant_statements: diff --git a/pyeudiw/federation/trust_chain_validator.py b/pyeudiw/federation/trust_chain_validator.py index 9fae4907..465921ea 100644 --- a/pyeudiw/federation/trust_chain_validator.py +++ b/pyeudiw/federation/trust_chain_validator.py @@ -1,22 +1,22 @@ import logging -from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.tools.utils import iat_now -from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header -from pyeudiw.federation import is_es -from pyeudiw.federation.policy import TrustChainPolicy -from pyeudiw.federation.statements import ( - get_entity_configurations, - get_entity_statements -) + from pyeudiw.federation.exceptions import ( + InvalidEntityStatement, + KeyValidationError, MissingTrustAnchorPublicKey, TimeValidationError, - KeyValidationError, - InvalidEntityStatement ) - -from pyeudiw.jwk import find_jwk_by_kid -from pyeudiw.jwk.exceptions import KidNotFoundError, InvalidKid +from pyeudiw.federation.policy import TrustChainPolicy +from pyeudiw.federation.statements import ( + get_entity_configurations, + get_entity_statements, +) +from pyeudiw.federation.utils import is_es +from pyeudiw.jwk.jwks import find_jwk_by_kid +from pyeudiw.jwk.exceptions import InvalidKid, KidNotFoundError +from pyeudiw.jwt.jws_helper import JWSHelper +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload +from pyeudiw.tools.utils import iat_now logger = logging.getLogger(__name__) @@ -80,9 +80,7 @@ def _validate_exp(self, exp: int) -> None: """ if not self._check_expired(exp): - raise TimeValidationError( - "Expired validation error" - ) + raise TimeValidationError("Expired validation error") def _validate_keys(self, fed_jwks: list[dict], st_header: dict) -> None: """ @@ -117,9 +115,7 @@ def validate(self) -> bool: """ # start from the last entity statement - rev_tc = [ - i for i in reversed(self.trust_chain) - ] + rev_tc = [i for i in reversed(self.trust_chain)] # inspect the entity statement kid header to know which # TA's public key to use for the validation @@ -127,14 +123,10 @@ def validate(self) -> bool: es_header = decode_jwt_header(last_element) es_payload = decode_jwt_payload(last_element) - ta_jwk = find_jwk_by_kid( - es_header.get("kid", None), self.trust_anchor_jwks - ) + ta_jwk = find_jwk_by_kid(self.trust_anchor_jwks, es_header.get("kid", None)) if not ta_jwk: - logger.error( - "Trust chain validation error: TA jwks not found." - ) + logger.error("Trust chain validation error: TA jwks not found.") return False # Validate the last statement with ta_jwk @@ -165,9 +157,7 @@ def validate(self) -> bool: st_payload = decode_jwt_payload(st) try: - jwk = find_jwk_by_kid( - st_header.get("kid", None), fed_jwks - ) + jwk = find_jwk_by_kid(fed_jwks, st_header.get("kid", None)) except (KidNotFoundError, InvalidKid): logger.error( f"Trust chain validation KidNotFoundError: {st_header} not in {fed_jwks}" @@ -226,7 +216,7 @@ def _update_st(self, st: str) -> str: :rtype: str """ payload = decode_jwt_payload(st) - iss = payload['iss'] + iss = payload["iss"] try: is_es(payload) @@ -237,9 +227,7 @@ def _update_st(self, st: str) -> str: # if it has the source_endpoint let's try a fast renewal download_url: str = payload.get("source_endpoint", "") if download_url: - jwt = self._retrieve_es( - f"{download_url}?sub={payload['sub']}", iss - ) + jwt = self._retrieve_es(f"{download_url}?sub={payload['sub']}", iss) else: ec = self._retrieve_ec(iss) ec_data = decode_jwt_payload(ec) diff --git a/pyeudiw/federation/utils.py b/pyeudiw/federation/utils.py new file mode 100644 index 00000000..687d2553 --- /dev/null +++ b/pyeudiw/federation/utils.py @@ -0,0 +1,23 @@ +from pyeudiw.federation.schemas.entity_configuration import ( + EntityStatementPayload, +) + +from .exceptions import InvalidEntityStatement + + +def is_es(payload: dict) -> None: + """ + Determines if payload dict is a Subordinate Entity Statement + + :param payload: the object to determine if is a Subordinate Entity Statement + :type payload: dict + """ + + try: + EntityStatementPayload(**payload) + if payload["iss"] == payload["sub"]: + _msg = "Invalid Entity Statement: iss and sub cannot be the same" + raise InvalidEntityStatement(_msg) + except ValueError as e: + _msg = f"Invalid Entity Statement: {e}" + raise InvalidEntityStatement(_msg) diff --git a/pyeudiw/jwk/__init__.py b/pyeudiw/jwk/__init__.py index 9b9cc1ae..5a18b381 100644 --- a/pyeudiw/jwk/__init__.py +++ b/pyeudiw/jwk/__init__.py @@ -8,12 +8,9 @@ from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jwk.rsa import new_rsa_key -from .exceptions import InvalidKid, KidNotFoundError, InvalidJwk +from pyeudiw.jwk.exceptions import InvalidKid, KidNotFoundError -KEY_TYPES_FUNC = dict( - EC=new_ec_key, - RSA=new_rsa_key -) +KEY_TYPES_FUNC = dict(EC=new_ec_key, RSA=new_rsa_key) class JWK: @@ -25,8 +22,8 @@ def __init__( self, key: Union[dict, None] = None, key_type: str = "EC", - hash_func: str = 'SHA-256', - ec_crv: str = "P-256" + hash_func: str = "SHA-256", + ec_crv: str = "P-256", ) -> None: """ Creates an instance of JWK. @@ -52,21 +49,21 @@ def __init__( if key: if isinstance(key, dict): self.key = key_from_jwk_dict(key) - key_type = key.get('kty', key_type) - self.kid = key.get('kid', "") + key_type = key.get("kty", key_type) + self.kid = key.get("kid", "") else: self.key = key else: # create new one - if key_type in ['EC', None]: - kwargs['crv'] = ec_crv - self.key = KEY_TYPES_FUNC[key_type or 'EC'](**kwargs) + if key_type in ["EC", None]: + kwargs["crv"] = ec_crv + self.key = KEY_TYPES_FUNC[key_type or "EC"](**kwargs) self.thumbprint = self.key.thumbprint(hash_function=hash_func) self.jwk = self.key.to_dict() self.jwk["kid"] = self.kid or self.thumbprint.decode() self.public_key = self.key.serialize() - self.public_key['kid'] = self.jwk["kid"] + self.public_key["kid"] = self.jwk["kid"] def as_json(self) -> str: """ @@ -125,64 +122,6 @@ def as_public_dict(self) -> dict: """ return self.public_key - def __repr__(self): + def __repr__(self) -> str: # private part! return self.as_json() - - -class RSAJWK(JWK): - def __init__(self, key: dict | None = None, hash_func: str = "SHA-256") -> None: - super().__init__(key, "RSA", hash_func, None) - - -class ECJWK(JWK): - def __init__(self, key: dict | None = None, hash_func: str = "SHA-256", ec_crv: str = "P-256") -> None: - super().__init__(key, "EC", hash_func, ec_crv) - - -def jwk_form_dict(key: dict, hash_func: str = "SHA-256") -> RSAJWK | ECJWK: - """ - Returns a JWK instance from a dict. - - :param key: a dict that represents the key. - :type key: dict - - :returns: a JWK instance. - :rtype: JWK - """ - _kty = key.get('kty', None) - - if _kty is None or _kty not in ['EC', 'RSA']: - raise InvalidJwk("Invalid JWK") - elif _kty == "RSA": - return RSAJWK(key, hash_func) - else: - ec_crv = key.get('crv', "P-256") - return ECJWK(key, hash_func, ec_crv) - - -def find_jwk_by_kid(kid: str, jwks: list[dict], as_dict: bool = True) -> dict | JWK: - """ - Find the JWK with the indicated kid in the jwks list. - - :param kid: the identifier of the jwk - :type kid: str - :param jwks: the list of jwks - :type jwks: list[dict] - :param as_dict: if True the return type will be a dict, JWK otherwise. - :type as_dict: bool - - :raises InvalidKid: if kid is None. - :raises KidNotFoundError: if kid is not in jwks list. - - :returns: the jwk with the indicated kid or an empty dict if no jwk is found - :rtype: dict | JWK - """ - if not kid: - raise InvalidKid("Kid cannot be empty") - for jwk in jwks: - valid_jwk = jwk.get("kid", None) - if valid_jwk and kid == valid_jwk: - return jwk if as_dict else JWK(jwk) - - raise KidNotFoundError(f"Key with Kid {kid} not found") diff --git a/pyeudiw/jwk/exceptions.py b/pyeudiw/jwk/exceptions.py index 6a80ddd3..03ce1ea3 100644 --- a/pyeudiw/jwk/exceptions.py +++ b/pyeudiw/jwk/exceptions.py @@ -8,11 +8,3 @@ class KidNotFoundError(Exception): class InvalidKid(Exception): pass - - -class JwkError(Exception): - pass - - -class InvalidJwk(Exception): - pass diff --git a/pyeudiw/jwk/jwks.py b/pyeudiw/jwk/jwks.py index fc8774fe..ace1ae5f 100644 --- a/pyeudiw/jwk/jwks.py +++ b/pyeudiw/jwk/jwks.py @@ -1,27 +1,32 @@ from pyeudiw.jwk import JWK +from pyeudiw.jwk.exceptions import InvalidKid, KidNotFoundError -def find_jwk_by_kid(jwks: list[dict], kid: str) -> dict | None: - """Find the key with the indicated kid in the given jwks list. - If multiple such keys are int he set, then the first found key - will be returned. +def find_jwk_by_kid(jwks: list[dict], kid: str, as_dict: bool = True) -> dict | JWK: + """ + Find the JWK with the indicated kid in the jwks list. :param kid: the identifier of the jwk :type kid: str :param jwks: the list of jwks :type jwks: list[dict] + :param as_dict: if True the return type will be a dict, JWK otherwise. + :type as_dict: bool + + :raises InvalidKid: if kid is None. + :raises KidNotFoundError: if kid is not in jwks list. - :returns: the jwk with the indicated kid or None if the such key can be found - :rtype: dict | None + :returns: the jwk with the indicated kid or an empty dict if no jwk is found + :rtype: dict | JWK """ if not kid: - raise ValueError("kid cannot be empty") + raise InvalidKid("Kid cannot be empty") for jwk in jwks: - obtained_kid = jwk.get("kid", None) - if kid == obtained_kid: - return jwk - return None + valid_jwk = jwk.get("kid", None) + if valid_jwk and kid == valid_jwk: + return jwk if as_dict else JWK(jwk) + raise KidNotFoundError(f"Key with Kid {kid} not found") def find_jwk_by_thumbprint(jwks: list[dict], thumbprint: bytes) -> dict | None: """Find if a jwk with the given thumbprint is part of the given JWKS. diff --git a/pyeudiw/jwk/parse.py b/pyeudiw/jwk/parse.py index e207ab73..934a42ef 100644 --- a/pyeudiw/jwk/parse.py +++ b/pyeudiw/jwk/parse.py @@ -1,23 +1,8 @@ -import cryptojwt -import cryptojwt.jwk -from cryptojwt.jwk.rsa import import_rsa_key, RSAKey +from cryptojwt.jwk.rsa import RSAKey, import_rsa_key from pyeudiw.jwk import JWK -def adapt_key_to_JWK(key: dict | JWK | cryptojwt.jwk.JWK) -> JWK: - """Function adapt_key_to_JWK normalize key representation format to - the internal JWK. - """ - if isinstance(key, JWK): - return key - if isinstance(key, dict): - return JWK(key) - if isinstance(key, cryptojwt.jwk.JWK): - return JWK(key.to_dict()) - raise ValueError(f"not a valid or supported key format: {type(key)}") - - def parse_key_from_x5c(x5c: list[str]) -> JWK: """Parse a key from an x509 chain. This function currently support only the parsing of public RSA key from such a chain. diff --git a/pyeudiw/jwk/schemas/public.py b/pyeudiw/jwk/schemas/public.py index 8bf22544..d4c2ebd1 100644 --- a/pyeudiw/jwk/schemas/public.py +++ b/pyeudiw/jwk/schemas/public.py @@ -18,18 +18,13 @@ _SUPPORTED_ALG_BY_KTY = { "RSA": ("PS256", "PS384", "PS512", "RS256", "RS384", "RS512"), - "EC": ("ES256", "ES384", "ES512") + "EC": ("ES256", "ES384", "ES512"), } # TODO: supported alg by kty and use _SUPPORTED_CRVS = Literal[ - "P-256", - "P-384", - "P-521", - "brainpoolP256r1", - "brainpoolP384r1", - "brainpoolP512r1" + "P-256", "P-384", "P-521", "brainpoolP256r1", "brainpoolP384r1", "brainpoolP512r1" ] @@ -54,20 +49,23 @@ class RSAJwkSchema(JwkBaseModel): class JwkSchema(BaseModel): kid: str # Base64url-encoded thumbprint string kty: _SUPPORTED_KTY - alg: Annotated[Union[_SUPPORTED_ALGS, None], - Field(validate_default=True)] = None - use: Annotated[Union[Literal["sig", "enc"], None], - Field(validate_default=True)] = None - n: Annotated[Union[str, None], Field( - validate_default=True)] = None # Base64urlUInt-encoded - e: Annotated[Union[str, None], Field( - validate_default=True)] = None # Base64urlUInt-encoded - x: Annotated[Union[str, None], Field( - validate_default=True)] = None # Base64urlUInt-encoded - y: Annotated[Union[str, None], Field( - validate_default=True)] = None # Base64urlUInt-encoded - crv: Annotated[Union[_SUPPORTED_CRVS, None], - Field(validate_default=True)] = None + alg: Annotated[Union[_SUPPORTED_ALGS, None], Field(validate_default=True)] = None + use: Annotated[Union[Literal["sig", "enc"], None], Field(validate_default=True)] = ( + None + ) + n: Annotated[Union[str, None], Field(validate_default=True)] = ( + None # Base64urlUInt-encoded + ) + e: Annotated[Union[str, None], Field(validate_default=True)] = ( + None # Base64urlUInt-encoded + ) + x: Annotated[Union[str, None], Field(validate_default=True)] = ( + None # Base64urlUInt-encoded + ) + y: Annotated[Union[str, None], Field(validate_default=True)] = ( + None # Base64urlUInt-encoded + ) + crv: Annotated[Union[_SUPPORTED_CRVS, None], Field(validate_default=True)] = None def _must_specific_kty_only(v, exp_kty: _SUPPORTED_ALGS, v_name: str, values: dict): """validate a jwk parameter by that it is (1) defined and (2) mandatory @@ -96,7 +94,8 @@ def validate_alg(cls, v, values): kty = values.data.get("kty") if v not in _SUPPORTED_ALG_BY_KTY[kty]: raise ValueError( - f"alg value {v} is not compatible or not supported with kty {kty}") + f"alg value {v} is not compatible or not supported with kty {kty}" + ) return @field_validator("n") @@ -120,8 +119,7 @@ def validate_crv(cls, v, values): cls._must_specific_kty_only(v, "EC", "crv", values.data) -_JwkSchema_T = Annotated[Union[ECJwkSchema, RSAJwkSchema], - Field(discriminator="kty")] +_JwkSchema_T = Annotated[Union[ECJwkSchema, RSAJwkSchema], Field(discriminator="kty")] class JwksSchema(BaseModel): diff --git a/pyeudiw/jwt/helper.py b/pyeudiw/jwt/helper.py index 224ef07b..8733c8f1 100644 --- a/pyeudiw/jwt/helper.py +++ b/pyeudiw/jwt/helper.py @@ -1,24 +1,19 @@ import json - -from pyeudiw.jwk import JWK -from pyeudiw.jwk.parse import parse_key_from_x5c - -from pyeudiw.jwt.log import logger - - -from typing import TypeAlias, Literal +from typing import Literal, TypeAlias from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey -from cryptojwt.jwk.okp import OKPKey from cryptojwt.jwk.hmac import SYMKey from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jwk.okp import OKPKey +from cryptojwt.jwk.rsa import RSAKey +from pyeudiw.jwk import JWK +from pyeudiw.jwk.parse import parse_key_from_x5c +from pyeudiw.jwt.log import logger from pyeudiw.jwt.utils import decode_jwt_payload from pyeudiw.tools.utils import iat_now -from . exceptions import LifetimeException - +from .exceptions import LifetimeException KeyLike: TypeAlias = ECKey | RSAKey | OKPKey | SYMKey SerializationFormat = Literal["compact", "json"] @@ -45,8 +40,7 @@ def __init__(self, jwks: list[KeyLike | dict] | KeyLike | dict): elif isinstance(jwks, (ECKey, RSAKey, OKPKey, SYMKey)): self.jwks = [jwks] else: - raise TypeError( - f"unable to handle input jwks with type {type(jwks)}") + raise TypeError(f"unable to handle input jwks with type {type(jwks)}") def get_jwk_by_kid(self, kid: str) -> KeyLike | None: if not kid: @@ -86,7 +80,8 @@ def find_self_contained_key(header: dict) -> tuple[set[str], JWK] | None: candidate_key = parse_key_from_x5c(header["x5c"]) except Exception as e: logger.debug( - f"failed to parse key from x5c chain {header['x5c']}", exc_info=e) + f"failed to parse key from x5c chain {header['x5c']}", exc_info=e + ) return set(["5xc"]), candidate_key if "jwk" in header: candidate_key = JWK(header["jwk"]) @@ -94,7 +89,8 @@ def find_self_contained_key(header: dict) -> tuple[set[str], JWK] | None: unsupported_claims = set(("trust_chain", "jku", "x5u", "x5t")) if unsupported_claims.intersection(header): raise NotImplementedError( - f"self contained key extraction form header with claims {unsupported_claims} not supported yet") + f"self contained key extraction form header with claims {unsupported_claims} not supported yet" + ) return None @@ -134,14 +130,14 @@ def validate_jwt_timestamps_claims(payload: dict, tolerance_s: int = 0) -> None: """ current_time = iat_now() - if 'iat' in payload: - if payload['iat'] - tolerance_s > current_time: + if "iat" in payload: + if payload["iat"] - tolerance_s > current_time: raise LifetimeException("Future issue time, token is invalid.") - if 'exp' in payload: - if payload['exp'] + tolerance_s <= current_time: + if "exp" in payload: + if payload["exp"] + tolerance_s <= current_time: raise LifetimeException("Token has expired.") - if 'nbf' in payload: - if payload['nbf'] - tolerance_s > current_time: + if "nbf" in payload: + if payload["nbf"] - tolerance_s > current_time: raise LifetimeException("Token not yet valid.") diff --git a/pyeudiw/jwt/jwe_helper.py b/pyeudiw/jwt/jwe_helper.py index 16f7cbeb..a31a714a 100644 --- a/pyeudiw/jwt/jwe_helper.py +++ b/pyeudiw/jwt/jwe_helper.py @@ -41,12 +41,12 @@ def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: else: _payload = "" - encryption_keys = [ - key for key in self.jwks if key.appropriate_for("encrypt")] + encryption_keys = [key for key in self.jwks if key.appropriate_for("encrypt")] if len(encryption_keys) == 0: raise JWEEncryptionError( - "unable to produce JWE: no available encryption key(s)") + "unable to produce JWE: no available encryption key(s)" + ) for key in self.jwks: if isinstance(key, cryptojwt.jwk.rsa.RSAKey): @@ -62,27 +62,27 @@ def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: alg=DEFAULT_ENC_ALG_MAP[key.kty], enc=DEFAULT_ENC_ENC_MAP[key.kty], kid=key.kid, - **kwargs + **kwargs, ) - if key.kty == 'EC': + if key.kty == "EC": _keyobj: JWE_EC - cek, encrypted_key, iv, params, epk = _keyobj.enc_setup( - msg=_payload, - key=key + cek, encrypted_key, iv, params, _ = _keyobj.enc_setup( + msg=_payload, key=key ) kwargs = { "params": params, "cek": cek, "iv": iv, - "encrypted_key": encrypted_key + "encrypted_key": encrypted_key, } return _keyobj.encrypt(**kwargs) else: return _keyobj.encrypt(key=key.public_key()) raise JWEEncryptionError( - "unable to produce JWE: no supported encryption key(s)") + "unable to produce JWE: no supported encryption key(s)" + ) def decrypt(self, jwe: str) -> dict: """ @@ -100,7 +100,8 @@ def decrypt(self, jwe: str) -> dict: jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: raise JWEDecryptionError( - f"Not a valid JWE format for the following reason: {e}") + f"Not a valid JWE format for the following reason: {e}" + ) _alg = jwe_header.get("alg") _enc = jwe_header.get("enc") diff --git a/pyeudiw/jwt/jws_helper.py b/pyeudiw/jwt/jws_helper.py index 74a5a696..d11916ff 100644 --- a/pyeudiw/jwt/jws_helper.py +++ b/pyeudiw/jwt/jws_helper.py @@ -1,18 +1,27 @@ import binascii -from copy import deepcopy import logging import os +from copy import deepcopy from typing import Any, Literal, Union from cryptojwt import JWS from cryptojwt.jwk.jwk import key_from_jwk_dict +from pyeudiw.jwk import JWK from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint -from pyeudiw.jwt.exceptions import JWEEncryptionError, JWSSigningError, JWSVerificationError, LifetimeException -from pyeudiw.jwt.helper import JWHelperInterface, find_self_contained_key, serialize_payload, validate_jwt_timestamps_claims - -from pyeudiw.jwk import JWK +from pyeudiw.jwt.exceptions import ( + JWEEncryptionError, + JWSSigningError, + JWSVerificationError, + LifetimeException, +) +from pyeudiw.jwt.helper import ( + JWHelperInterface, + find_self_contained_key, + serialize_payload, + validate_jwt_timestamps_claims, +) from pyeudiw.jwt.utils import decode_jwt_header SerializationFormat = Literal["compact", "json"] @@ -21,27 +30,17 @@ DEFAULT_HASH_FUNC = "SHA-256" -DEFAULT_SIG_KTY_MAP = { - "RSA": "RS256", - "EC": "ES256" -} +DEFAULT_SIG_KTY_MAP = {"RSA": "RS256", "EC": "ES256"} -DEFAULT_SIG_ALG_MAP = { - "RSA": "RS256", - "EC": "ES256" -} +DEFAULT_SIG_ALG_MAP = {"RSA": "RS256", "EC": "ES256"} -DEFAULT_ENC_ALG_MAP = { - "RSA": "RSA-OAEP", - "EC": "ECDH-ES+A256KW" -} +DEFAULT_ENC_ALG_MAP = {"RSA": "RSA-OAEP", "EC": "ECDH-ES+A256KW"} -DEFAULT_ENC_ENC_MAP = { - "RSA": "A256CBC-HS512", - "EC": "A256GCM" -} +DEFAULT_ENC_ENC_MAP = {"RSA": "A256CBC-HS512", "EC": "A256GCM"} -DEFAULT_TOKEN_TIME_TOLERANCE = int(os.getenv("PYEUDIW_TOKEN_TIME_TOLERANCE", "60"), base=10) +DEFAULT_TOKEN_TIME_TOLERANCE = int( + os.getenv("PYEUDIW_TOKEN_TIME_TOLERANCE", "60"), base=10 +) class JWSHelper(JWHelperInterface): @@ -57,7 +56,7 @@ def sign( serialization_format: SerializationFormat = "compact", signing_kid: str = "", kid_in_header: bool = True, - **kwargs + **kwargs, ) -> str: """Generate a signed JWS with the given payload and header. This method provides no guarantee that the input header is fully preserved, @@ -103,15 +102,15 @@ def sign( # Select the signing key # TODO: check that singing key is either private or symmetric - signing_key = self._select_signing_key( - (protected, unprotected), signing_kid) + signing_key = self._select_signing_key((protected, unprotected), signing_kid) # Ensure the key ID in the header matches the signing key header_kid = protected.get("kid") signer_kid = signing_key.get("kid") if header_kid and signer_kid and (header_kid != signer_kid): raise JWSSigningError( - f"token header contains a kid {header_kid} that does not match the signing key kid {signer_kid}") + f"token header contains a kid {header_kid} that does not match the signing key kid {signer_kid}" + ) payload = serialize_payload(plain_dict) @@ -121,8 +120,7 @@ def sign( # Add "typ" header if not present if "typ" not in protected: - protected["typ"] = "sd-jwt" if self.is_sd_jwt( - plain_dict) else "JWT" + protected["typ"] = "sd-jwt" if self.is_sd_jwt(plain_dict) else "JWT" # Include the signing key's kid in the header if required if kid_in_header and signer_kid: @@ -141,9 +139,7 @@ def sign( if serialization_format == "compact": try: - signed = signer.sign_compact( - keys, protected=protected, **kwargs - ) + signed = signer.sign_compact(keys, protected=protected, **kwargs) return signed except Exception as e: raise JWSSigningError("Signing error: error in step", e) @@ -153,28 +149,33 @@ def sign( flatten=True, ) - def _select_signing_key(self, headers: tuple[dict, dict], signing_kid: str = "") -> dict: + def _select_signing_key( + self, headers: tuple[dict, dict], signing_kid: str = "" + ) -> dict: if len(self.jwks) == 0: raise JWEEncryptionError( - "signing error: no key available for signature; note that {'alg':'none'} is not supported") + "signing error: no key available for signature; note that {'alg':'none'} is not supported" + ) # Case 0: key forced by the user if signing_kid: signing_key = self.get_jwk_by_kid(signing_kid) if not signing_kid: raise JWEEncryptionError( - f"signing forced by using key with {signing_kid=}, but no such key is available") + f"signing forced by using key with {signing_kid=}, but no such key is available" + ) return signing_key.to_dict() # Case 1: only one key - if (signing_key := self._select_signing_key_by_uniqueness()): + if signing_key := self._select_signing_key_by_uniqueness(): return signing_key # Case 2: only one *singing* key - if (signing_key := self._select_key_by_use(use="sig")): + if signing_key := self._select_key_by_use(use="sig"): return signing_key # Case 3: match key by kid: this goes beyond what promised on the method definition - if (signing_key := self._select_key_by_kid(headers)): + if signing_key := self._select_key_by_kid(headers): return signing_key raise JWSSigningError( - "signing error: not possible to uniquely determine the signing key") + "signing error: not possible to uniquely determine the signing key" + ) def _select_signing_key_by_uniqueness(self) -> dict | None: if len(self.jwks) == 1: @@ -185,7 +186,7 @@ def _select_key_by_use(self, use: str) -> dict | None: candidate_signing_keys: list[dict] = [] for key in self.jwks: key_d = key.to_dict() - if use == key_d .get("use", ""): + if use == key_d.get("use", ""): candidate_signing_keys.append(key_d) if len(candidate_signing_keys) == 1: return candidate_signing_keys[0] @@ -202,7 +203,9 @@ def _select_key_by_kid(self, headers: tuple[dict, dict]) -> dict | None: return None return find_jwk_by_kid([key.to_dict() for key in self.jwks], kid) - def verify(self, jwt: str, tolerance_s: int = DEFAULT_TOKEN_TIME_TOLERANCE) -> (str | Any | bytes): + def verify( + self, jwt: str, tolerance_s: int = DEFAULT_TOKEN_TIME_TOLERANCE + ) -> str | Any | bytes: """Verify a JWS with one of the initialized keys and validate standard standard claims if possible, such as 'iat' and 'exp'. Verification of tokens in JSON serialization format is not supported. @@ -231,22 +234,23 @@ def verify(self, jwt: str, tolerance_s: int = DEFAULT_TOKEN_TIME_TOLERANCE) -> ( verifying_key = self._select_verifying_key(header) if not verifying_key: raise JWSVerificationError( - f"Verification error: unable to find matching public key for header {header}") + f"Verification error: unable to find matching public key for header {header}" + ) # sanity check: kid must match if present - if (expected_kid := header.get("kid")): + if expected_kid := header.get("kid"): obtained_kid = verifying_key.get("kid", None) if obtained_kid and (obtained_kid != expected_kid): raise JWSVerificationError( KidError( "unexpected verification state: found a valid verifying key," - f"but its kid {obtained_kid} does not match token header kid {expected_kid}") + f"but its kid {obtained_kid} does not match token header kid {expected_kid}" + ) ) # Verify the JWS compact signature verifier = JWS(alg=header["alg"]) - msg: dict = verifier.verify_compact( - jwt, [key_from_jwk_dict(verifying_key)]) + msg: dict = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)]) # Validate JWT claims try: @@ -261,19 +265,22 @@ def _select_verifying_key(self, header: dict) -> dict | None: # case 1: can be found by header if "kid" in header: - if (verifying_key := find_jwk_by_kid(available_keys, header["kid"])): + if verifying_key := find_jwk_by_kid(available_keys, header["kid"]): return verifying_key # case 2: the token is self contained, and the verification key matches one of the key in the whitelist - if (self_contained_claims_key_pair := find_self_contained_key(header)): + if self_contained_claims_key_pair := find_self_contained_key(header): # check if the self contained key matches a trusted jwk - used_claims, candidate_key = self_contained_claims_key_pair + _, candidate_key = self_contained_claims_key_pair if hasattr(candidate_key, "thumbprint"): - if (verifying_key := find_jwk_by_thumbprint(available_keys, candidate_key.thumbprint)): + if verifying_key := find_jwk_by_thumbprint( + available_keys, candidate_key.thumbprint + ): return verifying_key else: logger.error( - f"Candidate key {candidate_key} does not have a thumbprint attribute.") + f"Candidate key {candidate_key} does not have a thumbprint attribute." + ) raise ValueError("Invalid key: missing thumbprint.") # case 3: if only one key and there is no header claim that can identitfy any key, than that MUST diff --git a/pyeudiw/jwt/log.py b/pyeudiw/jwt/log.py index 2c4f2744..988f4959 100644 --- a/pyeudiw/jwt/log.py +++ b/pyeudiw/jwt/log.py @@ -1,5 +1,4 @@ # This defined the package level logger import logging - logger = logging.getLogger(__name__) diff --git a/pyeudiw/jwt/parse.py b/pyeudiw/jwt/parse.py index fb4d35db..fd7dc2dc 100644 --- a/pyeudiw/jwt/parse.py +++ b/pyeudiw/jwt/parse.py @@ -1,13 +1,8 @@ -import json import base64 +import json from dataclasses import dataclass - -from pyeudiw.jwt.utils import is_jwt_format -from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload - -KeyIdentifier_T = str - +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload, is_jwt_format @dataclass(frozen=True) class DecodedJwt: @@ -16,24 +11,17 @@ class DecodedJwt: This class is not meant to be instantiated directly. Use instead the static method parse(str) -> DecodedJwt. """ + jwt: str header: dict payload: dict signature: str @staticmethod - def parse(jws: str) -> 'DecodedJwt': + def parse(jws: str) -> "DecodedJwt": return unsafe_parse_jws(jws) -def _unsafe_decode_part(part: str) -> dict: - padding_needed = len(part) % 4 - if padding_needed: - part += "=" * (4 - padding_needed) - decoded_bytes = base64.urlsafe_b64decode(part) - return json.loads(decoded_bytes.decode("utf-8")) - - def unsafe_parse_jws(token: str) -> DecodedJwt: """ Parse a token into its components. diff --git a/pyeudiw/jwt/schemas/jwt.py b/pyeudiw/jwt/schemas/jwt.py index 6926026c..28a67742 100644 --- a/pyeudiw/jwt/schemas/jwt.py +++ b/pyeudiw/jwt/schemas/jwt.py @@ -1,5 +1,10 @@ from pydantic import BaseModel, Field -from pyeudiw.federation.schemas.wallet_relying_party import SigningAlgValuesSupported, EncryptionAlgValuesSupported, EncryptionEncValuesSupported + +from pyeudiw.federation.schemas.openid_credential_verifier import ( + EncryptionAlgValuesSupported, + EncryptionEncValuesSupported, + SigningAlgValuesSupported, +) class JWTConfig(BaseModel): diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py index 1bec79ea..513b8ea6 100644 --- a/pyeudiw/jwt/utils.py +++ b/pyeudiw/jwt/utils.py @@ -2,10 +2,10 @@ import json import re -from pyeudiw.jwt.exceptions import JWTInvalidElementPosition, JWTDecodeError +from pyeudiw.jwt.exceptions import JWTDecodeError, JWTInvalidElementPosition # jwt regexp pattern is non terminating, hence it match jwt, sd-jwt and sd-jwt with kb -JWT_REGEXP = r'^[_\w\-]+\.[_\w\-]+\.[_\w\-]+' +JWT_REGEXP = r"^[_\w\-]+\.[_\w\-]+\.[_\w\-]+" def decode_jwt_element(jwt: str, position: int) -> dict: @@ -23,18 +23,17 @@ def decode_jwt_element(jwt: str, position: int) -> dict: :rtype: dict """ if position < 0: - raise JWTInvalidElementPosition( - f"Cannot accept negative position {position}") + raise JWTInvalidElementPosition(f"Cannot accept negative position {position}") if position > 2: raise JWTInvalidElementPosition( - f"Cannot accept position greater than 2 {position}") + f"Cannot accept position greater than 2 {position}" + ) splitted_jwt = jwt.split(".") if (len(splitted_jwt) - 1) < position: - raise JWTInvalidElementPosition( - f"JWT has no element in position {position}") + raise JWTInvalidElementPosition(f"JWT has no element in position {position}") try: if isinstance(jwt, bytes): diff --git a/pyeudiw/jwt/verification.py b/pyeudiw/jwt/verification.py index fa078f81..0a0340f4 100644 --- a/pyeudiw/jwt/verification.py +++ b/pyeudiw/jwt/verification.py @@ -1,9 +1,8 @@ +from cryptojwt.jwk import JWK from pyeudiw.jwt.exceptions import JWSVerificationError from pyeudiw.jwt.jws_helper import JWSHelper -from cryptojwt.jwk import JWK - def verify_jws_with_key(jws: str, key: JWK) -> None: """ @@ -13,5 +12,4 @@ def verify_jws_with_key(jws: str, key: JWK) -> None: verifier = JWSHelper(key) verifier.verify(jws) except Exception as e: - raise JWSVerificationError( - f"error during signature verification: {e}", e) + raise JWSVerificationError(f"error during signature verification: {e}", e) diff --git a/pyeudiw/oauth2/dpop/__init__.py b/pyeudiw/oauth2/dpop/__init__.py index f00560d8..833f230f 100644 --- a/pyeudiw/oauth2/dpop/__init__.py +++ b/pyeudiw/oauth2/dpop/__init__.py @@ -3,20 +3,12 @@ import logging import uuid +from pyeudiw.jwk.exceptions import KidError from pyeudiw.jwk.schemas.public import JwkSchema from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.oauth2.dpop.exceptions import ( - InvalidDPoP, - InvalidDPoPAth, - InvalidDPoPKid -) -from pyeudiw.jwk.exceptions import KidError - from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload -from pyeudiw.oauth2.dpop.schema import ( - DPoPTokenHeaderSchema, - DPoPTokenPayloadSchema -) +from pyeudiw.oauth2.dpop.exceptions import InvalidDPoP, InvalidDPoPAth, InvalidDPoPKid +from pyeudiw.oauth2.dpop.schema import DPoPTokenHeaderSchema, DPoPTokenPayloadSchema from pyeudiw.tools.utils import iat_now logger = logging.getLogger(__name__) @@ -58,15 +50,16 @@ def proof(self): "htm": "GET", "htu": self.htu, "iat": iat_now(), - "ath": base64.urlsafe_b64encode(hashlib.sha256(self.token.encode()).digest()).rstrip(b'=').decode() + "ath": base64.urlsafe_b64encode( + hashlib.sha256(self.token.encode()).digest() + ) + .rstrip(b"=") + .decode(), } jwt = self.signer.sign( data, - protected={ - 'typ': "dpop+jwt", - 'jwk': self.private_jwk.serialize() - }, - kid_in_header=False + protected={"typ": "dpop+jwt", "jwk": self.private_jwk.serialize()}, + kid_in_header=False, ) return jwt @@ -76,7 +69,7 @@ class DPoPVerifier: Helper class for validate DPoP proofs. """ - dpop_header_prefix = 'DPoP ' + dpop_header_prefix = "DPoP " def __init__( self, @@ -99,7 +92,7 @@ def __init__( """ self.public_jwk = public_jwk self.dpop_token = ( - http_header_authz.replace(self.dpop_header_prefix, '') + http_header_authz.replace(self.dpop_header_prefix, "") if self.dpop_header_prefix in http_header_authz else http_header_authz ) @@ -107,26 +100,17 @@ def __init__( try: JwkSchema(**public_jwk) except Exception as e: - logger.error( - "Jwk validation error, " - f"{e.__class__.__name__}: {e}" - ) + logger.error("Jwk validation error, " f"{e.__class__.__name__}: {e}") raise ValueError("JWK schema validation error during DPoP init") # If the jwt is invalid, this will raise an exception try: decode_jwt_header(http_header_dpop) except UnicodeDecodeError as e: - logger.error( - "DPoP proof validation error, " - f"{e.__class__.__name__}: {e}" - ) + logger.error("DPoP proof validation error, " f"{e.__class__.__name__}: {e}") raise ValueError("DPoP proof is not a valid JWT") except Exception as e: - logger.error( - "DPoP proof validation error, " - f"{e.__class__.__name__}: {e}" - ) + logger.error("DPoP proof validation error, " f"{e.__class__.__name__}: {e}") raise ValueError("DPoP proof is not a valid JWT") self.proof = http_header_dpop @@ -150,32 +134,29 @@ def validate(self) -> bool: dpop_valid = jws_verifier.verify(self.proof) except KidError as e: raise InvalidDPoPKid( - ( - "DPoP proof validation error, " - f"kid does not match: {e}" - ) + ("DPoP proof validation error, " f"kid does not match: {e}") ) except Exception as e: raise InvalidDPoP( - "DPoP proof validation error, " - f"{e.__class__.__name__}: {e}" + "DPoP proof validation error, " f"{e.__class__.__name__}: {e}" ) header = decode_jwt_header(self.proof) DPoPTokenHeaderSchema(**header) - if header['jwk'] != self.public_jwk: - raise InvalidDPoPAth(( - "DPoP proof validation error, " - "header['jwk'] != self.public_jwk, " - f"{header['jwk']} != {self.public_jwk}" - )) + if header["jwk"] != self.public_jwk: + raise InvalidDPoPAth( + ( + "DPoP proof validation error, " + "header['jwk'] != self.public_jwk, " + f"{header['jwk']} != {self.public_jwk}" + ) + ) payload = decode_jwt_payload(self.proof) DPoPTokenPayloadSchema(**payload) _ath = hashlib.sha256(self.dpop_token.encode()) - _ath_b64 = base64.urlsafe_b64encode( - _ath.digest()).rstrip(b'=').decode() - proof_valid = _ath_b64 == payload['ath'] + _ath_b64 = base64.urlsafe_b64encode(_ath.digest()).rstrip(b"=").decode() + proof_valid = _ath_b64 == payload["ath"] return dpop_valid and proof_valid diff --git a/pyeudiw/oauth2/dpop/exceptions.py b/pyeudiw/oauth2/dpop/exceptions.py index 6219de0e..acc1473b 100644 --- a/pyeudiw/oauth2/dpop/exceptions.py +++ b/pyeudiw/oauth2/dpop/exceptions.py @@ -1,7 +1,3 @@ -class InvalidDPoPJwk(Exception): - pass - - class InvalidDPoPAth(Exception): pass diff --git a/pyeudiw/openid4vp/authorization_request.py b/pyeudiw/openid4vp/authorization_request.py index 00a86977..f987b7d2 100644 --- a/pyeudiw/openid4vp/authorization_request.py +++ b/pyeudiw/openid4vp/authorization_request.py @@ -1,5 +1,5 @@ -from urllib.parse import quote_plus, urlencode import uuid +from urllib.parse import quote_plus, urlencode from pyeudiw.openid4vp.schemas.response import ResponseMode from pyeudiw.tools.utils import exp_from_now, iat_now @@ -19,7 +19,13 @@ def build_authorization_request_url(scheme: str, params: dict) -> str: return f"{scheme}{_sep}{query_params}" -def build_authorization_request_claims(client_id: str, state: str, response_uri: str, authorization_config: dict, nonce: str = "") -> dict: +def build_authorization_request_claims( + client_id: str, + state: str, + response_uri: str, + authorization_config: dict, + nonce: str = "", +) -> dict: """ Primitive function to build the payload claims of the (JAR) authorization request. :param client_id: the client identifier (who issue the jar token) @@ -45,21 +51,25 @@ def build_authorization_request_claims(client_id: str, state: str, response_uri: claims = { "client_id_scheme": "http", # that's federation. "client_id": client_id, - "response_mode": authorization_config.get("response_mode", ResponseMode.direct_post_jwt), + "response_mode": authorization_config.get( + "response_mode", ResponseMode.direct_post_jwt + ), "response_type": "vp_token", "response_uri": response_uri, "nonce": nonce, "state": state, "iss": authorization_config.get("auth_iss_id", client_id), "iat": iat_now(), - "exp": exp_from_now(minutes=authorization_config["expiration_time"]) + "exp": exp_from_now(minutes=authorization_config["expiration_time"]), } if authorization_config.get("scopes"): - claims["scope"] = ' '.join(authorization_config["scopes"]) + claims["scope"] = " ".join(authorization_config["scopes"]) # backend configuration validation should check that at least PE or DCQL must be configured within the authz request conf if authorization_config.get("presentation_definition"): - claims["presentation_definition"] = authorization_config["presentation_definition"] + claims["presentation_definition"] = authorization_config[ + "presentation_definition" + ] - if (_aud := authorization_config.get("aud")): + if _aud := authorization_config.get("aud"): claims["aud"] = _aud return claims diff --git a/pyeudiw/openid4vp/authorization_response.py b/pyeudiw/openid4vp/authorization_response.py index 413876f5..dee09bc5 100644 --- a/pyeudiw/openid4vp/authorization_response.py +++ b/pyeudiw/openid4vp/authorization_response.py @@ -1,19 +1,17 @@ -import json +import cryptojwt.jwe.exception import satosa.context - -from pyeudiw.jwk.exceptions import KidNotFoundError from pyeudiw.jwt.exceptions import JWEDecryptionError from pyeudiw.jwt.jwe_helper import JWEHelper -from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.jwt.utils import decode_jwt_header - -from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey -import cryptojwt.jwe.exception - -from pyeudiw.openid4vp.exceptions import AuthRespParsingException, AuthRespValidationException +from pyeudiw.openid4vp.exceptions import ( + AuthRespParsingException, + AuthRespValidationException, +) from pyeudiw.openid4vp.interface import AuthorizationResponseParser -from pyeudiw.openid4vp.schemas.response import AuthorizeResponseDirectPostJwt, AuthorizeResponsePayload, ResponseMode +from pyeudiw.openid4vp.schemas.response import ( + AuthorizeResponseDirectPostJwt, + AuthorizeResponsePayload, + ResponseMode, +) def detect_response_mode(context: satosa.context.Context) -> ResponseMode: @@ -26,7 +24,8 @@ def detect_response_mode(context: satosa.context.Context) -> ResponseMode: if "vp_token" in context.request: return ResponseMode.direct_post raise AuthRespParsingException( - "HTTP POST request body does not contain a recognized openid4vp response mode") + "HTTP POST request body does not contain a recognized openid4vp response mode" + ) def _check_http_post_headers(context: satosa.context.Context) -> None: @@ -40,7 +39,7 @@ def _check_http_post_headers(context: satosa.context.Context) -> None: # missing header is ok; but if it's there, it must be correct if context.http_headers: - content_type = context.http_headers['HTTP_CONTENT_TYPE'] + content_type = context.http_headers["HTTP_CONTENT_TYPE"] if "application/x-www-form-urlencoded" not in content_type: err_msg = f"HTTP content type [{content_type}] not supported" raise AuthRespParsingException(err_msg, err_msg) @@ -57,7 +56,9 @@ class DirectPostParser(AuthorizationResponseParser): def __init__(self): pass - def parse_and_validate(self, context: satosa.context.Context) -> AuthorizeResponsePayload: + def parse_and_validate( + self, context: satosa.context.Context + ) -> AuthorizeResponsePayload: _check_http_post_headers(context) resp_data: dict = context.request @@ -65,7 +66,8 @@ def parse_and_validate(self, context: satosa.context.Context) -> AuthorizeRespon return AuthorizeResponsePayload(**resp_data) except Exception as e: raise AuthRespParsingException( - "invalid data in direct_post request body", e) + "invalid data in direct_post request body", e + ) class DirectPostJwtJweParser(AuthorizationResponseParser): @@ -83,59 +85,44 @@ class DirectPostJwtJweParser(AuthorizationResponseParser): def __init__(self, jwe_decryptor: JWEHelper): self.jwe_decryptor = jwe_decryptor - def parse_and_validate(self, context: satosa.context.Context) -> AuthorizeResponsePayload: + def parse_and_validate( + self, context: satosa.context.Context + ) -> AuthorizeResponsePayload: _check_http_post_headers(context) resp_data_raw: dict = context.request try: resp_data = AuthorizeResponseDirectPostJwt(**resp_data_raw) except Exception as e: raise AuthRespParsingException( - "invalid data in direct_post.jwt request body", e) + "invalid data in direct_post.jwt request body", e + ) try: payload = self.jwe_decryptor.decrypt(resp_data.response) except JWEDecryptionError as e: raise AuthRespParsingException( - "invalid data in direct_post.jwt request body: not a jwe", e) + "invalid data in direct_post.jwt request body: not a jwe", e + ) except cryptojwt.jwe.exception.DecryptionFailed: raise AuthRespValidationException( - "invalid data in direct_post.jwt: unable to decrypt token") + "invalid data in direct_post.jwt: unable to decrypt token" + ) except Exception as e: # unfortunately library cryptojwt is not very exhaustive on why an operation failed... raise AuthRespValidationException( - "invalid data in direct_post.jwt request body", e) + "invalid data in direct_post.jwt request body", e + ) # iss, exp and aud MUST be OMITTED in the JWT Claims Set of the JWE if ("iss" in payload) or ("exp" in payload): - raise AuthRespParsingException("response token contains an unexpected lifetime claims", Exception( - "wallet mishbeahiour: JWe with bad claims")) + raise AuthRespParsingException( + "response token contains an unexpected lifetime claims", + Exception("wallet mishbeahiour: JWe with bad claims"), + ) try: return AuthorizeResponsePayload(**payload) except Exception as e: raise AuthRespParsingException( - "invalid data in the direct_post.jwt: token payload does not have the expected claims", e) - - -def _get_jwk_kid_from_store(jwt: str, key_store: dict[str, dict]) -> dict: - headers = decode_jwt_header(jwt) - kid: str | None = headers.get("kid", None) - if kid is None: - raise KidNotFoundError( - "authorization response is missing mandatory parameter [kid] in header section") - jwk_dict = key_store.get(kid, None) - if jwk_dict is None: - raise KidNotFoundError( - f"authorization response is encrypted with jwk with kid='{kid}' not found in store") - return jwk_dict - - -def _decrypt_jwe(jwe: str, decrypting_jwk: dict[str, any]) -> dict: - decrypter = JWEHelper(decrypting_jwk) - return decrypter.decrypt(jwe) - - -def _verify_and_decode_jwt(jwt: str, verifying_jwk: dict[dict, ECKey | RSAKey | dict]) -> dict: - verifier = JWSHelper(verifying_jwk) - raw_payload: str = verifier.verify(jwt)["msg"] - payload: dict = json.loads(raw_payload) - return payload + "invalid data in the direct_post.jwt: token payload does not have the expected claims", + e, + ) \ No newline at end of file diff --git a/pyeudiw/openid4vp/exceptions.py b/pyeudiw/openid4vp/exceptions.py index 7f247d0a..d6bfcdbc 100644 --- a/pyeudiw/openid4vp/exceptions.py +++ b/pyeudiw/openid4vp/exceptions.py @@ -1,63 +1,31 @@ class AuthRespParsingException(Exception): - """Raised when the http request corresponding to an authorization response is malformed. - """ + """Raised when the http request corresponding to an authorization response is malformed.""" + pass class AuthRespValidationException(Exception): """Raised when the http request corresponding to an authorization response is well formed, but not valid (for example, it might be wrapped in an expired token). """ - - -class KIDNotFound(Exception): - """ - Raised when kid is not present in the public key dict - """ - - -class VPSchemaException(Exception): - pass - - -class VPNotFound(Exception): - pass - - -class VPInvalidNonce(Exception): pass -class NoNonceInVPToken(Exception): - """ - Raised when a given VP has no nonce - """ - - class InvalidVPToken(Exception): """ Raised when a given VP is invalid """ + pass class InvalidVPKeyBinding(InvalidVPToken): """Raised when a given VP contains a proof of possession key binding with wrong parameters. """ - - -class InvalidVPSignature(InvalidVPKeyBinding): - """Raised when a VP contains a proof of possession key binding and - its signature verification failed. - """ - - -class RevokedVPToken(Exception): - """ - Raised when a given VP is revoked - """ + pass class VPFormatNotSupported(Exception): """ Raised when a given VP format is not supported """ + pass diff --git a/pyeudiw/openid4vp/interface.py b/pyeudiw/openid4vp/interface.py index 571259ee..1643a106 100644 --- a/pyeudiw/openid4vp/interface.py +++ b/pyeudiw/openid4vp/interface.py @@ -1,6 +1,6 @@ +import satosa.context from cryptojwt.jwk.ec import ECKey from cryptojwt.jwk.rsa import RSAKey -import satosa.context from pyeudiw.openid4vp.schemas.response import AuthorizeResponsePayload @@ -19,7 +19,9 @@ class AuthorizationResponseParser: object, method or interface. """ - def parse_and_validate(self, context: satosa.context.Context) -> AuthorizeResponsePayload: + def parse_and_validate( + self, context: satosa.context.Context + ) -> AuthorizeResponsePayload: """ Parse (and optionally validate) a satosa http request, wrapped in its own context, in order to extract an auhtorization response. diff --git a/pyeudiw/openid4vp/schemas/response.py b/pyeudiw/openid4vp/schemas/response.py index 38f346d5..a06a548e 100644 --- a/pyeudiw/openid4vp/schemas/response.py +++ b/pyeudiw/openid4vp/schemas/response.py @@ -46,8 +46,7 @@ class AuthorizeResponseDirectPostJwt: def __post_init__(self): jwt = self.response if not is_jwe_format(jwt) and not is_jwt_format(jwt): - raise ValueError( - f"input response={jwt} is neither jwt not jwe format") + raise ValueError(f"input response={jwt} is neither jwt not jwe format") @dataclass @@ -61,6 +60,7 @@ class AuthorizeResponsePayload: as it is not meant to validate the _content_ of the response; just that the representation lands with the proper expected claims """ + state: str vp_token: str | list[str] presentation_submission: dict diff --git a/pyeudiw/openid4vp/schemas/vp_formats.py b/pyeudiw/openid4vp/schemas/vp_formats.py index 254a3a50..a90ce95f 100644 --- a/pyeudiw/openid4vp/schemas/vp_formats.py +++ b/pyeudiw/openid4vp/schemas/vp_formats.py @@ -1,5 +1,6 @@ from enum import Enum from typing import List + from pydantic import BaseModel, Field @@ -13,9 +14,9 @@ class Algorithms(Enum): class VcSdJwt(BaseModel): - sd_jwt_alg_values: List[Algorithms] = Field([], alias='sd-jwt_alg_values') - kb_jwt_alg_values: List[Algorithms] = Field([], alias='kb-jwt_alg_values') + sd_jwt_alg_values: List[Algorithms] = Field([], alias="sd-jwt_alg_values") + kb_jwt_alg_values: List[Algorithms] = Field([], alias="kb-jwt_alg_values") class VpFormats(BaseModel): - vc_sd_jwt: VcSdJwt = Field(..., alias='vc+sd-jwt') + vc_sd_jwt: VcSdJwt = Field(..., alias="vc+sd-jwt") diff --git a/pyeudiw/openid4vp/schemas/vp_token.py b/pyeudiw/openid4vp/schemas/vp_token.py index 704cddf6..13fcfcee 100644 --- a/pyeudiw/openid4vp/schemas/vp_token.py +++ b/pyeudiw/openid4vp/schemas/vp_token.py @@ -1,7 +1,7 @@ from typing import Literal from pydantic import BaseModel, HttpUrl, field_validator -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from pyeudiw.sd_jwt.schema import is_sd_jwt_format from pyeudiw.tools.schema_utils import check_algorithm @@ -14,7 +14,7 @@ class VPTokenHeader(BaseModel): @field_validator("alg") @classmethod - def _check_alg(cls, alg, info: FieldValidationInfo): + def _check_alg(cls, alg, info: ValidationInfo): return check_algorithm(alg, info) diff --git a/pyeudiw/openid4vp/schemas/wallet_instance_attestation.py b/pyeudiw/openid4vp/schemas/wallet_instance_attestation.py index 891f4f26..9362dc02 100644 --- a/pyeudiw/openid4vp/schemas/wallet_instance_attestation.py +++ b/pyeudiw/openid4vp/schemas/wallet_instance_attestation.py @@ -1,7 +1,7 @@ from typing import Dict, List, Literal, Optional from pydantic import BaseModel, HttpUrl, field_validator -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from pyeudiw.openid4vp.schemas.cnf_schema import CNFSchema from pyeudiw.tools.schema_utils import check_algorithm @@ -32,7 +32,7 @@ class WalletInstanceAttestationHeader(BaseModel): @field_validator("alg") @classmethod - def _check_alg(cls, alg, info: FieldValidationInfo): + def _check_alg(cls, alg, info: ValidationInfo): return check_algorithm(alg, info) diff --git a/pyeudiw/openid4vp/schemas/wallet_instance_attestation_request.py b/pyeudiw/openid4vp/schemas/wallet_instance_attestation_request.py index 052b8bf0..4488f885 100644 --- a/pyeudiw/openid4vp/schemas/wallet_instance_attestation_request.py +++ b/pyeudiw/openid4vp/schemas/wallet_instance_attestation_request.py @@ -1,7 +1,7 @@ from typing import Literal from pydantic import BaseModel, HttpUrl, field_validator -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from pyeudiw.openid4vp.schemas.cnf_schema import CNFSchema from pyeudiw.tools.schema_utils import check_algorithm @@ -14,7 +14,7 @@ class WalletInstanceAttestationRequestHeader(BaseModel): @field_validator("alg") @classmethod - def _check_alg(cls, alg, info: FieldValidationInfo): + def _check_alg(cls, alg, info: ValidationInfo): return check_algorithm(alg, info) diff --git a/pyeudiw/openid4vp/utils.py b/pyeudiw/openid4vp/utils.py index af0aa32b..2d3adeac 100644 --- a/pyeudiw/openid4vp/utils.py +++ b/pyeudiw/openid4vp/utils.py @@ -1,13 +1,7 @@ from typing import Any from satosa.context import Context - from pyeudiw.openid4vp.schemas.flow import RemoteFlowType -from pyeudiw.openid4vp.vp import Vp -from pyeudiw.openid4vp.vp_mdoc_cbor import VpMDocCbor -from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt -from pyeudiw.openid4vp.exceptions import VPFormatNotSupported -from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.tools.mobile import is_smartphone @@ -15,59 +9,13 @@ def detect_flow_typ(context: Context) -> RemoteFlowType: """ Identitfy or guess the remote flow type based on the context of the user auhtnetication - """ - if is_smartphone(context.http_headers.get('HTTP_USER_AGENT')): - return RemoteFlowType.SAME_DEVICE - return RemoteFlowType.CROSS_DEVICE + :param context: the context of the user authentication + :type context: Context -def vp_parser(jwt: str) -> Vp: + :returns: the remote flow type + :rtype: RemoteFlowType """ - Handle the jwt returning the correct VP istance. - - :param jwt: a string that represents the jwt. - :type jwt: str - - :raises VPFormatNotSupported: if the VP Digital credentials type is not implemented yet. - - :returns: the VP istance. - :rtype: Vp - """ - - headers = decode_jwt_header(jwt) - - typ: str | None = headers.get("typ", None) - if typ is None: - raise ValueError("missing mandatory header [typ] in jwt header") - - match typ.lower(): - case "jwt": - return VpSdJwt(jwt) - case "vc+sd-jwt": - raise NotImplementedError( - "parsing of vp tokens with typ vc+sd-jwt not supported yet") - case "mcdoc_cbor": - return VpMDocCbor(jwt) - case unsupported: - raise VPFormatNotSupported( - f"parsing of unsupported vp typ [{unsupported}]") - - -def infer_vp_header_claim(jws: str, claim_name: str) -> Any: - headers = decode_jwt_header(jws) - claim_value = headers.get(claim_name, "") - return claim_value - - -def infer_vp_payload_claim(jws: str, claim_name: str) -> Any: - headers = decode_jwt_payload(jws) - claim_value: str = headers.get(claim_name, "") - return claim_value - - -def infer_vp_typ(jws: str) -> str: - return infer_vp_header_claim(jws, claim_name="typ") - - -def infer_vp_iss(jws: str) -> str: - return infer_vp_payload_claim(jws, claim_name="iss") + if is_smartphone(context.http_headers.get("HTTP_USER_AGENT")): + return RemoteFlowType.SAME_DEVICE + return RemoteFlowType.CROSS_DEVICE diff --git a/pyeudiw/openid4vp/vp.py b/pyeudiw/openid4vp/vp.py index 5818e1bd..bcf50a9a 100644 --- a/pyeudiw/openid4vp/vp.py +++ b/pyeudiw/openid4vp/vp.py @@ -5,12 +5,7 @@ WALLET_ATTESTATION_TYPE = "wallet-attestation+jwt" MDOC_BCOR_TYPE = "mdoc_cbor" -SUPPORTED_VC_TYPES = ( - JWT_TYPE, - VC_SD_JWT_TYPE, - WALLET_ATTESTATION_TYPE, - MDOC_BCOR_TYPE -) +SUPPORTED_VC_TYPES = (JWT_TYPE, VC_SD_JWT_TYPE, WALLET_ATTESTATION_TYPE, MDOC_BCOR_TYPE) class Vp(BaseLogger): @@ -38,8 +33,5 @@ def check_revocation(self): # TODO: check the revocation of the credential self._log_warning("VP", "Revocation check not implemented yet") - def verify( - self, - **kwargs - ) -> bool: + def verify(self, **kwargs) -> bool: raise NotImplementedError diff --git a/pyeudiw/openid4vp/vp_mdoc_cbor.py b/pyeudiw/openid4vp/vp_mdoc_cbor.py index 9943a9eb..d89e2a6a 100644 --- a/pyeudiw/openid4vp/vp_mdoc_cbor.py +++ b/pyeudiw/openid4vp/vp_mdoc_cbor.py @@ -1,6 +1,7 @@ -from pyeudiw.openid4vp.vp import Vp from pymdoccbor.mdoc.verifier import MdocCbor +from pyeudiw.openid4vp.vp import Vp + class VpMDocCbor(Vp): def __init__(self, data: str) -> None: diff --git a/pyeudiw/openid4vp/vp_sd_jwt_vc.py b/pyeudiw/openid4vp/vp_sd_jwt_vc.py index a2503d99..a8abea50 100644 --- a/pyeudiw/openid4vp/vp_sd_jwt_vc.py +++ b/pyeudiw/openid4vp/vp_sd_jwt_vc.py @@ -1,5 +1,8 @@ from typing import Optional +from cryptojwt.jwk.ec import ECKey +from cryptojwt.jwk.rsa import RSAKey + from pyeudiw.jwt.helper import is_jwt_expired from pyeudiw.openid4vp.exceptions import InvalidVPKeyBinding from pyeudiw.openid4vp.interface import VpTokenParser, VpTokenVerifier @@ -7,26 +10,28 @@ from pyeudiw.sd_jwt.schema import VerifierChallenge, is_sd_jwt_kb_format from pyeudiw.sd_jwt.sd_jwt import SdJwt -from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey - class VpVcSdJwtParserVerifier(VpTokenParser, VpTokenVerifier): - def __init__(self, token: str, verifier_id: Optional[str] = None, verifier_nonce: Optional[str] = None): + def __init__( + self, + token: str, + verifier_id: Optional[str] = None, + verifier_nonce: Optional[str] = None, + ): self.token = token if not is_sd_jwt_kb_format(token): raise ValueError( - f"input [token]={token} is not an sd-jwt with key binding: maybe it is a regular jwt or key binding jwt is missing?") + f"input [token]={token} is not an sd-jwt with key binding: maybe it is a regular jwt or key binding jwt is missing?" + ) self.verifier_id = verifier_id self.verifier_nonce = verifier_nonce # precomputed values self.sdjwt = SdJwt(self.token) def get_issuer_name(self) -> str: - iss = self.sdjwt.issuer_jwt.payload.get("iss", None) + iss = self.sdjwt.get_issuer_jwt().payload.get("iss", None) if not iss: - raise Exception( - "missing required information in token paylaod: [iss]") + raise Exception("missing required information in token paylaod: [iss]") return iss def get_credentials(self) -> dict: diff --git a/pyeudiw/presentation_exchange/schemas/oid4vc_presentation_definition.py b/pyeudiw/presentation_exchange/schemas/oid4vc_presentation_definition.py index ab287aef..9436b5e4 100644 --- a/pyeudiw/presentation_exchange/schemas/oid4vc_presentation_definition.py +++ b/pyeudiw/presentation_exchange/schemas/oid4vc_presentation_definition.py @@ -5,19 +5,19 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, List, Optional, Union, Annotated +from typing import Annotated, Any, Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field, RootModel, conint class LimitDisclosure(Enum): - required = 'required' - preferred = 'preferred' + required = "required" + preferred = "preferred" class Constraints(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) limit_disclosure: Optional[LimitDisclosure] = None fields: Optional[List[Any]] = None @@ -25,21 +25,21 @@ class Constraints(BaseModel): class PresentationDefinitionClaimFormatDesignations1(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) alg: Optional[List[str]] = Field(None, min_length=1) class PresentationDefinitionClaimFormatDesignations2(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) proof_type: Optional[List[str]] = Field(None, min_length=1) class PresentationDefinitionClaimFormatDesignations3(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) @@ -47,48 +47,48 @@ class PresentationDefinitionClaimFormatDesignations( RootModel[ Union[ Dict[ - Annotated[str, Field(pattern=r'^jwt$|^jwt_vc$|^jwt_vp$')], + Annotated[str, Field(pattern=r"^jwt$|^jwt_vc$|^jwt_vp$")], PresentationDefinitionClaimFormatDesignations1, ], Dict[ - Annotated[str, Field(pattern=r'^ldp_vc$|^ldp_vp$|^ldp$')], + Annotated[str, Field(pattern=r"^ldp_vc$|^ldp_vp$|^ldp$")], PresentationDefinitionClaimFormatDesignations2, ], Dict[ - Annotated[str, Field(pattern=r'^vc\+sd-jwt$')], + Annotated[str, Field(pattern=r"^vc\+sd-jwt$")], PresentationDefinitionClaimFormatDesignations3, - ] + ], ] ] ): root: Union[ Dict[ - Annotated[str, Field(pattern=r'^jwt$|^jwt_vc$|^jwt_vp$')], + Annotated[str, Field(pattern=r"^jwt$|^jwt_vc$|^jwt_vp$")], PresentationDefinitionClaimFormatDesignations1, ], Dict[ - Annotated[str, Field(pattern=r'^ldp_vc$|^ldp_vp$|^ldp$')], + Annotated[str, Field(pattern=r"^ldp_vc$|^ldp_vp$|^ldp$")], PresentationDefinitionClaimFormatDesignations2, ], Dict[ - Annotated[str, Field(pattern=r'^vc\+sd-jwt$')], + Annotated[str, Field(pattern=r"^vc\+sd-jwt$")], PresentationDefinitionClaimFormatDesignations2, - ] - ] = Field(..., title='Presentation Definition Claim Format Designations') + ], + ] = Field(..., title="Presentation Definition Claim Format Designations") class Rule(Enum): - pick = 'pick' + pick = "pick" class SubmissionRequirement1(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) name: Optional[str] = None rule: Rule count: Optional[conint(ge=1)] = None - from_: str = Field(..., alias='from') + from_: str = Field(..., alias="from") class SubmissionRequirement(RootModel[SubmissionRequirement1]): @@ -97,7 +97,7 @@ class SubmissionRequirement(RootModel[SubmissionRequirement1]): class InputDescriptor(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) id: str name: Optional[str] = None @@ -109,7 +109,7 @@ class InputDescriptor(BaseModel): class PresentationDefinition(BaseModel): model_config = ConfigDict( - extra='forbid', + extra="forbid", ) id: str input_descriptors: List[InputDescriptor] diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index dcb9e994..2c211692 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -1,13 +1,12 @@ -from pyeudiw.satosa.default.openid4vp_backend import OpenID4VPBackend as OID4VP -from pyeudiw.satosa.default.request_handler import RequestHandler -from pyeudiw.satosa.default.response_handler import ResponseHandler +from typing import Callable from satosa.context import Context from satosa.internal import InternalData from satosa.response import Response -from typing import Callable - +from pyeudiw.satosa.default.openid4vp_backend import OpenID4VPBackend as OID4VP +from pyeudiw.satosa.default.request_handler import RequestHandler +from pyeudiw.satosa.default.response_handler import ResponseHandler from pyeudiw.tools.utils import get_dynamic_class @@ -15,13 +14,15 @@ class OpenID4VPBackend(RequestHandler, ResponseHandler, OID4VP): """ A backend module (acting as a OpenID4VP SP). """ - def __new__(cls, - auth_callback_func: Callable[[Context, InternalData], Response], - internal_attributes: dict[str, dict[str, str | list[str]]], - config: dict[str, dict[str, str] | list[str]], - base_url: str, - name: str - ): + + def __new__( + cls, + auth_callback_func: Callable[[Context, InternalData], Response], + internal_attributes: dict[str, dict[str, str | list[str]]], + config: dict[str, dict[str, str] | list[str]], + base_url: str, + name: str, + ): """ Create a backend dynamically. @@ -49,26 +50,28 @@ def __new__(cls, tmp_bases = list(cls.__bases__) - if isinstance(request_backend_conf, dict) \ - and request_backend_conf.get("module", None) \ - and request_backend_conf.get("class", None): + if ( + isinstance(request_backend_conf, dict) + and request_backend_conf.get("module", None) + and request_backend_conf.get("class", None) + ): request_backend = get_dynamic_class( - request_backend_conf["module"], - request_backend_conf["class"] + request_backend_conf["module"], request_backend_conf["class"] ) tmp_bases[0] = request_backend response_handler_conf = dynamic_backend_conf.get("response", None) - if isinstance(response_handler_conf, dict) \ - and response_handler_conf.get("module", None) \ - and response_handler_conf.get("class", None): + if ( + isinstance(response_handler_conf, dict) + and response_handler_conf.get("module", None) + and response_handler_conf.get("class", None) + ): response_handler = get_dynamic_class( - response_handler_conf["module"], - response_handler_conf["class"] + response_handler_conf["module"], response_handler_conf["class"] ) tmp_bases[1] = response_handler diff --git a/pyeudiw/satosa/default/openid4vp_backend.py b/pyeudiw/satosa/default/openid4vp_backend.py index 79a52197..e27b3f50 100644 --- a/pyeudiw/satosa/default/openid4vp_backend.py +++ b/pyeudiw/satosa/default/openid4vp_backend.py @@ -1,39 +1,38 @@ import json -import pydantic import uuid - from typing import Callable -from urllib.parse import quote_plus, urlencode +import pydantic from satosa.context import Context from satosa.internal import InternalData from satosa.response import Redirect, Response +from pyeudiw.jwk import JWK from pyeudiw.openid4vp.authorization_request import build_authorization_request_url -from pyeudiw.openid4vp.utils import detect_flow_typ from pyeudiw.openid4vp.schemas.flow import RemoteFlowType +from pyeudiw.openid4vp.utils import detect_flow_typ +from pyeudiw.tools.base_logger import BaseLogger from pyeudiw.satosa.schemas.config import PyeudiwBackendConfig -from pyeudiw.jwk import JWK from pyeudiw.satosa.utils.html_template import Jinja2TemplateHandler from pyeudiw.satosa.utils.respcode import ResponseCodeSource from pyeudiw.satosa.utils.response import JsonResponse -from pyeudiw.satosa.utils.trust import BackendTrust from pyeudiw.storage.db_engine import DBEngine from pyeudiw.storage.exceptions import StorageWriteError from pyeudiw.tools.utils import iat_now from pyeudiw.trust.dynamic import CombinedTrustEvaluator +from pyeudiw.trust.handler.interface import TrustHandlerInterface from ..interfaces.openid4vp_backend import OpenID4VPBackendInterface -class OpenID4VPBackend(OpenID4VPBackendInterface, BackendTrust): +class OpenID4VPBackend(OpenID4VPBackendInterface, BaseLogger): def __init__( self, auth_callback_func: Callable[[Context, InternalData], Response], internal_attributes: dict[str, dict[str, str | list[str]]], config: dict[str, dict[str, str] | list[str]], base_url: str, - name: str + name: str, ) -> None: """ OpenID4VP backend module. @@ -53,46 +52,62 @@ def __init__( """ super().__init__(auth_callback_func, internal_attributes, base_url, name) + # to be inizialized by .db_engine() property + self._db_engine = None + self.config = config self._backend_url = f"{base_url}/{name}" self._client_id = self._backend_url - self.config['metadata']['client_id'] = self.client_id + self.config["metadata"]["client_id"] = self.client_id - self.config['metadata']['response_uris_supported'] = [] - self.config['metadata']['response_uris_supported'].append( - f"{self._backend_url}/response-uri") + self.config["metadata"]["response_uris"] = [] + self.config["metadata"]["response_uris"].append( + f"{self._backend_url}/response-uri" + ) - self.config['metadata']['request_uris'] = [] - self.config['metadata']['request_uris'].append( - f"{self._backend_url}/request-uri") + self.config["metadata"]["request_uris"] = [] + self.config["metadata"]["request_uris"].append( + f"{self._backend_url}/request-uri" + ) - self.default_exp = int(self.config['jwt']['default_exp']) + self.default_exp = int(self.config["jwt"]["default_exp"]) - federation_jwks = self.config['trust']['federation']['config']['federation_jwks'] + federation_jwks = self.config["trust"]["federation"]["config"][ + "federation_jwks" + ] if isinstance(federation_jwks, str): try: - self.config['trust']['federation']['config']['federation_jwks'] = json.loads(federation_jwks) + self.config["trust"]["federation"]["config"]["federation_jwks"] = ( + json.loads(federation_jwks) + ) except json.JSONDecodeError as e: - raise ValueError(f"Invalid federation_jwks {self.config['trust']['federation']['config']['federation_jwks']} JSON: {e}") - - if isinstance(self.config['trust']['federation']['config']['federation_jwks'] , dict): - self.config['trust']['federation']['config']['federation_jwks'] = [self.config['trust']['federation']['config']['federation_jwks']] - - if isinstance(self.config['metadata_jwks'], str): + raise ValueError( + f"Invalid federation_jwks {self.config['trust']['federation']['config']['federation_jwks']} JSON: {e}" + ) + + if isinstance( + self.config["trust"]["federation"]["config"]["federation_jwks"], dict + ): + self.config["trust"]["federation"]["config"]["federation_jwks"] = [ + self.config["trust"]["federation"]["config"]["federation_jwks"] + ] + + if isinstance(self.config["metadata_jwks"], str): try: - self.config['metadata_jwks'] = json.loads(self.config['metadata_jwks']) + self.config["metadata_jwks"] = json.loads(self.config["metadata_jwks"]) except json.JSONDecodeError as e: - raise ValueError(f"Invalid metadata_jwks {self.config['metadata_jwks']} JSON: {e}") - - if isinstance(self.config['metadata_jwks'], dict): - self.config['metadata_jwks'] = [self.config['metadata_jwks']] - self.metadata_jwks_by_kids = { - i['kid']: i for i in self.config['metadata_jwks'] + raise ValueError( + f"Invalid metadata_jwks {self.config['metadata_jwks']} JSON: {e}" + ) + + if isinstance(self.config["metadata_jwks"], dict): + self.config["metadata_jwks"] = [self.config["metadata_jwks"]] + + self.metadata_jwks_by_kids = {i["kid"]: i for i in self.config["metadata_jwks"]} + self.config["metadata"]["jwks"] = { + "keys": [JWK(i).public_key for i in self.config["metadata_jwks"]] } - self.config['metadata']['jwks'] = {"keys": [ - JWK(i).public_key for i in self.config['metadata_jwks'] - ]} # HTML template loader self.template = Jinja2TemplateHandler(self.config["ui"]) @@ -104,9 +119,7 @@ def __init__( self.registered_get_response_endpoint = None self._server_url = ( - self.base_url[:-1] - if self.base_url[-1] == '/' - else self.base_url + self.base_url[:-1] if self.base_url[-1] == "/" else self.base_url ) try: @@ -116,18 +129,26 @@ def __init__( self._log_warning("OpenID4VPBackend", debug_message) self.response_code_helper = ResponseCodeSource( - self.config["response_code"]["sym_key"]) + self.config["response_code"]["sym_key"] + ) + + # This loads all the configured trust evaluation mechanisms trust_configuration = self.config.get("trust", {}) self.trust_evaluator = CombinedTrustEvaluator.from_config( - trust_configuration, self.db_engine) - # Questo carica risorse, metadata endpoint (sotto formate di attributi con pattern *_endpoint) etc, che satosa deve pubblicare - self.init_trust_resources() + trust_configuration, self.db_engine, default_client_id = self.client_id + ) + + def get_trust_backend_by_class_name(self, class_name: str) -> TrustHandlerInterface: + + for i in self.trust_evaluator.handlers: + if i.__class__.__name__ == class_name: + return i @property def client_id(self): - if (_cid := self.config["authorization"].get("client_id")): + if _cid := self.config["authorization"].get("client_id"): return _cid - elif (_cid := self.config["metadata"].get("client_id")): + elif _cid := self.config["metadata"].get("client_id"): return _cid else: return self._client_id @@ -140,12 +161,12 @@ def register_endpoints(self) -> list[tuple[str, Callable[[Context], Response]]]: :rtype: Sequence[(str, Callable[[satosa.context.Context], satosa.response.Response]] :return: A list that can be used to map the request to SATOSA to this endpoint. """ + # This loads the metadata endpoints required by the supported/configured trust evaluation methods url_map = self.trust_evaluator.build_metadata_endpoints( - self.name, - self._backend_url + self.name, self._backend_url ) - for k, v in self.config['endpoints'].items(): + for k, v in self.config["endpoints"].items(): endpoint_value = v if isinstance(endpoint_value, dict): @@ -153,19 +174,18 @@ def register_endpoints(self) -> list[tuple[str, Callable[[Context], Response]]]: if not endpoint_value or not isinstance(endpoint_value, str): raise ValueError( - f"Invalid endpoint value for \"{k}\". Given value: {endpoint_value}" + f"Invalid endpoint value for '{k}'. Given value: {endpoint_value}" ) url_map.append( ( f"^{self.name}/{endpoint_value.lstrip('/')}$", - getattr(self, f"{k}_endpoint") + getattr(self, f"{k}_endpoint"), ) ) _endpoint = f"{self._backend_url}/{endpoint_value.lstrip('/')}" self._log_debug( - "OpenID4VPBackend", - f"Exposing backend entity endpoint = {_endpoint}" + "OpenID4VPBackend", f"Exposing backend entity endpoint = {_endpoint}" ) match k: case "get_response": @@ -194,7 +214,9 @@ def start_auth(self, context: Context, internal_request) -> Response: """ return self.pre_request_endpoint(context, internal_request) - def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: + def pre_request_endpoint( + self, context: Context, internal_request, **kwargs + ) -> Response: """ This endpoint is called by the User-Agent/Wallet Instance before calling the request endpoint. It initializes the session and returns the request_uri to be used by the User-Agent/Wallet Instance. @@ -209,7 +231,8 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> """ self._log_function_debug( - "pre_request_endpoint", context, "internal_request", internal_request) + "pre_request_endpoint", context, "internal_request", internal_request + ) session_id = context.state["SESSION_ID"] state = str(uuid.uuid4()) @@ -219,7 +242,8 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> self._log_warning(context, _msg) return self._handle_400( context, - "previous authn session not found. It seems that the flow did not started with a valid authn request to one of the configured frontend." + "previous authn session not found. It seems that the flow did " + "not started with a valid authn request to one of the configured frontend.", ) flow_typ = detect_flow_typ(context) @@ -227,29 +251,30 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> # Init session try: self.db_engine.init_session( - state=state, - session_id=session_id, - remote_flow_typ=flow_typ.value + state=state, session_id=session_id, remote_flow_typ=flow_typ.value + ) + except StorageWriteError as e: + _msg = ( + f"Error while initializing session with state {state} and {session_id}." ) - except (StorageWriteError) as e: - _msg = f"Error while initializing session with state {state} and {session_id}." self._log_error(context, f"{_msg} for the following reason {e}") return self._handle_500(context, _msg, e) - except (Exception) as e: - _msg = f"Error while initializing session with state {state} and {session_id}." + except Exception as e: + _msg = ( + f"Error while initializing session with state {state} and {session_id}." + ) self._log_error(context, _msg) return self._handle_500(context, _msg, e) # PAR payload = { - 'client_id': self.client_id, - 'request_uri': f"{self.absolute_request_url}?id={state}", + "client_id": self.client_id, + "request_uri": f"{self.absolute_request_url}?id={state}", } response_url = build_authorization_request_url( - self.config["authorization"]["url_scheme"], - payload + self.config["authorization"]["url_scheme"], payload ) match flow_typ: @@ -260,7 +285,11 @@ def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> case unsupported: _msg = f"unrecognized remote flow type: {unsupported}" self._log_error(context, _msg) - return self._handle_500(context, "something went wrong when creating your authentication request", Exception(_msg)) + return self._handle_500( + context, + "something went wrong when creating your authentication request", + Exception(_msg), + ) def _same_device_http_response(self, response_url: str) -> Response: return Redirect(response_url) @@ -274,7 +303,7 @@ def _cross_device_http_response(self, response_url: str, state: str) -> Response "qrcode_logo_path": self.config["qrcode"]["logo_path"], "qrcode_expiration_time": self.config["qrcode"]["expiration_time"], "state": state, - "status_endpoint": self.absolute_status_url + "status_endpoint": self.absolute_status_url, } ) return Response(result, content="text/html; charset=utf8", status="200") @@ -292,14 +321,15 @@ def get_response_endpoint(self, context: Context) -> Response: try: state = self.response_code_helper.recover_state(resp_code) except Exception: - return self._handle_400(context, "missing or invalid parameter [response_code]") + return self._handle_400( + context, "missing or invalid parameter [response_code]" + ) finalized_session = None try: finalized_session = self.db_engine.get_by_state_and_session_id( - state=state, - session_id=session_id + state=state, session_id=session_id ) except Exception as e: _msg = f"Error while retrieving internal response with response_code {resp_code} and session_id {session_id}: {e}" @@ -309,19 +339,17 @@ def get_response_endpoint(self, context: Context) -> Response: return self._handle_400(context, "session not found or invalid") _now = iat_now() - _exp = finalized_session['request_object']['exp'] + _exp = finalized_session["request_object"]["exp"] if _exp < _now: - return self._handle_400(context, f"session expired, request object exp is {_exp} while now is {_now}") + return self._handle_400( + context, + f"session expired, request object exp is {_exp} while now is {_now}", + ) internal_response = InternalData() - resp = internal_response.from_dict( - finalized_session['internal_response'] - ) + resp = internal_response.from_dict(finalized_session["internal_response"]) - return self.auth_callback_func( - context, - resp - ) + return self.auth_callback_func(context, resp) def status_endpoint(self, context: Context) -> JsonResponse: @@ -354,58 +382,38 @@ def status_endpoint(self, context: Context) -> JsonResponse: if iat_now() > request_object["exp"]: return self._handle_403("expired", "Request object expired") - if (session["finalized"] is True): + if session["finalized"] is True: resp_code = self.response_code_helper.create_code(state) return JsonResponse( { "redirect_uri": f"{self.registered_get_response_endpoint}?response_code={resp_code}" }, - status="200" + status="200", ) else: if request_object is not None: - return JsonResponse( - { - "response": "Accepted" - }, - status="202" - ) + return JsonResponse({"response": "Accepted"}, status="202") - return JsonResponse( - { - "response": "Request object issued" - }, - status="201" - ) + return JsonResponse({"response": "Request object issued"}, status="201") @property def db_engine(self) -> DBEngine: """Returns the DBEngine instance used by the class""" + if not self._db_engine: + self._db_engine = DBEngine(self.config["storage"]) + try: self._db_engine.is_connected except Exception as e: - if getattr(self, '_db_engine', None): + if getattr(self, "_db_engine", None): self._log_debug( "OpenID4VP db storage handling", - f"connection check silently fails and get restored: {e}" + f"connection check silently fails and get restored: {e}", ) self._db_engine = DBEngine(self.config["storage"]) return self._db_engine - def _build_authz_request_url(self, payload: dict) -> str: - scheme = self.config["authorization"]["url_scheme"] - if "://" not in scheme: - scheme = scheme + "://" - if not scheme.endswith("/"): - scheme = scheme + "/" - # NOTE: path component is currently unused by the protocol, but currently - # we leave it there as 'authorize' to stress the fact that this is an - # OAuth 2.0 request modified by JAR (RFC9101) - path = "authorize" - query_params = urlencode(payload, quote_via=quote_plus) - return f"{scheme}{path}?{query_params}" - @property def default_metadata_private_jwk(self) -> tuple: """Returns the default metadata private JWK""" diff --git a/pyeudiw/satosa/default/request_handler.py b/pyeudiw/satosa/default/request_handler.py index 4cc5cc0d..4f811c96 100644 --- a/pyeudiw/satosa/default/request_handler.py +++ b/pyeudiw/satosa/default/request_handler.py @@ -1,19 +1,17 @@ - from satosa.context import Context from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.openid4vp.authorization_request import build_authorization_request_claims from pyeudiw.satosa.exceptions import HTTPError from pyeudiw.satosa.interfaces.request_handler import RequestHandlerInterface -from pyeudiw.satosa.utils.dpop import BackendDPoP from pyeudiw.satosa.utils.response import Response -from pyeudiw.satosa.utils.trust import BackendTrust +from pyeudiw.tools.base_logger import BaseLogger -class RequestHandler(RequestHandlerInterface, BackendDPoP, BackendTrust): +class RequestHandler(RequestHandlerInterface, BaseLogger): - _RESP_CONTENT_TYPE = "application/oauth-authz-req+jwt" _REQUEST_OBJECT_TYP = "oauth-authz-req+jwt" + _RESP_CONTENT_TYPE = f"application/{_REQUEST_OBJECT_TYP}" def request_endpoint(self, context: Context, *args) -> Response: self._log_function_debug("request_endpoint", context, "args", args) @@ -25,16 +23,18 @@ def request_endpoint(self, context: Context, *args) -> Response: "Error while retrieving id from qs_params: " f"{e.__class__.__name__}: {e}" ) - return self._handle_400(context, _msg, HTTPError(f"{e} with {context.__dict__}")) + return self._handle_400( + context, _msg, HTTPError(f"{e} with {context.__dict__}") + ) data = build_authorization_request_claims( self.client_id, state, self.absolute_response_url, - self.config["authorization"] + self.config["authorization"], ) - if (_aud := self.config["authorization"].get("aud")): + if _aud := self.config["authorization"].get("aud"): data["aud"] = _aud # take the session created in the pre-request authz endpoint try: @@ -44,22 +44,33 @@ def request_endpoint(self, context: Context, *args) -> Response: except ValueError as e: _msg = "Error while retrieving request object from database." - return self._handle_500(context, _msg, HTTPError(f"{e} with {context.__dict__}")) + return self._handle_500( + context, _msg, HTTPError(f"{e} with {context.__dict__}") + ) except (Exception, BaseException) as e: _msg = f"Error while updating request object: {e}" return self._handle_500(context, _msg, e) + _protected_jwt_headers = { + "typ": RequestHandler._REQUEST_OBJECT_TYP, + } + + # load all the trust handlers request jwt header parameters, if any + self.trust_evaluator.get_selfissued_jwt_header_trust_parameters(issuer=self.client_id) + + + # federation_trust_handler_backend_class: TrustHandlerInterface = ( + # self.get_trust_backend_by_class_name("FederationHandler") + # ) + helper = JWSHelper(self.default_metadata_private_jwk) request_object_jwt = helper.sign( data, - protected={ - 'trust_chain': self.get_backend_trust_chain(), - 'typ': RequestHandler._REQUEST_OBJECT_TYP - } + protected=_protected_jwt_headers, ) return Response( message=request_object_jwt, status="200", - content=RequestHandler._RESP_CONTENT_TYPE + content=RequestHandler._RESP_CONTENT_TYPE, ) diff --git a/pyeudiw/satosa/default/response_handler.py b/pyeudiw/satosa/default/response_handler.py index 8dbe7c37..e96eb0c1 100644 --- a/pyeudiw/satosa/default/response_handler.py +++ b/pyeudiw/satosa/default/response_handler.py @@ -1,69 +1,48 @@ -from copy import deepcopy import datetime import hashlib import json import logging -import pydantic - +from copy import deepcopy from typing import Any + from satosa.context import Context from satosa.internal import AuthenticationInformation, InternalData from satosa.response import Redirect from pyeudiw.jwt.jwe_helper import JWEHelper -from pyeudiw.openid4vp.authorization_response import AuthorizeResponsePayload, DirectPostJwtJweParser, DirectPostParser, detect_response_mode -from pyeudiw.openid4vp.exceptions import AuthRespParsingException, AuthRespValidationException, InvalidVPKeyBinding, InvalidVPToken, KIDNotFound +from pyeudiw.openid4vp.authorization_response import ( + AuthorizeResponsePayload, + DirectPostJwtJweParser, + DirectPostParser, + detect_response_mode, +) +from pyeudiw.openid4vp.exceptions import ( + AuthRespParsingException, + AuthRespValidationException, + InvalidVPKeyBinding, +) from pyeudiw.openid4vp.interface import VpTokenParser, VpTokenVerifier from pyeudiw.openid4vp.schemas.flow import RemoteFlowType from pyeudiw.openid4vp.schemas.response import ResponseMode -from pyeudiw.openid4vp.vp import Vp from pyeudiw.openid4vp.vp_sd_jwt_vc import VpVcSdJwtParserVerifier -from pyeudiw.openid4vp.vp_sd_jwt import VpSdJwt -from pyeudiw.satosa.exceptions import (AuthorizeUnmatchedResponse, BadRequestError, FinalizedSessionError, - InvalidInternalStateError, NotTrustedFederationError, HTTPError) +from pyeudiw.satosa.exceptions import ( + AuthorizeUnmatchedResponse, + FinalizedSessionError, + HTTPError, + InvalidInternalStateError, +) from pyeudiw.satosa.interfaces.response_handler import ResponseHandlerInterface from pyeudiw.satosa.utils.response import JsonResponse -from pyeudiw.satosa.utils.trust import BackendTrust from pyeudiw.sd_jwt.schema import VerifierChallenge from pyeudiw.storage.exceptions import StorageWriteError from pyeudiw.tools.utils import iat_now -class ResponseHandler(ResponseHandlerInterface, BackendTrust): +class ResponseHandler(ResponseHandlerInterface): _SUPPORTED_RESPONSE_METHOD = "post" _SUPPORTED_RESPONSE_CONTENT_TYPE = "application/x-www-form-urlencoded" _ACCEPTED_ISSUER_METADATA_TYPE = "openid_credential_issuer" - def _handle_credential_trust(self, context: Context, vp: Vp) -> bool: - try: - # establish the trust with the issuer of the credential by checking it to the revocation - # inspect VP's iss or trust_chain if available or x5c if available - # TODO: X.509 as alternative to Federation - - # for each single vp token, take the credential within it, use cnf.jwk to validate the vp token signature -> if not exception - # establish the trust to each credential issuer - tchelper = self._validate_trust(context, vp.payload['vp']) - - if not tchelper.is_trusted: - return self._handle_400(context, f"Trust Evaluation failed for {tchelper.entity_id}") - - # TODO: generalyze also for x509 - if isinstance(vp, VpSdJwt): - credential_jwks = tchelper.get_trusted_jwks( - metadata_type='openid_credential_issuer' - ) - vp.set_credential_jwks(credential_jwks) - except InvalidVPToken: - return self._handle_400(context, f"Cannot validate VP: {vp.jwt}") - except pydantic.ValidationError as e: - return self._handle_400(context, f"Error validating schemas: {e}") - except KIDNotFound as e: - return self._handle_400(context, f"Kid error: {e}") - except NotTrustedFederationError as e: - return self._handle_400(context, f"Not trusted federation error: {e}") - except Exception as e: - return self._handle_400(context, f"VP parsing error: {e}") - def _extract_all_user_attributes(self, attributes_by_issuers: dict) -> dict: # for all the valid credentials, take the payload and the disclosure and disclose user attributes # returns the user attributes ... @@ -72,31 +51,6 @@ def _extract_all_user_attributes(self, attributes_by_issuers: dict) -> dict: all_user_attributes.update(**i) return all_user_attributes - def _parse_http_request(self, context: Context) -> dict: - """Parse the http layer of the request to extract the dictionary data. - - :param context: the satosa context containing, among the others, the details of the HTTP request - :type context: satosa.Context - - :return: a dictionary containing the request data - :rtype: dict - - :raises BadRequestError: when request paramets are in a not processable state; the expected handling is returning 400 - """ - if (http_method := context.request_method.lower()) != ResponseHandler._SUPPORTED_RESPONSE_METHOD: - raise BadRequestError(f"HTTP method [{http_method}] not supported") - - if (content_type := context.http_headers['HTTP_CONTENT_TYPE']) != ResponseHandler._SUPPORTED_RESPONSE_CONTENT_TYPE: - raise BadRequestError( - f"HTTP content type [{content_type}] not supported") - - _endpoint = f"{self.server_url}{context.request_uri}" - if self.config["metadata"].get('response_uris_supported', None): - if _endpoint not in self.config["metadata"]['response_uris_supported']: - raise BadRequestError("response_uri not valid") - - return context.request - def _retrieve_session_from_state(self, state: str) -> dict: """_retrieve_session_and_nonce_from_state tries to recover an authenticasion session by matching it with the state. Returns the whole @@ -117,53 +71,51 @@ def _retrieve_session_from_state(self, state: str) -> dict: request_session = self.db_engine.get_by_state(state=state) except Exception as err: raise AuthorizeUnmatchedResponse( - f"unable to find document-session associated to state {state}", err) + f"unable to find document-session associated to state {state}", err + ) if not request_session: raise InvalidInternalStateError( - f"unable to find document-session associated to state {state}") + f"unable to find document-session associated to state {state}" + ) if request_session.get("finalized", True): raise FinalizedSessionError( - f"cannot accept response: session for state {state} corrupted or already finalized") + f"cannot accept response: session for state {state} corrupted or already finalized" + ) nonce = request_session.get("nonce", None) if not nonce: raise InvalidInternalStateError( - f"unable to find nonce in session associated to state {state}: corrupted data") + f"unable to find nonce in session associated to state {state}: corrupted data" + ) return request_session - def _is_same_device_flow(request_session: dict, context: Context) -> bool: - initiating_session_id: str | None = request_session.get( - "session_id", None) - if initiating_session_id is None: - raise ValueError( - "invalid session storage information: missing [session_id]") - current_session_id: str | None = context.state.get("SESSION_ID", None) - if current_session_id is None: - raise ValueError( - "missing session id in wallet authorization response") - return initiating_session_id == current_session_id - - def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: + def response_endpoint( + self, context: Context, *args: tuple + ) -> Redirect | JsonResponse: self._log_function_debug("response_endpoint", context, "args", args) # parse and eventually decrypt jwt in response try: - authz_payload: AuthorizeResponsePayload = self._parse_authorization_response( - context) + authz_payload: AuthorizeResponsePayload = ( + self._parse_authorization_response(context) + ) except AuthRespParsingException as e400: self._handle_400(context, e400.args[0], e400.args[1]) except AuthRespValidationException as e401: self._handle_401( - context, "invalid authentication method: token might be invalid or expired", e401) + context, + "invalid authentication method: token might be invalid or expired", + e401, + ) self._log_debug( - context, f"response URI endpoint response with payload {authz_payload}") + context, f"response URI endpoint response with payload {authz_payload}" + ) request_session: dict = {} try: - request_session = self._retrieve_session_from_state( - authz_payload.state) + request_session = self._retrieve_session_from_state(authz_payload.state) except AuthorizeUnmatchedResponse as e400: return self._handle_400(context, e400.args[0], e400.args[1]) except InvalidInternalStateError as e500: @@ -177,20 +129,23 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe # (3) we use all disclosed claims in vp tokens to build the user identity attributes_by_issuer: dict[str, dict[str, Any]] = {} credential_issuers: list[str] = [] - encoded_vps: list[str] = [authz_payload.vp_token] if isinstance( - authz_payload.vp_token, str) else authz_payload.vp_token + encoded_vps: list[str] = ( + [authz_payload.vp_token] + if isinstance(authz_payload.vp_token, str) + else authz_payload.vp_token + ) for vp_token in encoded_vps: # verify vp token and extract user information try: token_parser, token_verifier = self._vp_verifier_factory( - authz_payload.presentation_submission, vp_token, request_session) + authz_payload.presentation_submission, vp_token, request_session + ) except ValueError as e: return self._handle_400(context, f"VP parsing error: {e}") token_issuer = token_parser.get_issuer_name() - whitelisted_keys = self.trust_evaluator.get_public_keys( - token_issuer) + whitelisted_keys = self.trust_evaluator.get_public_keys(token_issuer) try: token_verifier.verify_signature(whitelisted_keys) except Exception as e: @@ -204,28 +159,27 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe claims = token_parser.get_credentials() iss = token_parser.get_issuer_name() attributes_by_issuer[iss] = claims - self._log_debug( - context, f"disclosed claims {claims} from issuer {iss}") + self._log_debug(context, f"disclosed claims {claims} from issuer {iss}") - all_attributes = self._extract_all_user_attributes( - attributes_by_issuer) - iss_list_serialized = ";".join( - credential_issuers) # marshaling is whatever + all_attributes = self._extract_all_user_attributes(attributes_by_issuer) + iss_list_serialized = ";".join(credential_issuers) # marshaling is whatever internal_resp = self._translate_response( - all_attributes, iss_list_serialized, context) + all_attributes, iss_list_serialized, context + ) state = authz_payload.state response_code = self.response_code_helper.create_code(state) try: self.db_engine.update_response_object( - request_session['nonce'], state, internal_resp + request_session["nonce"], state, internal_resp ) # authentication finalized! - self.db_engine.set_finalized(request_session['document_id']) + self.db_engine.set_finalized(request_session["document_id"]) if self.effective_log_level == logging.DEBUG: request_session = self.db_engine.get_by_state(state=state) self._log_debug( - context, f"Session update on storage: {request_session}") + context, f"Session update on storage: {request_session}" + ) except StorageWriteError as e: # TODO - do we have to block in the case the update cannot be done? @@ -236,8 +190,11 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe flow_type = RemoteFlowType(request_session["remote_flow_typ"]) except ValueError as e: self._log_error( - context, f"unable to identify flow from stored session: {e}") - return self._handle_500(context, "error in authentication response processing", e) + context, f"unable to identify flow from stored session: {e}" + ) + return self._handle_500( + context, "error in authentication response processing", e + ) match flow_type: case RemoteFlowType.SAME_DEVICE: @@ -248,9 +205,15 @@ def response_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonRe case unsupported: _msg = f"unrecognized remote flow type: {unsupported}" self._log_error(context, _msg) - return self._handle_500(context, "error in authentication response processing", Exception(_msg)) + return self._handle_500( + context, + "error in authentication response processing", + Exception(_msg), + ) - def _translate_response(self, response: dict, issuer: str, context: Context) -> InternalData: + def _translate_response( + self, response: dict, issuer: str, context: Context + ) -> InternalData: """ Translates wallet response to SATOSA internal response. :type response: dict[str, str] @@ -264,63 +227,52 @@ def _translate_response(self, response: dict, issuer: str, context: Context) -> """ # it may depends by credential type and attested security context evaluated # if WIA was previously submitted by the Wallet - timestamp_epoch = ( - response.get("auth_time") - or response.get("iat") - or iat_now() - ) + timestamp_epoch = response.get("auth_time") or response.get("iat") or iat_now() timestamp_dt = datetime.datetime.fromtimestamp( - timestamp_epoch, - datetime.timezone.utc + timestamp_epoch, datetime.timezone.utc ) timestamp_iso = timestamp_dt.isoformat().replace("+00:00", "Z") auth_class_ref = ( - response.get("acr") or - response.get("amr") or - self.config["authorization"]["default_acr_value"] + response.get("acr") + or response.get("amr") + or self.config["authorization"]["default_acr_value"] ) - auth_info = AuthenticationInformation( - auth_class_ref, timestamp_iso, issuer) + auth_info = AuthenticationInformation(auth_class_ref, timestamp_iso, issuer) # TODO - ACR values internal_resp = InternalData(auth_info=auth_info) # (re)define the response subject sub = "" - pepper = self.config.get("user_attributes", {})[ - 'subject_id_random_value' - ] + pepper = self.config.get("user_attributes", {})["subject_id_random_value"] for i in self.config.get("user_attributes", {}).get("unique_identifiers", []): if response.get(i): _sub = response[i] - sub = hashlib.sha256( - f"{_sub}~{pepper}".encode( - ) - ).hexdigest() + sub = hashlib.sha256(f"{_sub}~{pepper}".encode()).hexdigest() break if not sub: self._log( context, - level='warning', + level="warning", message=( "[USER ATTRIBUTES] Missing subject id from OpenID4VP presentation " "setting a random one for interop for internal frontends" - ) + ), ) sub = hashlib.sha256( f"{json.dumps(response).encode()}~{pepper}".encode() ).hexdigest() response["sub"] = [sub] - internal_resp.attributes = self.converter.to_internal( - "openid4vp", response - ) + internal_resp.attributes = self.converter.to_internal("openid4vp", response) internal_resp.subject_id = sub return internal_resp - def _parse_authorization_response(self, context: Context) -> AuthorizeResponsePayload: + def _parse_authorization_response( + self, context: Context + ) -> AuthorizeResponsePayload: response_mode = detect_response_mode(context) match response_mode: case ResponseMode.direct_post: @@ -333,15 +285,18 @@ def _parse_authorization_response(self, context: Context) -> AuthorizeResponsePa case _: raise AuthRespParsingException( f"invalid or unrecognized response mode: {response_mode}", - Exception("invalid program state") + Exception("invalid program state"), ) - def _vp_verifier_factory(self, presentation_submission: dict, token: str, session_data: dict) -> tuple[VpTokenParser, VpTokenVerifier]: + def _vp_verifier_factory( + self, presentation_submission: dict, token: str, session_data: dict + ) -> tuple[VpTokenParser, VpTokenVerifier]: # TODO: la funzione dovrebbe consumare la presentation submission per sapere quale token # ritornare - per ora viene ritornata l'unica implementazione possibile challenge = self._get_verifier_challenge(session_data) token_processor = VpVcSdJwtParserVerifier( - token, challenge["aud"], challenge["nonce"]) + token, challenge["aud"], challenge["nonce"] + ) return (token_processor, deepcopy(token_processor)) def _get_verifier_challenge(self, session_data: dict) -> VerifierChallenge: diff --git a/pyeudiw/satosa/exceptions.py b/pyeudiw/satosa/exceptions.py index 1c04562f..42be71e4 100644 --- a/pyeudiw/satosa/exceptions.py +++ b/pyeudiw/satosa/exceptions.py @@ -4,6 +4,7 @@ class BadRequestError(Exception): This exception should be raised when we want to return an HTTP 400 Bad Request """ + pass class InternalServerError(Exception): @@ -12,6 +13,7 @@ class InternalServerError(Exception): This exception should be raised when we want to return an HTTP 400 Bad Request """ + pass class InvalidInternalStateError(InternalServerError): @@ -19,6 +21,7 @@ class InvalidInternalStateError(InternalServerError): This is specification of InternalServerError that specify that the internal error is caused by an invalid backend, storage or cache state. """ + pass class FinalizedSessionError(BadRequestError): @@ -26,15 +29,6 @@ class FinalizedSessionError(BadRequestError): Raised when an authorization request or respsonse attempts at updating or modifying an already finalized authentication session. """ - - -class NoBoundEndpointError(Exception): - """ - Raised when a given url path is not bound to any endpoint function - """ - - -class NotTrustedFederationError(Exception): pass @@ -42,27 +36,25 @@ class DiscoveryFailedError(Exception): """ Raised when the discovery fails """ + pass class HTTPError(Exception): """ Raised when an error occurs during an HTTP request """ + pass class EmptyHTTPError(HTTPError): """ Default HTTP empty error """ - - -class DPOPValidationError(Exception): - """ - Raised when a DPoP validation error occurs - """ + pass class AuthorizeUnmatchedResponse(Exception): """ Raised when an authorization response cannot be matched to an authentication request """ + pass diff --git a/pyeudiw/satosa/interfaces/openid4vp_backend.py b/pyeudiw/satosa/interfaces/openid4vp_backend.py index 02c6a352..bd4b5905 100644 --- a/pyeudiw/satosa/interfaces/openid4vp_backend.py +++ b/pyeudiw/satosa/interfaces/openid4vp_backend.py @@ -7,7 +7,9 @@ class OpenID4VPBackendInterface(EventHandlerInterface): - def pre_request_endpoint(self, context: Context, internal_request, **kwargs) -> Response: + def pre_request_endpoint( + self, context: Context, internal_request, **kwargs + ) -> Response: """ This endpoint is called by the User-Agent/Wallet Instance before calling the request endpoint. It initializes the session and returns the request_uri to be used by the User-Agent/Wallet Instance. diff --git a/pyeudiw/satosa/interfaces/request_handler.py b/pyeudiw/satosa/interfaces/request_handler.py index 91723b7d..8d971dd2 100644 --- a/pyeudiw/satosa/interfaces/request_handler.py +++ b/pyeudiw/satosa/interfaces/request_handler.py @@ -11,7 +11,9 @@ class RequestHandlerInterface(EventHandlerInterface): Interface for request handlers. """ - def request_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: + def request_endpoint( + self, context: Context, *args: tuple + ) -> Redirect | JsonResponse: """ This endpoint is called by the User-Agent/Wallet Instance to retrieve the signed signed Request Object. diff --git a/pyeudiw/satosa/schemas/autorization.py b/pyeudiw/satosa/schemas/autorization.py index 7c7f4295..281872cf 100644 --- a/pyeudiw/satosa/schemas/autorization.py +++ b/pyeudiw/satosa/schemas/autorization.py @@ -1,7 +1,9 @@ from pydantic import BaseModel, Field, HttpUrl from pyeudiw.openid4vp.schemas.response import ResponseMode -from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import PresentationDefinition +from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import ( + PresentationDefinition, +) class AuthorizationConfig(BaseModel): diff --git a/pyeudiw/satosa/schemas/config.py b/pyeudiw/satosa/schemas/config.py index b66b18c2..8a64c9f0 100644 --- a/pyeudiw/satosa/schemas/config.py +++ b/pyeudiw/satosa/schemas/config.py @@ -1,13 +1,16 @@ from pydantic import BaseModel -from pyeudiw.federation.schemas.wallet_relying_party import WalletRelyingParty -from pyeudiw.jwt.schemas.jwt import JWTConfig + +from pyeudiw.federation.schemas.openid_credential_verifier import ( + OpenIDCredentialVerifier, +) from pyeudiw.jwk.schemas.public import JwkSchema +from pyeudiw.jwt.schemas.jwt import JWTConfig +from pyeudiw.satosa.schemas.autorization import AuthorizationConfig from pyeudiw.satosa.schemas.endpoint import EndpointsConfig from pyeudiw.satosa.schemas.qrcode import QRCode from pyeudiw.satosa.schemas.response import ResponseConfig -from pyeudiw.satosa.schemas.autorization import AuthorizationConfig -from pyeudiw.satosa.schemas.user_attributes import UserAttributesConfig from pyeudiw.satosa.schemas.ui import UiConfig +from pyeudiw.satosa.schemas.user_attributes import UserAttributesConfig from pyeudiw.storage.schemas.storage import Storage from pyeudiw.trust.model import TrustModuleConfiguration_T @@ -24,4 +27,4 @@ class PyeudiwBackendConfig(BaseModel): trust: dict[str, TrustModuleConfiguration_T] metadata_jwks: list[JwkSchema] storage: Storage - metadata: WalletRelyingParty + metadata: OpenIDCredentialVerifier diff --git a/pyeudiw/satosa/schemas/endpoint.py b/pyeudiw/satosa/schemas/endpoint.py index ffb40eb1..e3c8918c 100644 --- a/pyeudiw/satosa/schemas/endpoint.py +++ b/pyeudiw/satosa/schemas/endpoint.py @@ -1,21 +1,24 @@ from typing import Union + from pydantic import BaseModel, field_validator _CONFIG_ENDPOINT_KEYS = ["module", "class", "path"] class EndpointsConfig(BaseModel): - pre_request: str + pre_request: Union[str, dict] response: Union[str, dict] request: Union[str, dict] - entity_configuration: str - status: str - get_response: str + status: Union[str, dict] + get_response: Union[str, dict] - @field_validator("pre_request", "entity_configuration", "status", "get_response") + @field_validator("pre_request", "response", "request", "status", "get_response") def must_start_with_slash(cls, v): - if not v.startswith('/'): - raise ValueError(f"{v} must start with '/'") + if isinstance(v, str) and not v.startswith("/"): + raise ValueError(f"Endpoints: {v} must start with '/'") + elif isinstance(v, dict): + if not v["path"].startswith("/"): + raise ValueError(f"Endpoints: {v['path']} must start with '/'") return v @field_validator("response", "request") @@ -25,10 +28,9 @@ def must_start_with_slash_path(cls, v): endpoint_value = v.get("path", None) if not endpoint_value or not isinstance(endpoint_value, str): - raise ValueError( - f"Invalid config endpoint structure for {endpoint_value}") + raise ValueError(f"Invalid config endpoint structure for {endpoint_value}") - if not endpoint_value.startswith('/'): + if not endpoint_value.startswith("/"): raise ValueError(f"{endpoint_value} must start with '/'") return v diff --git a/pyeudiw/satosa/schemas/qrcode.py b/pyeudiw/satosa/schemas/qrcode.py index d4a274c6..773e86aa 100644 --- a/pyeudiw/satosa/schemas/qrcode.py +++ b/pyeudiw/satosa/schemas/qrcode.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, field_validator, Field +from pydantic import BaseModel, Field, field_validator class QRCode(BaseModel): @@ -9,8 +9,8 @@ class QRCode(BaseModel): expiration_time: int = Field(..., gt=0) logo_path: str - @field_validator('logo_path') + @field_validator("logo_path") def must_start_with_slash(cls, v): - if v.startswith('/'): + if v.startswith("/"): raise ValueError(f'{v} must start without "/"') return v diff --git a/pyeudiw/satosa/schemas/user_attributes.py b/pyeudiw/satosa/schemas/user_attributes.py index a78036ab..6f77f04c 100644 --- a/pyeudiw/satosa/schemas/user_attributes.py +++ b/pyeudiw/satosa/schemas/user_attributes.py @@ -5,8 +5,8 @@ class UserAttributesConfig(BaseModel): unique_identifiers: list[str] subject_id_random_value: str - @field_validator('subject_id_random_value') + @field_validator("subject_id_random_value") def validate_subject_id_random_value(cls, v): - if v == 'CHANGEME!': + if v == "CHANGEME!": raise ValueError('subject_id_random_value must not be "CHANGEME!"') return v diff --git a/pyeudiw/satosa/utils/base_http_error_handler.py b/pyeudiw/satosa/utils/base_http_error_handler.py index 5b84f6ca..97615baa 100644 --- a/pyeudiw/satosa/utils/base_http_error_handler.py +++ b/pyeudiw/satosa/utils/base_http_error_handler.py @@ -7,13 +7,13 @@ class BaseHTTPErrorHandler(BaseLogger): def _serialize_error( - self, - context: Context, - message: str, - troubleshoot: str, - err: str, - err_code: str, - level: str + self, + context: Context, + message: str, + troubleshoot: str, + err: str, + err_code: str, + level: str, ) -> JsonResponse: """ Serializes an error. @@ -38,16 +38,10 @@ def _serialize_error( _msg = f"{message}:" if err: _msg += f" {err}." - self._log( - context, level=level, - message=f"{_msg} {troubleshoot}" - ) + self._log(context, level=level, message=f"{_msg} {troubleshoot}") - return JsonResponse({ - "error": message, - "error_description": troubleshoot - }, - status=err_code + return JsonResponse( + {"error": message, "error_description": troubleshoot}, status=err_code ) def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonResponse: @@ -71,10 +65,17 @@ def _handle_500(self, context: Context, msg: str, err: Exception) -> JsonRespons f"{msg}", f"{msg}. {err.__class__.__name__}: {err}", "500", - "error" + "error", ) - def _handle_40X(self, code_number: str, message: str, context: Context, troubleshoot: str, err: Exception) -> JsonResponse: + def _handle_40X( + self, + code_number: str, + message: str, + context: Context, + troubleshoot: str, + err: Exception, + ) -> JsonResponse: """ Handles a 40X error. @@ -99,10 +100,12 @@ def _handle_40X(self, code_number: str, message: str, context: Context, troubles troubleshoot, f"{err.__class__.__name__}: {err}", f"40{code_number}", - "error" + "error", ) - def _handle_400(self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("")) -> JsonResponse: + def _handle_400( + self, context: Context, troubleshoot: str, err: Exception = EmptyHTTPError("") + ) -> JsonResponse: """ Handles a 400 error. @@ -118,7 +121,9 @@ def _handle_400(self, context: Context, troubleshoot: str, err: Exception = Empt """ return self._handle_40X("0", "invalid_request", context, troubleshoot, err) - def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + def _handle_401( + self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("") + ): """ Handles a 401 error. @@ -135,7 +140,9 @@ def _handle_401(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTT return self._handle_40X("1", "invalid_client", context, troubleshoot, err) - def _handle_403(self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("")): + def _handle_403( + self, context, troubleshoot: str, err: EmptyHTTPError = EmptyHTTPError("") + ): """ Handles a 403 error. diff --git a/pyeudiw/satosa/utils/html_template.py b/pyeudiw/satosa/utils/html_template.py index f0a1b9b5..eeefaebe 100644 --- a/pyeudiw/satosa/utils/html_template.py +++ b/pyeudiw/satosa/utils/html_template.py @@ -32,9 +32,7 @@ def __init__(self, config: Dict[str, Any]): } ) - self.qrcode_page = self.loader.get_template( - config["qrcode_template"] - ) + self.qrcode_page = self.loader.get_template(config["qrcode_template"]) # TODO - for rendering custom errors # self.error_page = self.loader.get_template( diff --git a/pyeudiw/satosa/utils/respcode.py b/pyeudiw/satosa/utils/respcode.py index bdde5bfc..7c347d50 100644 --- a/pyeudiw/satosa/utils/respcode.py +++ b/pyeudiw/satosa/utils/respcode.py @@ -1,8 +1,9 @@ import base64 -from dataclasses import dataclass, field -from cryptography.hazmat.primitives.ciphers.aead import AESGCM import secrets import string +from dataclasses import dataclass, field + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM CODE_SYM_KEY_LEN = 32 # in bytes (256 bits) @@ -38,40 +39,40 @@ def recover_state(self, code: str) -> str: def decode_key(key: str) -> bytes: if not set(key) <= set(string.hexdigits): - raise ValueError( - "key in format different than hex currently not supported") + raise ValueError("key in format different than hex currently not supported") key_len = len(key) - if key_len != 2*CODE_SYM_KEY_LEN: + if key_len != 2 * CODE_SYM_KEY_LEN: raise ValueError( - f"invalid key: key should be {CODE_SYM_KEY_LEN} bytes, obtained instead: {key_len//2}") + f"invalid key: key should be {CODE_SYM_KEY_LEN} bytes, obtained instead: {key_len//2}" + ) return bytes.fromhex(key) def _base64_encode_no_pad(b: bytes) -> str: - return base64.urlsafe_b64encode(b).decode().rstrip('=') + return base64.urlsafe_b64encode(b).decode().rstrip("=") def _base64_decode_no_pad(s: str) -> bytes: - padded = s + "="*((4 - len(s) % 4) % 4) + padded = s + "=" * ((4 - len(s) % 4) % 4) return base64.urlsafe_b64decode(padded) def _encrypt_state(msg: bytes, key: bytes) -> bytes: nonce = secrets.token_bytes(12) - ciphertext = AESGCM(key).encrypt(nonce, msg, b'') + ciphertext = AESGCM(key).encrypt(nonce, msg, b"") return nonce + ciphertext def _decrypt_code(encrypted_token: bytes, key: bytes) -> bytes: nonce = encrypted_token[:12] ciphertext = encrypted_token[12:] - dec = AESGCM(key).decrypt(nonce, ciphertext, b'') + dec = AESGCM(key).decrypt(nonce, ciphertext, b"") return dec def create_code(state: str, key: str) -> str: bkey = decode_key(key) - msg = bytes(state, encoding='utf-8') + msg = bytes(state, encoding="utf-8") code = _encrypt_state(msg, bkey) return _base64_encode_no_pad(code) @@ -80,4 +81,4 @@ def recover_state(code: str, key: str) -> str: bkey = decode_key(key) enc = _base64_decode_no_pad(code) state = _decrypt_code(enc, bkey) - return state.decode(encoding='utf-8') + return state.decode(encoding="utf-8") diff --git a/pyeudiw/satosa/utils/trust.py b/pyeudiw/satosa/utils/trust.py deleted file mode 100644 index d2978e19..00000000 --- a/pyeudiw/satosa/utils/trust.py +++ /dev/null @@ -1,212 +0,0 @@ -import json - -from satosa.context import Context -from satosa.response import Response - -from pyeudiw.jwk import JWK - -from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.jwt.utils import decode_jwt_header -from pyeudiw.satosa.exceptions import (DiscoveryFailedError, - NotTrustedFederationError) -from pyeudiw.storage.exceptions import EntryNotFound -from pyeudiw.tools.base_logger import BaseLogger -from pyeudiw.tools.utils import exp_from_now, iat_now -from pyeudiw.trust import TrustEvaluationHelper -from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs - - -class BackendTrust(BaseLogger): - """ - Backend Trust class. - """ - - def init_trust_resources(self) -> None: - """ - Initializes the trust resources. - """ - # TODO: adapt method to init ALL types of trust resources (if configured) - - # private keys by kid - self.federations_jwks_by_kids = { - i['kid']: i for i in self.config['trust']['federation']['config']['federation_jwks'] - } - # dumps public jwks - self.federation_public_jwks = [ - JWK(i).public_key for i in self.config['trust']['federation']['config']['federation_jwks'] - ] - # we close the connection in this constructor since it must be fork safe and - # get reinitialized later on, within each fork - self.update_trust_anchors() - - try: - self.get_backend_trust_chain() - except Exception as e: - self._log_critical( - "Backend Trust", - f"Cannot fetch the trust anchor configuration: {e}" - ) - - self.db_engine.close() - self._db_engine = None - - def entity_configuration_endpoint(self, context: Context) -> Response: - """ - Entity Configuration endpoint. - - :param context: The current context - :type context: Context - - :return: The entity configuration - :rtype: Response - """ - - if context.qs_params.get('format', '') == 'json': - return Response( - json.dumps(self.entity_configuration_as_dict), - status="200", - content="application/json" - ) - - return Response( - self.entity_configuration, - status="200", - content="application/entity-statement+jwt" - ) - - def update_trust_anchors(self): - """ - Updates the trust anchors of current instance. - """ - - tas = self.config['trust']['federation']['config']['trust_anchors'] - self._log_info("Trust Anchors updates", f"Trying to update: {tas}") - - for ta in tas: - try: - update_trust_anchors_ecs( - db=self.db_engine, - trust_anchors=[ta], - httpc_params=self.config['network']['httpc_params'] - ) - except Exception as e: - self._log_warning("Trust Anchor updates", - f"{ta} update failed: {e}") - - self._log_info("Trust Anchor updates", f"{ta} updated") - - def get_backend_trust_chain(self) -> list[str]: - """ - Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message - and returns an empty list. - - :return: The trust chain - :rtype: list - """ - try: - trust_evaluation_helper = TrustEvaluationHelper.build_trust_chain_for_entity_id( - storage=self.db_engine, - entity_id=self.client_id, - entity_configuration=self.entity_configuration, - httpc_params=self.config['network']['httpc_params'] - ) - self.db_engine.add_or_update_trust_attestation( - entity_id=self.client_id, - attestation=trust_evaluation_helper.trust_chain, - exp=trust_evaluation_helper.exp - ) - return trust_evaluation_helper.trust_chain - - except (DiscoveryFailedError, EntryNotFound, Exception) as e: - message = ( - f"Error while building trust chain for client with id: {self.client_id}. " - f"{e.__class__.__name__}: {e}" - ) - self._log_warning("Trust Chain", message) - - return [] - - def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: - """ - Validates the trust of the given jws. - - :param context: the request context - :type context: satosa.context.Context - :param jws: the jws to validate - :type jws: str - - :raises: NotTrustedFederationError: raises an error if the trust evaluation fails. - - :return: the trust evaluation helper - :rtype: TrustEvaluationHelper - """ - - self._log_debug(context, "[TRUST EVALUATION] evaluating trust.") - - headers = decode_jwt_header(jws) - trust_eval = TrustEvaluationHelper( - self.db_engine, - httpc_params=self.config['network']['httpc_params'], - **headers - ) - - try: - trust_eval.evaluation_method() - except EntryNotFound: - message = ( - "[TRUST EVALUATION] not found for " - f"{trust_eval.entity_id}" - ) - self._log_error(context, message) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} not found for Trust evaluation." - ) - except Exception as e: - message = ( - "[TRUST EVALUATION] failed for " - f"{trust_eval.entity_id}: {e}" - ) - self._log_error(context, message) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} is not trusted." - ) - - return trust_eval - - @property - def default_federation_private_jwk(self) -> dict: - """Returns the default federation private jwk.""" - return tuple(self.federations_jwks_by_kids.values())[0] - - @property - def entity_configuration_as_dict(self) -> dict: - """Returns the entity configuration as a dictionary.""" - ec_payload = { - "exp": exp_from_now(minutes=self.default_exp), - "iat": iat_now(), - "iss": self.client_id, - "sub": self.client_id, - "jwks": { - "keys": self.federation_public_jwks - }, - "metadata": { - self.config['trust']['federation']['config']["metadata_type"]: self.config['metadata'], - "federation_entity": self.config['trust']['federation']['config']['federation_entity_metadata'] - }, - "authority_hints": self.config['trust']['federation']['config']['authority_hints'] - } - return ec_payload - - @property - def entity_configuration(self) -> dict: - """Returns the entity configuration as a JWT.""" - data = self.entity_configuration_as_dict - jwshelper = JWSHelper(self.default_federation_private_jwk) - return jwshelper.sign( - protected={ - "alg": self.config['trust']['federation']['config']["default_sig_alg"], - "kid": self.default_federation_private_jwk["kid"], - "typ": "entity-statement+jwt" - }, - plain_dict=data - ) diff --git a/pyeudiw/sd_jwt/common.py b/pyeudiw/sd_jwt/common.py index a4947b04..43e2ca1b 100644 --- a/pyeudiw/sd_jwt/common.py +++ b/pyeudiw/sd_jwt/common.py @@ -2,20 +2,14 @@ import os import random import secrets - from base64 import urlsafe_b64decode, urlsafe_b64encode from dataclasses import dataclass from hashlib import sha256 from json import loads -from typing import List - -from . exceptions import SDJWTHasSDClaimException +from typing import List, Union -from . import ( - SD_DIGESTS_KEY, - JSON_SER_DISCLOSURE_KEY, - JSON_SER_KB_JWT_KEY -) +from pyeudiw.sd_jwt import JSON_SER_DISCLOSURE_KEY, JSON_SER_KB_JWT_KEY, SD_DIGESTS_KEY +from pyeudiw.sd_jwt.exceptions import SDJWTHasSDClaimException logger = logging.getLogger(__name__) @@ -26,15 +20,16 @@ class SDObj: value: any - # Make hashable def __hash__(self): + """Hash the object.""" return hash(self.value) class SDJWTCommon: SD_JWT_HEADER = os.getenv( # TODO: dc is only for digital credential, while you might use another typ ... - "SD_JWT_HEADER", "dc+sd-jwt" + "SD_JWT_HEADER", + "dc+sd-jwt", ) # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"} KB_JWT_TYP_HEADER = "kb+jwt" HASH_ALG = {"name": "sha-256", "fn": sha256} @@ -45,30 +40,79 @@ class SDJWTCommon: def __init__(self, serialization_format): if serialization_format not in ("compact", "json"): - raise ValueError( - f"Unknown serialization format: {serialization_format}") + raise ValueError(f"Unknown serialization format: {serialization_format}") self._serialization_format = serialization_format - def _b64hash(self, raw): - # Calculate the SHA 256 hash and output it base64 encoded + def _b64hash(self, raw: bytes) -> str: + """ + Calculate the SHA 256 hash and output it base64 encoded. + + :param raw: The raw data to hash. + :type raw: bytes + + :return: The base64 encoded hash. + :rtype: str + """ return self._base64url_encode(self.HASH_ALG["fn"](raw).digest()) - def _combine(self, *parts): + def _combine(self, *parts) -> str: + """ + Combine the parts with the separator. + + :param parts: The parts to combine. + :type parts: str + + :return: The combined string. + :rtype: str + """ return self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR.join(parts) - def _split(self, combined): + def _split(self, combined: str) -> List[str]: + """ + Split the combined string. + + :param combined: The combined string. + :type combined: str + + :return: The parts. + :rtype: List[str] + """ return combined.split(self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR) @staticmethod def _base64url_encode(data: bytes) -> str: + """ + Encode the data in base64url encoding. + + :param data: The data to encode. + :type data: bytes + + :return: The base64url encoded data. + :rtype: str + """ return urlsafe_b64encode(data).decode("ascii").strip("=") @staticmethod def _base64url_decode(b64data: str) -> bytes: + """ + Decode the base64url encoded data. + + :param b64data: The base64url encoded data. + :type b64data: str + + :return: The decoded data. + :rtype: bytes + """ padded = f"{b64data}{'=' * divmod(len(b64data),4)[1]}" return urlsafe_b64decode(padded) - def _generate_salt(self): + def _generate_salt(self) -> str: + """ + Generate a salt. + + :return: The salt. + :rtype: str + """ if self.unsafe_randomness: # This is not cryptographically secure, but it is deterministic # and allows for repeatable output for the generation of the examples. @@ -81,7 +125,14 @@ def _generate_salt(self): else: return self._base64url_encode(secrets.token_bytes(16)) - def _create_hash_mappings(self, disclosurses_list: List): + def _create_hash_mappings(self, disclosurses_list: List) -> None: + """ + Create the hash mappings for the disclosures. + + :param disclosurses_list: The list of disclosures. + :type disclosurses_list: List + """ + # Mapping from hash of disclosure to the decoded disclosure self._hash_to_decoded_disclosure = {} @@ -101,33 +152,46 @@ def _create_hash_mappings(self, disclosurses_list: List): self._hash_to_decoded_disclosure[_hash] = decoded_disclosure self._hash_to_disclosure[_hash] = disclosure - def _check_for_sd_claim(self, the_object): + def _check_for_sd_claim(self, obj: Union[dict, list, any]) -> None: + """ + Check for the presence of the _sd claim in the object. + + :param obj: The object to check. + :type obj: Union[dict, list, any] + """ + # Recursively check for the presence of the _sd claim, also # works for arrays and nested objects. - if isinstance(the_object, dict): - for key, value in the_object.items(): + if isinstance(obj, dict): + for key, value in obj.items(): if key == SD_DIGESTS_KEY: - raise SDJWTHasSDClaimException(the_object) + raise SDJWTHasSDClaimException(obj) else: self._check_for_sd_claim(value) - elif isinstance(the_object, list): - for item in the_object: + elif isinstance(obj, list): + for item in obj: self._check_for_sd_claim(item) else: return - def _parse_sd_jwt(self, sd_jwt): + def _parse_sd_jwt(self, sd_jwt: str) -> None: + """ + Parse the SD-JWT. + + :param sd_jwt: The SD-JWT to parse. + :type sd_jwt: str + """ + if self._serialization_format == "compact": ( self._unverified_input_sd_jwt, *self._input_disclosures, - self._unverified_input_key_binding_jwt + self._unverified_input_key_binding_jwt, ) = self._split(sd_jwt) # Extract only the body from SD-JWT without verifying the signature _, jwt_body, _ = self._unverified_input_sd_jwt.split(".") - self._unverified_input_sd_jwt_payload = self._base64url_decode( - jwt_body) + self._unverified_input_sd_jwt_payload = self._base64url_decode(jwt_body) self._unverified_compact_serialized_input_sd_jwt = ( self._unverified_input_sd_jwt ) @@ -138,8 +202,7 @@ def _parse_sd_jwt(self, sd_jwt): self._unverified_input_sd_jwt_parsed = loads(sd_jwt) self._unverified_input_sd_jwt_payload = loads( - self._base64url_decode( - self._unverified_input_sd_jwt_parsed["payload"]) + self._base64url_decode(self._unverified_input_sd_jwt_parsed["payload"]) ) # distinguish between flattened and general JSON serialization (RFC7515) @@ -157,7 +220,7 @@ def _parse_sd_jwt(self, sd_jwt): [ self._unverified_input_sd_jwt_parsed["protected"], self._unverified_input_sd_jwt_parsed["payload"], - self._unverified_input_sd_jwt_parsed["signature"] + self._unverified_input_sd_jwt_parsed["signature"], ] ) @@ -186,12 +249,20 @@ def _parse_sd_jwt(self, sd_jwt): else: raise ValueError("Invalid JSON serialization of SD-JWT") - def _calculate_kb_hash(self, disclosures): + def _calculate_kb_hash(self, disclosures: List[str]) -> str: + """ + Calculate the hash over the key binding. + + :param disclosures: The list of disclosures. + :type disclosures: List[str] + + :return: The hash over the key binding. + :rtype: str + """ + # Temporarily create the combined presentation in order to create the hash over it # Note: For JSON Serialization, the compact representation of the SD-JWT is restored from the parsed JSON (see common.py) string_to_hash = self._combine( - self._unverified_compact_serialized_input_sd_jwt, - *disclosures, - "" + self._unverified_compact_serialized_input_sd_jwt, *disclosures, "" ) return self._b64hash(string_to_hash.encode("ascii")) diff --git a/pyeudiw/sd_jwt/disclosure.py b/pyeudiw/sd_jwt/disclosure.py index ebe3c5e6..99860661 100644 --- a/pyeudiw/sd_jwt/disclosure.py +++ b/pyeudiw/sd_jwt/disclosure.py @@ -17,7 +17,15 @@ class SDJWTDisclosure: def __post_init__(self): self._hash() - def _hash(self): + def _hash(self) -> None: + """ + Hash the claim. + + This method hashes the claim using the issuer's hashing algorithm. The hashed claim is stored in the + `hash` attribute. + + :return: None + """ salt = self.issuer._generate_salt() if self.key is None: data = [salt, self.value] diff --git a/pyeudiw/sd_jwt/exceptions.py b/pyeudiw/sd_jwt/exceptions.py index efcda1fc..e8801089 100644 --- a/pyeudiw/sd_jwt/exceptions.py +++ b/pyeudiw/sd_jwt/exceptions.py @@ -1,10 +1,5 @@ from . import SD_DIGESTS_KEY - -class UnknownCurveNistName(Exception): - pass - - class InvalidKeyBinding(Exception): pass diff --git a/pyeudiw/sd_jwt/holder.py b/pyeudiw/sd_jwt/holder.py index 63600300..34b88678 100644 --- a/pyeudiw/sd_jwt/holder.py +++ b/pyeudiw/sd_jwt/holder.py @@ -1,32 +1,33 @@ import logging +from itertools import zip_longest +from json import dumps, loads +from time import time +from typing import Dict, List, Optional, Union +from pyeudiw.jwt.helper import KeyLike -from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.sd_jwt.common import SDJWTCommon +from cryptojwt.jws.jws import JWS +from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.sd_jwt import ( DEFAULT_SIGNING_ALG, - SD_DIGESTS_KEY, - SD_LIST_PREFIX, - KB_DIGEST_KEY, JSON_SER_DISCLOSURE_KEY, JSON_SER_KB_JWT_KEY, + KB_DIGEST_KEY, + SD_DIGESTS_KEY, + SD_LIST_PREFIX, ) -from json import dumps -from time import time -from typing import Dict, List, Optional -from itertools import zip_longest - -from cryptojwt.jws.jws import JWS -from json import dumps, loads +from pyeudiw.sd_jwt.common import SDJWTCommon logger = logging.getLogger(__name__) - class SDJWTHolder(SDJWTCommon): + """ + SDJWTHolder is a class to create a holder presentation from a SD-JWT. + """ + hs_disclosures: List key_binding_jwt_header: Dict key_binding_jwt_payload: Dict - key_binding_jwt: JWS serialized_key_binding_jwt: str = "" sd_jwt_presentation: str @@ -34,7 +35,19 @@ class SDJWTHolder(SDJWTCommon): _hash_to_decoded_disclosure: Dict _hash_to_disclosure: Dict - def __init__(self, sd_jwt_issuance: str, serialization_format: str = "compact"): + def __init__( + self, + sd_jwt_issuance: str, + serialization_format: str = "compact") -> None: + """ + Creates an instance of SDJWTHolder. + + :param sd_jwt_issuance: The SD-JWT to create a presentation from. + :param serialization_format: The serialization format of the SD-JWT. + + :param serialization_format: The serialization format of the SD-JWT. + :type serialization_format: str + """ super().__init__(serialization_format=serialization_format) self._parse_sd_jwt(sd_jwt_issuance) @@ -49,8 +62,23 @@ def __init__(self, sd_jwt_issuance: str, serialization_format: str = "compact"): self._create_hash_mappings(self._input_disclosures) def create_presentation( - self, claims_to_disclose, nonce=None, aud=None, holder_key=None, sign_alg=None - ): + self, + claims_to_disclose: Union[dict, bool, None], + nonce: Union[str, None] = None, + aud: Union[str, None] = None, + holder_key: Union[KeyLike, None] = None, + sign_alg: Union[str, None] = None + ) -> None: + """ + Create a holder presentation from the SD-JWT. + + :param claims_to_disclose: The claims to disclose. If True, all claims are disclosed. + :param nonce: The nonce to include in the key binding JWT. + :param aud: The audience to include in the key binding JWT. + :param holder_key: The key to sign the key binding JWT with. + :param sign_alg: The signing algorithm to use for the key binding JWT. + """ + # Select the disclosures self.hs_disclosures = [] @@ -58,8 +86,7 @@ def create_presentation( # Optional: Create a key binding JWT if nonce and aud and holder_key: - sd_jwt_presentation_hash = self._calculate_kb_hash( - self.hs_disclosures) + sd_jwt_presentation_hash = self._calculate_kb_hash(self.hs_disclosures) self._create_key_binding_jwt( nonce, aud, sd_jwt_presentation_hash, holder_key, sign_alg ) @@ -99,14 +126,31 @@ def create_presentation( self.sd_jwt_presentation = dumps(presentation) - def _select_disclosures(self, sd_jwt_claims, claims_to_disclose): - # Recursively process the claims in sd_jwt_claims. In each - # object found therein, look at the SD_DIGESTS_KEY. If it - # contains hash digests for claims that should be disclosed, - # then add the corresponding disclosures to the claims_to_disclose. + def _select_disclosures( + self, + sd_jwt_claims: Union[bytes, list, dict], + claims_to_disclose: Union[dict, bool, None]) -> Union[dict, list, None]: + """ + Recursively process the claims in sd_jwt_claims. In each + object found therein, look at the SD_DIGESTS_KEY. If it + contains hash digests for claims that should be disclosed, + then add the corresponding disclosures to the claims_to_disclose. + + :param sd_jwt_claims: The claims to process. + :param claims_to_disclose: The claims to disclose. + + :type sd_jwt_claims: bytes | list | dict + :type claims_to_disclose: dict | True | None + - if (type(sd_jwt_claims) is bytes): - return self._select_disclosures_dict(loads(self.sd_jwt_payload.decode('utf-8')), claims_to_disclose) + :returns: The claims to disclose. + :rtype: dict | list | None + """ + + if type(sd_jwt_claims) is bytes: + return self._select_disclosures_dict( + loads(self.sd_jwt_payload.decode("utf-8")), claims_to_disclose + ) if type(sd_jwt_claims) is list: return self._select_disclosures_list(sd_jwt_claims, claims_to_disclose) elif type(sd_jwt_claims) is dict: @@ -114,7 +158,26 @@ def _select_disclosures(self, sd_jwt_claims, claims_to_disclose): else: pass - def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): + def _select_disclosures_list( + self, + sd_jwt_claims: list, + claims_to_disclose: Union[list, bool, None]) -> list: + + """ + Process the claims in a list. + + :param sd_jwt_claims: The claims to process. + :param claims_to_disclose: The claims to disclose. + + :type sd_jwt_claims: list + :type claims_to_disclose: list | True | None + + :raises ValueError: If the disclosure information is not an array. + + :returns: The claims to disclose. + :rtype: list + """ + if claims_to_disclose is None: return [] if claims_to_disclose is True: @@ -151,8 +214,7 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): ): continue - self.hs_disclosures.append( - self._hash_to_disclosure[digest_to_check]) + self.hs_disclosures.append(self._hash_to_disclosure[digest_to_check]) if isinstance(disclosure_value, dict): if claims_to_disclose_element is True: # Tolerate a "True" for a disclosure of an object @@ -186,7 +248,25 @@ def _select_disclosures_list(self, sd_jwt_claims, claims_to_disclose): else: self._select_disclosures(element, claims_to_disclose_element) - def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): + def _select_disclosures_dict( + self, + sd_jwt_claims: dict, + claims_to_disclose: Union[dict, bool, None]) -> dict: + """ + Process the claims in a dictionary. + + :param sd_jwt_claims: The claims to process. + :param claims_to_disclose: The claims to disclose. + + :type sd_jwt_claims: dict + :type claims_to_disclose: dict | True | None + + :raises ValueError: If the disclosure information is not a dictionary. + + :returns: The claims to disclose. + :rtype: dict + """ + if claims_to_disclose is None: return {} if claims_to_disclose is True: @@ -211,8 +291,7 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): f"In _select_disclosures_dict: {key}, {value}, {claims_to_disclose}" ) if key in claims_to_disclose and claims_to_disclose[key]: - logger.debug( - f"Adding disclosure for {digest_to_check}") + logger.debug(f"Adding disclosure for {digest_to_check}") self.hs_disclosures.append( self._hash_to_disclosure[digest_to_check] ) @@ -227,15 +306,28 @@ def _select_disclosures_dict(self, sd_jwt_claims, claims_to_disclose): f"Check claims_to_disclose for key: {key}, value: {value}" ) from None - self._select_disclosures( - value, claims_to_disclose.get(key, None)) + self._select_disclosures(value, claims_to_disclose.get(key, None)) else: - self._select_disclosures( - value, claims_to_disclose.get(key, None)) + self._select_disclosures(value, claims_to_disclose.get(key, None)) def _create_key_binding_jwt( - self, nonce, aud, presentation_hash, holder_key, sign_alg: Optional[str] = None - ): + self, + nonce: Union[str, None], + aud: Union[str, None], + presentation_hash, + holder_key: Union[KeyLike | list[KeyLike | dict] | dict], + sign_alg: Optional[str] = None + ) -> None: + """ + Create a key binding JWT. + + :param nonce: The nonce to include in the key binding JWT. + :param aud: The audience to include in the key binding JWT. + :param presentation_hash: The hash of the presentation. + :param holder_key: The key to sign the key binding JWT with. + :param sign_alg: The signing algorithm to use for the key binding JWT. + """ + _alg = sign_alg or DEFAULT_SIGNING_ALG self.key_binding_jwt_header = { @@ -254,5 +346,5 @@ def _create_key_binding_jwt( self.serialized_key_binding_jwt = signer.sign( self.key_binding_jwt_payload, protected=self.key_binding_jwt_header, - kid_in_header=False + kid_in_header=False, ) diff --git a/pyeudiw/sd_jwt/issuer.py b/pyeudiw/sd_jwt/issuer.py index c78b8ba5..b57624e1 100644 --- a/pyeudiw/sd_jwt/issuer.py +++ b/pyeudiw/sd_jwt/issuer.py @@ -1,25 +1,21 @@ import logging import secrets - from typing import Dict, List, Union +from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jws.jws import JWS from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.sd_jwt.common import SDJWTCommon, SDObj - from pyeudiw.sd_jwt import ( DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, + JSON_SER_DISCLOSURE_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX, - JSON_SER_DISCLOSURE_KEY, ) - +from pyeudiw.sd_jwt.common import SDJWTCommon, SDObj from pyeudiw.sd_jwt.disclosure import SDJWTDisclosure -from cryptojwt.jws.jws import JWS -from cryptojwt.jwk.jwk import key_from_jwk_dict - logger = logging.getLogger(__name__) @@ -162,8 +158,7 @@ def _create_sd_claims_object(self, user_claims: Dict): for _ in range( sr.randint(self.DECOY_MIN_ELEMENTS, self.DECOY_MAX_ELEMENTS) ): - sd_claims[SD_DIGESTS_KEY].append( - self._create_decoy_claim_entry()) + sd_claims[SD_DIGESTS_KEY].append(self._create_decoy_claim_entry()) # Delete the SD_DIGESTS_KEY if it is empty if len(sd_claims[SD_DIGESTS_KEY]) == 0: @@ -194,19 +189,19 @@ def _create_signed_jws(self): _unprotected_headers = {} for i, key in enumerate(self._issuer_keys): - _unprotected_headers = { - "kid": key["kid"]} if "kid" in key else None + _unprotected_headers = {"kid": key["kid"]} if "kid" in key else None if self._serialization_format == "json" and i == 0: _unprotected_headers = _unprotected_headers or {} _unprotected_headers[JSON_SER_DISCLOSURE_KEY] = [ - d.b64 for d in self.ii_disclosures] + d.b64 for d in self.ii_disclosures + ] self.sd_jwt = JWSHelper(jwks=self._issuer_keys) self.serialized_sd_jwt = self.sd_jwt.sign( self.sd_jwt_payload, protected=_protected_headers, unprotected=_unprotected_headers, - serialization_format=self._serialization_format + serialization_format=self._serialization_format, ) def _create_combined(self): diff --git a/pyeudiw/sd_jwt/schema.py b/pyeudiw/sd_jwt/schema.py index e0d11479..ee32a269 100644 --- a/pyeudiw/sd_jwt/schema.py +++ b/pyeudiw/sd_jwt/schema.py @@ -1,7 +1,8 @@ import logging -import sys import re +import sys from typing import Dict, Literal, Optional, TypeVar + from typing_extensions import Self if float(f"{sys.version_info.major}.{sys.version_info.minor}") >= 3.12: @@ -13,8 +14,7 @@ from pyeudiw.jwk.schemas.public import JwkSchema - -_OptionalDict_T = TypeVar('T', None, dict) +_OptionalDict_T = TypeVar("T", None, dict) _IDENTIFYING_VC_TYP = "vc+sd-jwt" _IDENTIFYING_KB_TYP = "kb+jwt" @@ -50,14 +50,14 @@ class VcSdJwtHeaderSchema(BaseModel): def validate_typ(cls, v: str) -> str: if v != _IDENTIFYING_VC_TYP: raise ValueError( - f"header parameter [typ] must be '{_IDENTIFYING_VC_TYP}', found instead '{v}'") + f"header parameter [typ] must be '{_IDENTIFYING_VC_TYP}', found instead '{v}'" + ) return v @model_validator(mode="after") def check_typ_when_not_x5c(self) -> Self: if (not self.x5c) and (not self.kid): - raise ValueError( - "[kid] must be defined if [x5c] claim is not defined") + raise ValueError("[kid] must be defined if [x5c] claim is not defined") return self @@ -98,7 +98,8 @@ def validate_status(cls, v: dict) -> dict: _StatusSchema(**v) except ValueError as e: raise ValueError( - f"parameter [status] value '{v}' does not comply with schema {_StatusSchema.model_fields}: {e}") + f"parameter [status] value '{v}' does not comply with schema {_StatusSchema.model_fields}: {e}" + ) return v @field_validator("verification") @@ -107,7 +108,8 @@ def validate_verification(cls, v: dict) -> dict: _VerificationSchema(**v) except ValueError as e: raise ValueError( - f"parameter [verification] value '{v}' does not comply with schema {_VerificationSchema.model_fields}: {e}") + f"parameter [verification] value '{v}' does not comply with schema {_VerificationSchema.model_fields}: {e}" + ) return v @@ -119,7 +121,8 @@ class KeyBindingJwtHeader(BaseModel): def validate_typ(cls, v: str) -> str: if v != _IDENTIFYING_KB_TYP: raise ValueError( - f"header parameter [typ] must be '{_IDENTIFYING_KB_TYP}', found instead '{v}'") + f"header parameter [typ] must be '{_IDENTIFYING_KB_TYP}', found instead '{v}'" + ) return v diff --git a/pyeudiw/sd_jwt/sd_jwt.py b/pyeudiw/sd_jwt/sd_jwt.py index 6541b302..e19ab651 100644 --- a/pyeudiw/sd_jwt/sd_jwt.py +++ b/pyeudiw/sd_jwt/sd_jwt.py @@ -1,31 +1,26 @@ +import json import logging from hashlib import sha256 -import json from typing import Any, Callable, TypeVar -from pyeudiw.jwt.jws_helper import JWSHelper -from pyeudiw.sd_jwt.common import SDJWTCommon +from cryptojwt.jwk.ec import ECKey +from cryptojwt.jwk.rsa import RSAKey + +from pyeudiw.jwt.jws_helper import JWSHelper +from pyeudiw.jwt.parse import DecodedJwt from pyeudiw.jwt.utils import base64_urldecode, base64_urlencode from pyeudiw.jwt.verification import verify_jws_with_key from pyeudiw.sd_jwt.common import SDJWTCommon from pyeudiw.sd_jwt.exceptions import InvalidKeyBinding, UnsupportedSdAlg -from pyeudiw.sd_jwt.schema import is_sd_jwt_format, is_sd_jwt_kb_format, VerifierChallenge -from pyeudiw.jwt.parse import DecodedJwt -from pyeudiw.tools.utils import iat_now - - -from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey - -from . import ( - DEFAULT_SD_ALG, - DIGEST_ALG_KEY, - SD_DIGESTS_KEY, - SD_LIST_PREFIX +from pyeudiw.sd_jwt.schema import ( + VerifierChallenge, + is_sd_jwt_format, ) +from . import DEFAULT_SD_ALG, DIGEST_ALG_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX + _JsonTypes = dict | list | str | int | float | bool | None -_JsonTypes_T = TypeVar('_JsonTypes_T', bound=_JsonTypes) +_JsonTypes_T = TypeVar("_JsonTypes_T", bound=_JsonTypes) FORMAT_SEPARATOR = SDJWTCommon.COMBINED_SERIALIZATION_FORMAT_SEPARATOR @@ -46,7 +41,8 @@ class SdJwt: def __init__(self, token: str): if not is_sd_jwt_format(token): raise ValueError( - f"input [token]={token} is not an sd-jwt with: maybe it is a regular jwt?") + f"input [token]={token} is not an sd-jwt with: maybe it is a regular jwt?" + ) self.token = token # precomputed values self.token_without_kb: str = "" @@ -57,8 +53,11 @@ def __init__(self, token: str): def _post_init_precomputed_values(self): iss_jwt, *disclosures, kb_jwt = self.token.split(FORMAT_SEPARATOR) - self.token_without_kb = iss_jwt + FORMAT_SEPARATOR + \ - ''.join(disc + FORMAT_SEPARATOR for disc in disclosures) + self.token_without_kb = ( + iss_jwt + + FORMAT_SEPARATOR + + "".join(disc + FORMAT_SEPARATOR for disc in disclosures) + ) self.issuer_jwt = DecodedJwt.parse(iss_jwt) self.disclosures = disclosures if kb_jwt: @@ -69,17 +68,19 @@ def get_confirmation_key(self) -> dict: cnf: dict = self.issuer_jwt.payload.get("cnf", {}).get("jwk", {}) if not cnf: raise ValueError( - "missing confirmation (cnf) key from issuer payload claims") + "missing confirmation (cnf) key from issuer payload claims" + ) return cnf - def get_encoded_disclosures(self) -> list[str]: - return self.disclosures - def get_disclosed_claims(self) -> dict: - return _extract_claims_from_payload(self.issuer_jwt.payload, self.disclosures, SUPPORTED_SD_ALG_FN[self.get_sd_alg()]) + return _extract_claims_from_payload( + self.issuer_jwt.payload, + self.disclosures, + SUPPORTED_SD_ALG_FN[self.get_sd_alg()], + ) - def get_issuer_jwt(self) -> str: - return self.issuer_jwt.jwt + def get_issuer_jwt(self) -> DecodedJwt: + return self.issuer_jwt def get_holder_key_binding_jwt(self) -> str: return self.holder_kb.jwt @@ -90,7 +91,9 @@ def get_sd_alg(self) -> str: def has_key_binding(self) -> bool: return self.holder_kb is not None - def verify_issuer_jwt_signature(self, keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict) -> None: + def verify_issuer_jwt_signature( + self, keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict + ) -> None: jws_verifier = JWSHelper(keys) jws_verifier.verify(self.issuer_jwt.jwt) @@ -104,8 +107,9 @@ def verify_holder_kb_jwt(self, challenge: VerifierChallenge) -> None: """ if not self.has_key_binding(): return - _verify_key_binding(self.token_without_kb, self.get_sd_alg(), - self.holder_kb, challenge) + _verify_key_binding( + self.token_without_kb, self.get_sd_alg(), self.holder_kb, challenge + ) self.verify_holder_kb_jwt_signature() def verify_holder_kb_jwt_signature(self) -> None: @@ -115,24 +119,15 @@ def verify_holder_kb_jwt_signature(self) -> None: verify_jws_with_key(self.holder_kb.jwt, cnf) -class SdJwtKb(SdJwt): - - def __init__(self, token: str): - if not is_sd_jwt_kb_format(token): - raise ValueError( - f"input [token]={token} is not an sd-jwt with key binding with: maybe it is a regular jwt?") - super().__init__(token) - if not self.holder_kb: - raise ValueError("missing key binding jwt") - - def _verify_challenge(hkb: DecodedJwt, challenge: VerifierChallenge): if (obt := hkb.payload.get("aud", None)) != (exp := challenge["aud"]): raise InvalidKeyBinding( - f"challenge audience {exp} does not match obtained audience {obt}") + f"challenge audience {exp} does not match obtained audience {obt}" + ) if (obt := hkb.payload.get("nonce", None)) != (exp := challenge["nonce"]): raise InvalidKeyBinding( - f"challenge nonce {exp} does not match obtained nonce {obt}") + f"challenge nonce {exp} does not match obtained nonce {obt}" + ) def _verify_sd_hash(token_without_hkb: str, sd_hash_alg: str, expected_digest: str): @@ -141,7 +136,8 @@ def _verify_sd_hash(token_without_hkb: str, sd_hash_alg: str, expected_digest: s raise UnsupportedSdAlg(f"unsupported sd_alg: {sd_hash_alg}") if expected_digest != (obt_digest := hash_fn(token_without_hkb)): raise InvalidKeyBinding( - f"sd-jwt digest {obt_digest} does not match expected digest {expected_digest}") + f"sd-jwt digest {obt_digest} does not match expected digest {expected_digest}" + ) def _verify_iat(payload: dict) -> None: @@ -153,17 +149,22 @@ def _verify_iat(payload: dict) -> None: raise ValueError("missing or invalid parameter [iat] in kbjwt") -def _verify_key_binding(token_without_hkb: str, sd_hash_alg: str, hkb: DecodedJwt, challenge: VerifierChallenge): +def _verify_key_binding( + token_without_hkb: str, + sd_hash_alg: str, + hkb: DecodedJwt, + challenge: VerifierChallenge, +): _verify_challenge(hkb, challenge) _verify_sd_hash( - token_without_hkb, - sd_hash_alg, - hkb.payload.get("sd_hash", "sha-256") + token_without_hkb, sd_hash_alg, hkb.payload.get("sd_hash", "sha-256") ) _verify_iat(hkb.payload) -def _disclosures_to_hash_mappings(disclosures: list[str], sd_alg: Callable[[str], str]) -> tuple[dict[str, str], dict[str, Any]]: +def _disclosures_to_hash_mappings( + disclosures: list[str], sd_alg: Callable[[str], str] +) -> tuple[dict[str, str], dict[str, Any]]: """ :returns: in order (i) hash_to_disclosure, a map: digest -> raw disclosure (base64 encoded) @@ -173,8 +174,7 @@ def _disclosures_to_hash_mappings(disclosures: list[str], sd_alg: Callable[[str] hash_to_disclosure: dict[str, str] = {} hash_to_dec_disclosure: dict[str, Any] = {} for disclosure in disclosures: - decoded_disclosure = json.loads( - base64_urldecode(disclosure).decode("utf-8")) + decoded_disclosure = json.loads(base64_urldecode(disclosure).decode("utf-8")) digest = sd_alg(disclosure) if digest in hash_to_dec_disclosure: raise ValueError(f"duplicate disclosure for digest {digest}") @@ -183,33 +183,56 @@ def _disclosures_to_hash_mappings(disclosures: list[str], sd_alg: Callable[[str] return hash_to_disclosure, hash_to_dec_disclosure -def _extract_claims_from_payload(payload: dict, disclosures: list[str], sd_alg: Callable[[str], str]) -> dict: +def _extract_claims_from_payload( + payload: dict, disclosures: list[str], sd_alg: Callable[[str], str] +) -> dict: hash_to_disclosure, hash_to_dec_disclosure = _disclosures_to_hash_mappings( - disclosures, sd_alg) + disclosures, sd_alg + ) return _unpack_claims(payload, hash_to_dec_disclosure, sd_alg, []) def _is_element_leaf(element: Any) -> bool: - return (type(element) is dict and len(element) == 1 and SD_LIST_PREFIX in element - and type(element[SD_LIST_PREFIX]) is str) + return ( + type(element) is dict + and len(element) == 1 + and SD_LIST_PREFIX in element + and type(element[SD_LIST_PREFIX]) is str + ) -def _unpack_json_array(claims: list, decoded_disclosures_by_digest: dict[str, Any], sd_alg: Callable[[str], str], processed_digests: list[str]) -> list: +def _unpack_json_array( + claims: list, + decoded_disclosures_by_digest: dict[str, Any], + sd_alg: Callable[[str], str], + processed_digests: list[str], +) -> list: result = [] for element in claims: if _is_element_leaf(element): digest: str = element[SD_LIST_PREFIX] if digest in decoded_disclosures_by_digest: _, value = decoded_disclosures_by_digest[digest] - result.append(_unpack_claims( - value, decoded_disclosures_by_digest, sd_alg, processed_digests)) + result.append( + _unpack_claims( + value, decoded_disclosures_by_digest, sd_alg, processed_digests + ) + ) else: - result.append(_unpack_claims( - element, decoded_disclosures_by_digest, sd_alg, processed_digests)) + result.append( + _unpack_claims( + element, decoded_disclosures_by_digest, sd_alg, processed_digests + ) + ) return result -def _unpack_json_dict(claims: dict, decoded_disclosures_by_digest: dict[str, Any], sd_alg: Callable[[str], str], proceessed_digests: list[str]) -> dict: +def _unpack_json_dict( + claims: dict, + decoded_disclosures_by_digest: dict[str, Any], + sd_alg: Callable[[str], str], + proceessed_digests: list[str], +) -> dict: # First, try to figure out if there are any claims to be # disclosed in this dict. If so, replace them by their # disclosed values. @@ -217,12 +240,12 @@ def _unpack_json_dict(claims: dict, decoded_disclosures_by_digest: dict[str, Any for k, v in claims.items(): if k != SD_DIGESTS_KEY and k != DIGEST_ALG_KEY: filtered_unpacked_claims[k] = _unpack_claims( - v, decoded_disclosures_by_digest, sd_alg, proceessed_digests) + v, decoded_disclosures_by_digest, sd_alg, proceessed_digests + ) for disclosed_digests in claims.get(SD_DIGESTS_KEY, []): if disclosed_digests in proceessed_digests: - raise ValueError( - f"duplicate hash found in SD-JWT: {disclosed_digests}") + raise ValueError(f"duplicate hash found in SD-JWT: {disclosed_digests}") proceessed_digests.append(disclosed_digests) if disclosed_digests in decoded_disclosures_by_digest: @@ -232,16 +255,25 @@ def _unpack_json_dict(claims: dict, decoded_disclosures_by_digest: dict[str, Any f"duplicate key found when unpacking disclosed claim: '{key}' in {filtered_unpacked_claims}; this is not allowed." ) unpacked_value = _unpack_claims( - value, decoded_disclosures_by_digest, sd_alg, proceessed_digests) + value, decoded_disclosures_by_digest, sd_alg, proceessed_digests + ) filtered_unpacked_claims[key] = unpacked_value return filtered_unpacked_claims -def _unpack_claims(claims: _JsonTypes_T, decoded_disclosures_by_digest: dict[str, Any], - sd_alg: Callable[[str], str], proceessed_digests: list[str]) -> _JsonTypes_T: +def _unpack_claims( + claims: _JsonTypes_T, + decoded_disclosures_by_digest: dict[str, Any], + sd_alg: Callable[[str], str], + proceessed_digests: list[str], +) -> _JsonTypes_T: if type(claims) is list: - return _unpack_json_array(claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests) + return _unpack_json_array( + claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests + ) elif type(claims) is dict: - return _unpack_json_dict(claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests) + return _unpack_json_dict( + claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests + ) else: return claims diff --git a/pyeudiw/sd_jwt/utils/demo_utils.py b/pyeudiw/sd_jwt/utils/demo_utils.py index cb23a88c..36068164 100644 --- a/pyeudiw/sd_jwt/utils/demo_utils.py +++ b/pyeudiw/sd_jwt/utils/demo_utils.py @@ -1,11 +1,8 @@ -import base64 import logging import random -import yaml import sys - +import yaml from cryptojwt.jwk.jwk import key_from_jwk_dict -from typing import Union logger = logging.getLogger("sd_jwt") @@ -21,34 +18,15 @@ def load_yaml_settings(file): # 'issuer_key' can be used instead of 'issuer_keys' in the key settings; will be converted to an array anyway if "issuer_key" in settings["key_settings"]: if "issuer_keys" in settings["key_settings"]: - sys.exit( - "Settings file cannot define both 'issuer_key' and 'issuer_keys'.") + sys.exit("Settings file cannot define both 'issuer_key' and 'issuer_keys'.") settings["key_settings"]["issuer_keys"] = [ - settings["key_settings"]["issuer_key"]] + settings["key_settings"]["issuer_key"] + ] return settings -def print_repr(values: Union[str, list], nlines=2): - value = "\n".join(values) if isinstance(values, (list, tuple)) else values - _nlines = "\n" * nlines if nlines else "" - print(value, end=_nlines) - - -def print_decoded_repr(value: str, nlines=2): - seq = [] - for i in value.split("."): - try: - padded = f"{i}{'=' * divmod(len(i),4)[1]}" - seq.append(f"{base64.urlsafe_b64decode(padded).decode()}") - except Exception as e: - logging.debug(f"{e} - for value: {i}") - seq.append(i) - _nlines = "\n" * nlines if nlines else "" - print("\n.\n".join(seq), end=_nlines) - - def get_jwk(jwk_kwargs: dict = {}, no_randomness: bool = False, random_seed: int = 0): """ jwk_kwargs = { @@ -66,8 +44,7 @@ def get_jwk(jwk_kwargs: dict = {}, no_randomness: bool = False, random_seed: int issuer_keys = [key_from_jwk_dict(k) for k in jwk_kwargs["issuer_keys"]] holder_key = key_from_jwk_dict(jwk_kwargs["holder_key"]) else: - _kwargs = { - "key_size": jwk_kwargs["key_size"], "kty": jwk_kwargs["kty"]} + _kwargs = {"key_size": jwk_kwargs["key_size"], "kty": jwk_kwargs["kty"]} issuer_keys = [key_from_jwk_dict(_kwargs)] holder_key = key_from_jwk_dict(_kwargs) diff --git a/pyeudiw/sd_jwt/utils/yaml_specification.py b/pyeudiw/sd_jwt/utils/yaml_specification.py index 60eb1260..74f1fec5 100644 --- a/pyeudiw/sd_jwt/utils/yaml_specification.py +++ b/pyeudiw/sd_jwt/utils/yaml_specification.py @@ -1,11 +1,75 @@ -from pyeudiw.sd_jwt.common import SDObj -import yaml import sys +import yaml +from io import TextIOWrapper +from pyeudiw.sd_jwt.common import SDObj +from typing import Union + +class _SDKeyTag(yaml.YAMLObject): + """ + YAML tag for selective disclosure keys. + + This class is used to define a custom YAML tag for selective disclosure keys. This tag is used to indicate + that a key in a YAML mapping is a selective disclosure key, and that its value should be parsed as a selective + disclosure object. + """ + + yaml_tag = "!sd" + + @classmethod + def from_yaml(cls, loader, node): + # If this is a scalar node, it can be a string, int, float, etc.; unfortunately, since we tagged + # it with !sd, we cannot rely on the default YAML loader to parse it into the correct data type. + # Instead, we must manually resolve it. + if isinstance(node, yaml.ScalarNode): + # If the 'style' is '"', then the scalar is a string; otherwise, we must resolve it. + if node.style == '"': + mp = loader.construct_yaml_str(node) + else: + resolved_type = yaml.resolver.Resolver().resolve( + yaml.ScalarNode, node.value, (True, False) + ) + if resolved_type == "tag:yaml.org,2002:str": + mp = loader.construct_yaml_str(node) + elif resolved_type == "tag:yaml.org,2002:int": + mp = loader.construct_yaml_int(node) + elif resolved_type == "tag:yaml.org,2002:float": + mp = loader.construct_yaml_float(node) + elif resolved_type == "tag:yaml.org,2002:bool": + mp = loader.construct_yaml_bool(node) + elif resolved_type == "tag:yaml.org,2002:null": + mp = None + else: + raise Exception( + f"Unsupported scalar type for selective disclosure (!sd): {resolved_type}; node is {node}, style is {node.style}" + ) + return SDObj(mp) + elif isinstance(node, yaml.MappingNode): + return SDObj(loader.construct_mapping(node)) + elif isinstance(node, yaml.SequenceNode): + return SDObj(loader.construct_sequence(node)) + else: + raise Exception( + "Unsupported node type for selective disclosure (!sd): {}".format( + node + ) + ) +def _yaml_load_specification(file_buffer: TextIOWrapper): + return yaml.load(file_buffer, Loader=yaml.FullLoader) # nosec + +def load_yaml_specification(file_path: str) -> dict: + """ + Load a YAML specification file and return the parsed content. + + :param file_path: Path to the YAML file. + :type file_path: str + + :returns: The parsed content of the YAML file. + :rtype: dict + """ -def load_yaml_specification(file): # create new resolver for tags - with open(file, "r") as f: + with open(file_path, "r") as f: example = _yaml_load_specification(f) for property in ("user_claims", "holder_disclosed_claims"): @@ -14,65 +78,24 @@ def load_yaml_specification(file): return example +def remove_sdobj_wrappers(data: Union[SDObj, dict, list, any]) -> Union[dict, list, any]: + """ + Recursively remove SDObj wrappers from the data structure. -def _yaml_load_specification(f): - resolver = yaml.resolver.Resolver() - - # Define custom YAML tag to indicate selective disclosure - class SDKeyTag(yaml.YAMLObject): - yaml_tag = "!sd" - - @classmethod - def from_yaml(cls, loader, node): - # If this is a scalar node, it can be a string, int, float, etc.; unfortunately, since we tagged - # it with !sd, we cannot rely on the default YAML loader to parse it into the correct data type. - # Instead, we must manually resolve it. - if isinstance(node, yaml.ScalarNode): - # If the 'style' is '"', then the scalar is a string; otherwise, we must resolve it. - if node.style == '"': - mp = loader.construct_yaml_str(node) - else: - resolved_type = resolver.resolve( - yaml.ScalarNode, node.value, (True, False)) - if resolved_type == "tag:yaml.org,2002:str": - mp = loader.construct_yaml_str(node) - elif resolved_type == "tag:yaml.org,2002:int": - mp = loader.construct_yaml_int(node) - elif resolved_type == "tag:yaml.org,2002:float": - mp = loader.construct_yaml_float(node) - elif resolved_type == "tag:yaml.org,2002:bool": - mp = loader.construct_yaml_bool(node) - elif resolved_type == "tag:yaml.org,2002:null": - mp = None - else: - raise Exception( - f"Unsupported scalar type for selective disclosure (!sd): {resolved_type}; node is {node}, style is {node.style}" - ) - return SDObj(mp) - elif isinstance(node, yaml.MappingNode): - return SDObj(loader.construct_mapping(node)) - elif isinstance(node, yaml.SequenceNode): - return SDObj(loader.construct_sequence(node)) - else: - raise Exception( - "Unsupported node type for selective disclosure (!sd): {}".format( - node - ) - ) - - return yaml.load(f, Loader=yaml.FullLoader) # nosec - - -""" -Takes an object that has been parsed from a YAML file and removes the SDObj wrappers. -""" + :param data: The data structure to remove SDObj wrappers from. + :type data: SDObj | dict | list | any + :returns: The data structure with SDObj wrappers removed. + :rtype: dict | list | any + """ -def remove_sdobj_wrappers(data): if isinstance(data, SDObj): return remove_sdobj_wrappers(data.value) elif isinstance(data, dict): - return {remove_sdobj_wrappers(key): remove_sdobj_wrappers(value) for key, value in data.items()} + return { + remove_sdobj_wrappers(key): remove_sdobj_wrappers(value) + for key, value in data.items() + } elif isinstance(data, list): return [remove_sdobj_wrappers(value) for value in data] else: diff --git a/pyeudiw/sd_jwt/verifier.py b/pyeudiw/sd_jwt/verifier.py index cc65a991..adac66a3 100644 --- a/pyeudiw/sd_jwt/verifier.py +++ b/pyeudiw/sd_jwt/verifier.py @@ -1,24 +1,22 @@ import logging +from typing import Callable, Dict, List, Union + +from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jws.jws import JWS from pyeudiw.jwt.exceptions import JWSVerificationError from pyeudiw.jwt.helper import validate_jwt_timestamps_claims from pyeudiw.jwt.jws_helper import JWSHelper +from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload +from pyeudiw.sd_jwt.common import SDJWTCommon from . import ( DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, + KB_DIGEST_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX, - KB_DIGEST_KEY ) -from pyeudiw.sd_jwt.common import SDJWTCommon - -from typing import Dict, List, Union, Callable - -from cryptojwt.jwk.jwk import key_from_jwk_dict -from cryptojwt.jws.jws import JWS - -from pyeudiw.jwt.utils import decode_jwt_payload, decode_jwt_header logger = logging.getLogger(__name__) @@ -66,12 +64,16 @@ def _verify_sd_jwt( parsed_input_sd_jwt = JWS(alg=sign_alg) if self._serialization_format == "json": - _deserialize_sd_jwt_payload = decode_jwt_header( - self._unverified_input_sd_jwt_parsed["payload"]) + _deserialize_sd_jwt_payload: dict = decode_jwt_header( + self._unverified_input_sd_jwt_parsed["payload"] + ) unverified_issuer = _deserialize_sd_jwt_payload.get("iss", None) - unverified_header_parameters = self._unverified_input_sd_jwt_parsed['header'] + unverified_header_parameters = self._unverified_input_sd_jwt_parsed[ + "header" + ] issuer_public_key_input = cb_get_issuer_key( - unverified_issuer, unverified_header_parameters) + unverified_issuer, unverified_header_parameters + ) issuer_public_key = [] for key in issuer_public_key_input: @@ -84,15 +86,16 @@ def _verify_sd_jwt( issuer_public_key.append(key) self._sd_jwt_payload = parsed_input_sd_jwt.verify_json( - jws=self._unverified_input_sd_jwt, - keys=issuer_public_key + jws=self._unverified_input_sd_jwt, keys=issuer_public_key ) elif self._serialization_format == "compact": unverified_header_parameters = decode_jwt_header( - self._unverified_input_sd_jwt) + self._unverified_input_sd_jwt + ) sign_alg = sign_alg or unverified_header_parameters.get( - "alg", DEFAULT_SIGNING_ALG) + "alg", DEFAULT_SIGNING_ALG + ) parsed_input_sd_jwt = JWS(alg=sign_alg) parsed_payload = decode_jwt_payload(self._unverified_input_sd_jwt) @@ -116,7 +119,7 @@ def _verify_sd_jwt( self._sd_jwt_payload = parsed_input_sd_jwt.verify_compact( jws=self._unverified_input_sd_jwt, keys=issuer_public_key, - sigalg=sign_alg + sigalg=sign_alg, ) try: @@ -143,11 +146,9 @@ def _verify_key_binding_jwt( # Verify the key binding JWT using the holder public key if self._serialization_format == "json": - _deserialize_sd_jwt_payload = decode_jwt_header( - self._unverified_input_sd_jwt_parsed["payload"]) + decode_jwt_header(self._unverified_input_sd_jwt_parsed["payload"]) - holder_public_key_payload_jwk = self._holder_public_key_payload.get( - "jwk", None) + holder_public_key_payload_jwk = self._holder_public_key_payload.get("jwk", None) if not holder_public_key_payload_jwk: raise ValueError( @@ -160,10 +161,12 @@ def _verify_key_binding_jwt( parsed_input_key_binding_jwt = JWSHelper(jwks=pubkey) verified_payload = parsed_input_key_binding_jwt.verify( - self._unverified_input_key_binding_jwt) + self._unverified_input_key_binding_jwt + ) key_binding_jwt_header = decode_jwt_header( - self._unverified_input_key_binding_jwt) + self._unverified_input_key_binding_jwt + ) if key_binding_jwt_header["typ"] != self.KB_JWT_TYP_HEADER: raise ValueError("Invalid header typ") @@ -229,15 +232,15 @@ def _unpack_disclosed_claims(self, sd_jwt_claims): for digest in sd_jwt_claims.get(SD_DIGESTS_KEY, []): if digest in self._duplicate_hash_check: - raise ValueError( - f"Duplicate hash found in SD-JWT: {digest}") + raise ValueError(f"Duplicate hash found in SD-JWT: {digest}") self._duplicate_hash_check.append(digest) if digest in self._hash_to_decoded_disclosure: _, key, value = self._hash_to_decoded_disclosure[digest] if key in pre_output: raise ValueError( - f"Duplicate key found when unpacking disclosed claim: '{key}' in {pre_output}. This is not allowed." + "Duplicate key found when unpacking disclosed claim: " + f"'{key}' in {pre_output}. This is not allowed." ) unpacked_value = self._unpack_disclosed_claims(value) pre_output[key] = unpacked_value diff --git a/pyeudiw/storage/base_cache.py b/pyeudiw/storage/base_cache.py index da08f540..21e734d8 100644 --- a/pyeudiw/storage/base_cache.py +++ b/pyeudiw/storage/base_cache.py @@ -1,5 +1,6 @@ from enum import Enum from typing import Callable + from .base_db import BaseDB @@ -13,7 +14,9 @@ class BaseCache(BaseDB): Interface class for cache storage. """ - def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: + def try_retrieve( + self, object_name: str, on_not_found: Callable[[], str] + ) -> tuple[dict, RetrieveStatus]: """ Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. diff --git a/pyeudiw/storage/base_storage.py b/pyeudiw/storage/base_storage.py index 7be2222a..74fac983 100644 --- a/pyeudiw/storage/base_storage.py +++ b/pyeudiw/storage/base_storage.py @@ -1,6 +1,7 @@ import datetime from enum import Enum from typing import Union + from pymongo.results import UpdateResult from .base_db import BaseDB @@ -15,17 +16,17 @@ class TrustType(Enum): trust_type_map: dict = { TrustType.X509: "x509", TrustType.FEDERATION: "federation", - TrustType.DIRECT_TRUST_SD_JWT_VC: "direct_trust_sd_jwt_vc" + TrustType.DIRECT_TRUST_SD_JWT_VC: "direct_trust_sd_jwt_vc", } trust_attestation_field_map: dict = { TrustType.X509: "x5c", - TrustType.FEDERATION: "chain" + TrustType.FEDERATION: "chain", } trust_anchor_field_map: dict = { TrustType.X509: "pem", - TrustType.FEDERATION: "entity_configuration" + TrustType.FEDERATION: "entity_configuration", } @@ -34,7 +35,9 @@ class BaseStorage(BaseDB): Interface class for storage. """ - def init_session(self, document_id: str, session_id: str, state: str, remote_flow_typ: str) -> str: + def init_session( + self, document_id: str, session_id: str, state: str, remote_flow_typ: str + ) -> str: """ Initialize a session. @@ -66,7 +69,9 @@ def has_session_retention_ttl(self) -> bool: """ raise NotImplementedError() - def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict) -> UpdateResult: + def add_dpop_proof_and_attestation( + self, document_id, dpop_proof: dict, attestation: dict + ) -> UpdateResult: """ Add a dpop proof and an attestation to the session. @@ -95,7 +100,9 @@ def set_finalized(self, document_id: str) -> UpdateResult: raise NotImplementedError() - def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: + def update_request_object( + self, document_id: str, request_object: dict + ) -> UpdateResult: """ Update the request object of the session. @@ -109,7 +116,9 @@ def update_request_object(self, document_id: str, request_object: dict) -> Updat """ raise NotImplementedError() - def update_response_object(self, nonce: str, state: str, response_object: dict) -> UpdateResult: + def update_response_object( + self, nonce: str, state: str, response_object: dict + ) -> UpdateResult: """ Update the response object of the session. @@ -177,7 +186,14 @@ def has_trust_anchor(self, entity_id: str) -> bool: def has_trust_source(self, entity_id: str) -> bool: raise NotImplementedError() - def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType, jwks: dict) -> str: + def add_trust_attestation( + self, + entity_id: str, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + jwks: dict, + ) -> str: """ Add a trust attestation. @@ -197,7 +213,9 @@ def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: dat """ raise NotImplementedError() - def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict) -> str: + def add_trust_attestation_metadata( + self, entity_id: str, metadata_type: str, metadata: dict + ) -> str: """ Add a trust attestation metadata. @@ -237,7 +255,13 @@ def get_trust_source(self, entity_id: str) -> Union[dict, None]: """ raise NotImplementedError() - def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): + def add_trust_anchor( + self, + entity_id: str, + entity_configuration: str, + exp: datetime, + trust_type: TrustType, + ): """ Add a trust anchor. @@ -255,7 +279,14 @@ def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datet """ raise NotImplementedError() - def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType, jwks: dict) -> str: + def update_trust_attestation( + self, + entity_id: str, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + jwks: dict, + ) -> str: """ Update a trust attestation. @@ -275,7 +306,13 @@ def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: """ raise NotImplementedError() - def update_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType) -> str: + def update_trust_anchor( + self, + entity_id: str, + entity_configuration: str, + exp: datetime, + trust_type: TrustType, + ) -> str: """ Update a trust anchor. @@ -333,7 +370,9 @@ def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: """ raise NotImplementedError() - def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: + def get_by_state_and_session_id( + self, state: str, session_id: str = "" + ) -> Union[dict, None]: """ Get a session by state and session id. @@ -360,7 +399,9 @@ def get_by_session_id(self, session_id: str) -> Union[dict, None]: raise NotImplementedError() # TODO: create add_or_update for all the write methods - def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime) -> str: + def add_or_update_trust_attestation( + self, entity_id: str, attestation: list[str], exp: datetime + ) -> str: """ Add or update a trust attestation. diff --git a/pyeudiw/storage/db_engine.py b/pyeudiw/storage/db_engine.py index f2523d73..ee7eb1ff 100644 --- a/pyeudiw/storage/db_engine.py +++ b/pyeudiw/storage/db_engine.py @@ -1,18 +1,13 @@ import uuid from datetime import datetime -from typing import Callable, Union, Tuple +from typing import Callable, Tuple, Union + from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus from pyeudiw.storage.base_storage import BaseStorage, TrustType -from pyeudiw.storage.exceptions import ( - ChainNotExist, - StorageWriteError, - EntryNotFound -) +from pyeudiw.storage.exceptions import ChainNotExist, EntryNotFound, StorageWriteError from pyeudiw.tools.base_logger import BaseLogger - -from .base_db import BaseDB - from pyeudiw.tools.utils import dynamic_class_loader +from pyeudiw.storage.base_db import BaseDB class DBEngine(BaseStorage, BaseCache, BaseLogger): @@ -44,7 +39,10 @@ def init_session(self, session_id: str, state: str, remote_flow_typ: str) -> str for db_name, storage in self.storages: try: storage.init_session( - document_id, session_id=session_id, state=state, remote_flow_typ=remote_flow_typ + document_id, + session_id=session_id, + state=state, + remote_flow_typ=remote_flow_typ, ) except StorageWriteError as e: self._log_critical( @@ -52,7 +50,7 @@ def init_session(self, session_id: str, state: str, remote_flow_typ: str) -> str ( f"Error while initializing session with document_id {document_id}. " f"Cannot write document with id {document_id} on {db_name}: {e}" - ) + ), ) raise e @@ -87,8 +85,7 @@ def write(self, method: str, *args, **kwargs): replica_count += 1 except Exception as e: self._log_critical( - e.__class__.__name__, - f"Error {_err_msg} on {db_name}: {e}" + e.__class__.__name__, f"Error {_err_msg} on {db_name}: {e}" ) if not replica_count: @@ -96,12 +93,14 @@ def write(self, method: str, *args, **kwargs): return replica_count - def add_dpop_proof_and_attestation(self, document_id, dpop_proof: dict, attestation: dict): + def add_dpop_proof_and_attestation( + self, document_id, dpop_proof: dict, attestation: dict + ): return self.write( "add_dpop_proof_and_attestation", document_id, dpop_proof=dpop_proof, - attestation=attestation + attestation=attestation, ) def set_finalized(self, document_id: str): @@ -110,7 +109,9 @@ def set_finalized(self, document_id: str): def update_request_object(self, document_id: str, request_object: dict) -> int: return self.write("update_request_object", document_id, request_object) - def update_response_object(self, nonce: str, state: str, response_object: dict) -> int: + def update_response_object( + self, nonce: str, state: str, response_object: dict + ) -> int: return self.write("update_response_object", nonce, state, response_object) def get(self, method: str, *args, **kwargs) -> Union[dict, None]: @@ -139,7 +140,7 @@ def get(self, method: str, *args, **kwargs) -> Union[dict, None]: except EntryNotFound as e: self._log_debug( e.__class__.__name__, - f"Cannot find result by method {method} on {db_name} with {args} {kwargs}: {str(e)}" + f"Cannot find result by method {method} on {db_name} with {args} {kwargs}: {str(e)}", ) raise EntryNotFound(f"Cannot find any result by method {method}") @@ -159,11 +160,24 @@ def has_trust_anchor(self, entity_id: str) -> bool: def has_trust_source(self, entity_id: str) -> bool: return self.get_trust_source(entity_id) is not None - def add_trust_attestation(self, entity_id: str, attestation: list[str] = [], exp: datetime = None, trust_type: TrustType = TrustType.FEDERATION, jwks: list[dict] = []) -> str: - return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type, jwks) + def add_trust_attestation( + self, + entity_id: str, + attestation: list[str] = [], + exp: datetime = None, + trust_type: TrustType = TrustType.FEDERATION, + jwks: list[dict] = [], + ) -> str: + return self.write( + "add_trust_attestation", entity_id, attestation, exp, trust_type, jwks + ) - def add_trust_attestation_metadata(self, entity_id: str, metadat_type: str, metadata: dict) -> str: - return self.write("add_trust_attestation_metadata", entity_id, metadat_type, metadata) + def add_trust_attestation_metadata( + self, entity_id: str, metadat_type: str, metadata: dict + ) -> str: + return self.write( + "add_trust_attestation_metadata", entity_id, metadat_type, metadata + ) def add_trust_source(self, trust_source: dict) -> str: return self.write("add_trust_source", trust_source) @@ -171,21 +185,62 @@ def add_trust_source(self, trust_source: dict) -> str: def get_trust_source(self, entity_id: str) -> dict: return self.get("get_trust_source", entity_id) - def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: - return self.write("add_trust_anchor", entity_id, entity_configuration, exp, trust_type) + def add_trust_anchor( + self, + entity_id: str, + entity_configuration: str, + exp: datetime, + trust_type: TrustType = TrustType.FEDERATION, + ) -> str: + return self.write( + "add_trust_anchor", entity_id, entity_configuration, exp, trust_type + ) - def update_trust_attestation(self, entity_id: str, attestation: list[str] = [], exp: datetime = None, trust_type: TrustType = TrustType.FEDERATION, jwks: list[dict] = []) -> str: - return self.write("update_trust_attestation", entity_id, attestation, exp, trust_type, jwks) + def update_trust_attestation( + self, + entity_id: str, + attestation: list[str] = [], + exp: datetime = None, + trust_type: TrustType = TrustType.FEDERATION, + jwks: list[dict] = [], + ) -> str: + return self.write( + "update_trust_attestation", entity_id, attestation, exp, trust_type, jwks + ) - def add_or_update_trust_attestation(self, entity_id: str, attestation: list[str] = [], exp: datetime = None, trust_type: TrustType = TrustType.FEDERATION, jwks: list[dict] = []) -> str: + def add_or_update_trust_attestation( + self, + entity_id: str, + attestation: list[str] = [], + exp: datetime = None, + trust_type: TrustType = TrustType.FEDERATION, + jwks: list[dict] = [], + ) -> str: try: self.get_trust_attestation(entity_id) - return self.write("update_trust_attestation", entity_id, attestation, exp, trust_type, jwks) + return self.write( + "update_trust_attestation", + entity_id, + attestation, + exp, + trust_type, + jwks, + ) except (EntryNotFound, ChainNotExist): - return self.write("add_trust_attestation", entity_id, attestation, exp, trust_type, jwks) + return self.write( + "add_trust_attestation", entity_id, attestation, exp, trust_type, jwks + ) - def update_trust_anchor(self, entity_id: str, entity_configuration: dict, exp: datetime, trust_type: TrustType = TrustType.FEDERATION) -> str: - return self.write("update_trust_anchor", entity_id, entity_configuration, exp, trust_type) + def update_trust_anchor( + self, + entity_id: str, + entity_configuration: dict, + exp: datetime, + trust_type: TrustType = TrustType.FEDERATION, + ) -> str: + return self.write( + "update_trust_anchor", entity_id, entity_configuration, exp, trust_type + ) def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dict: # if no cache instance exist return the object @@ -193,15 +248,14 @@ def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dic return on_not_found() # if almost one cache instance exist try to retrieve - cache_object, status, idx = self._cache_try_retrieve( - object_name, on_not_found) + cache_object, status, idx = self._cache_try_retrieve(object_name, on_not_found) # if the status is retrieved return the object if status == RetrieveStatus.RETRIEVED: return cache_object # else try replicate the data on all the other istances - replica_instances = self.caches[:idx] + self.caches[idx + 1:] + replica_instances = self.caches[:idx] + self.caches[idx + 1 :] for cache_name, cache in replica_instances: try: @@ -209,7 +263,7 @@ def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> dic except Exception as e: self._log_critical( e.__class__.__name__, - f"Cannot replicate cache object with identifier {object_name} on cache {cache_name}" + f"Cannot replicate cache object with identifier {object_name} on cache {cache_name}", ) return cache_object @@ -222,14 +276,15 @@ def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: except Exception as e: self._log_critical( e.__class__.__name__, - f"Cannot overwrite cache object with identifier {object_name} on cache {cache_name}" + f"Cannot overwrite cache object with identifier {object_name} on cache {cache_name}", ) return cache_object def exists_by_state_and_session_id(self, state: str, session_id: str = "") -> bool: for db_name, storage in self.storages: found = storage.exists_by_state_and_session_id( - state=state, session_id=session_id) + state=state, session_id=session_id + ) if found: return True return False @@ -238,14 +293,13 @@ def get_by_state(self, state: str) -> Union[dict, None]: return self.get_by_state_and_session_id(state=state) def get_by_nonce_state(self, state: str, nonce: str) -> Union[dict, None]: - return self.get('get_by_nonce_state', state=state, nonce=nonce) + return self.get("get_by_nonce_state", state=state, nonce=nonce) - def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: + def get_by_state_and_session_id( + self, state: str, session_id: str = "" + ) -> Union[dict, None]: return self.get("get_by_state_and_session_id", state, session_id) - def get_by_session_id(self, session_id: str) -> Union[dict, None]: - return self.get("get_by_session_id", session_id) - @property def is_connected(self): _connected = False @@ -257,19 +311,21 @@ def is_connected(self): except Exception as e: self._log_debug( e.__class__.__name__, - f"Error while checking db engine connection on {db_name}: {e} " + f"Error while checking db engine connection on {db_name}: {e} ", ) if True in _cons.values() and not all(_cons.values()): self._log_warning( "DB Engine", f"Not all the storage are found available, storages misalignment: " - f"{_cons}" + f"{_cons}", ) return _connected - def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus, int]: + def _cache_try_retrieve( + self, object_name: str, on_not_found: Callable[[], str] + ) -> tuple[dict, RetrieveStatus, int]: """ Try to retrieve an object from the cache. If the object is not found, call the on_not_found function. @@ -287,16 +343,15 @@ def _cache_try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) for i, (cache_name, cache_istance) in enumerate(self.caches): try: cache_object, status = cache_istance.try_retrieve( - object_name, on_not_found) + object_name, on_not_found + ) return cache_object, status, i except Exception as e: self._log_critical( e.__class__.__name__, - f"Cannot retrieve cache object with identifier {object_name} on cache database {cache_name}" + f"Cannot retrieve cache object with identifier {object_name} on cache database {cache_name}", ) - raise ConnectionRefusedError( - "Cannot write cache object on any instance" - ) + raise ConnectionRefusedError("Cannot write cache object on any instance") def _close_list(self, db_list: list[Tuple[str, BaseDB]]) -> None: """ @@ -314,11 +369,13 @@ def _close_list(self, db_list: list[Tuple[str, BaseDB]]) -> None: except Exception as e: self._log_critical( e.__class__.__name__, - f"Error while closing db engine {db_name}: {e}" + f"Error while closing db engine {db_name}: {e}", ) raise e - def _handle_instance(self, instance: dict) -> tuple[BaseStorage | None, BaseCache | None]: + def _handle_instance( + self, instance: dict + ) -> tuple[BaseStorage | None, BaseCache | None]: """ Handle the initialization of a storage/cache instance. @@ -336,7 +393,7 @@ def _handle_instance(self, instance: dict) -> tuple[BaseStorage | None, BaseCach storage_instance = dynamic_class_loader( storage_conf["module"], storage_conf["class"], - storage_conf.get("init_params", {}) + storage_conf.get("init_params", {}), ) cache_instance = None @@ -344,7 +401,7 @@ def _handle_instance(self, instance: dict) -> tuple[BaseStorage | None, BaseCach cache_instance = dynamic_class_loader( cache_conf["module"], cache_conf["class"], - cache_conf.get("init_params", {}) + cache_conf.get("init_params", {}), ) return storage_instance, cache_instance diff --git a/pyeudiw/storage/exceptions.py b/pyeudiw/storage/exceptions.py index ae765980..7e2d2d6a 100644 --- a/pyeudiw/storage/exceptions.py +++ b/pyeudiw/storage/exceptions.py @@ -1,7 +1,3 @@ -class ChainAlreadyExist(Exception): - pass - - class ChainNotExist(Exception): pass diff --git a/pyeudiw/storage/mongo_cache.py b/pyeudiw/storage/mongo_cache.py index 99b47866..afe13ec0 100644 --- a/pyeudiw/storage/mongo_cache.py +++ b/pyeudiw/storage/mongo_cache.py @@ -2,11 +2,11 @@ from typing import Callable import pymongo - -from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus from pymongo.collection import Collection -from pymongo.mongo_client import MongoClient from pymongo.database import Database +from pymongo.mongo_client import MongoClient + +from pyeudiw.storage.base_cache import BaseCache, RetrieveStatus class MongoCache(BaseCache): @@ -39,7 +39,9 @@ def close(self) -> None: self._connect() self.client.close() - def try_retrieve(self, object_name: str, on_not_found: Callable[[], str]) -> tuple[dict, RetrieveStatus]: + def try_retrieve( + self, object_name: str, on_not_found: Callable[[], str] + ) -> tuple[dict, RetrieveStatus]: self._connect() query = {"object_name": object_name} @@ -62,17 +64,14 @@ def overwrite(self, object_name: str, value_gen_fn: Callable[[], str]) -> dict: cache_object = { "object_name": object_name, "data": new_data, - "creation_date": update_time + "creation_date": update_time, } query = {"object_name": object_name} - self.collection.update_one(query, { - "$set": { - "data": new_data, - "creation_date": update_time - } - }) + self.collection.update_one( + query, {"$set": {"data": new_data, "creation_date": update_time}} + ) return cache_object @@ -83,8 +82,7 @@ def set(self, data: dict) -> dict: def _connect(self) -> None: if not self.client or not self.client.server_info(): - self.client = pymongo.MongoClient( - self.url, **self.connection_params) + self.client = pymongo.MongoClient(self.url, **self.connection_params) self.db = getattr(self.client, self.storage_conf["db_name"]) self.collection = getattr(self.db, "cache_storage") @@ -101,5 +99,5 @@ def _gen_cache_object(self, object_name: str, data: str) -> dict: return { "object_name": object_name, "data": data, - "creation_date": datetime.now().isoformat() + "creation_date": datetime.now().isoformat(), } diff --git a/pyeudiw/storage/mongo_storage.py b/pyeudiw/storage/mongo_storage.py index 2834194b..721c3320 100644 --- a/pyeudiw/storage/mongo_storage.py +++ b/pyeudiw/storage/mongo_storage.py @@ -1,21 +1,18 @@ -import pymongo import datetime as dt from datetime import datetime +from typing import Union +import pymongo from pymongo.results import UpdateResult from pyeudiw.storage.base_storage import ( BaseStorage, TrustType, - trust_type_map, + trust_anchor_field_map, trust_attestation_field_map, - trust_anchor_field_map -) -from pyeudiw.storage.exceptions import ( - ChainNotExist, - StorageEntryUpdateFailed + trust_type_map, ) -from typing import Union +from pyeudiw.storage.exceptions import ChainNotExist, StorageEntryUpdateFailed class MongoStorage(BaseStorage): @@ -43,9 +40,7 @@ def is_connected(self) -> bool: def _connect(self): if not self.is_connected: - self.client = pymongo.MongoClient( - self.url, **self.connection_params - ) + self.client = pymongo.MongoClient(self.url, **self.connection_params) self.db = getattr(self.client, self.storage_conf["db_name"]) self.sessions = getattr( self.db, self.storage_conf["db_sessions_collection"] @@ -69,7 +64,7 @@ def get_by_id(self, document_id: str) -> dict: document = self.sessions.find_one({"document_id": document_id}) if document is None: - raise ValueError(f'Document with id {document_id} not found') + raise ValueError(f"Document with id {document_id} not found") return document @@ -77,13 +72,12 @@ def get_by_nonce_state(self, nonce: str, state: str | None) -> dict: self._connect() query = {"state": state, "nonce": nonce} if not state: - query.pop('state') + query.pop("state") document = self.sessions.find_one(query) if document is None: - raise ValueError( - f'Document with nonce {nonce} and state {state} not found') + raise ValueError(f"Document with nonce {nonce} and state {state} not found") return document @@ -93,13 +87,13 @@ def get_by_session_id(self, session_id: str) -> Union[dict, None]: document = self.sessions.find_one(query) if document is None: - raise ValueError( - f'Document with session id {session_id} not found.' - ) + raise ValueError(f"Document with session id {session_id} not found.") return document - def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]: + def get_by_state_and_session_id( + self, state: str, session_id: str = "" + ) -> Union[dict, None]: self._connect() query = {"state": state} if session_id: @@ -107,13 +101,13 @@ def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union document = self.sessions.find_one(query) if document is None: - raise ValueError( - f'Document with state {state} not found.' - ) + raise ValueError(f"Document with state {state} not found.") return document - def init_session(self, document_id: str, session_id: str, state: str, remote_flow_typ: str) -> str: + def init_session( + self, document_id: str, session_id: str, state: str, remote_flow_typ: str + ) -> str: entity = { "document_id": document_id, "creation_date": dt.datetime.now(tz=dt.timezone.utc), @@ -141,7 +135,8 @@ def set_session_retention_ttl(self, ttl: int) -> None: self.sessions.drop_index("creation_date_1") else: self.sessions.create_index( - [("creation_date", pymongo.ASCENDING)], expireAfterSeconds=ttl) + [("creation_date", pymongo.ASCENDING)], expireAfterSeconds=ttl + ) def get_session_retention_ttl(self) -> dict: return self.sessions.index_information().get("creation_date_1") @@ -150,7 +145,9 @@ def has_session_retention_ttl(self) -> bool: self._connect() return self.sessions.index_information().get("creation_date_1") is not None - def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict) -> UpdateResult: + def add_dpop_proof_and_attestation( + self, document_id: str, dpop_proof: dict, attestation: dict + ) -> UpdateResult: self._connect() update_result: UpdateResult = self.sessions.update_one( {"document_id": document_id}, @@ -159,15 +156,16 @@ def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, att "dpop_proof": dpop_proof, "attestation": attestation, } - }) + }, + ) if update_result.matched_count != 1 or update_result.modified_count != 1: - raise ValueError( - f"Cannot update document {document_id}'." - ) + raise ValueError(f"Cannot update document {document_id}'.") return update_result - def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult: + def update_request_object( + self, document_id: str, request_object: dict + ) -> UpdateResult: self.get_by_id(document_id) documentStatus = self.sessions.update_one( {"document_id": document_id}, @@ -177,12 +175,10 @@ def update_request_object(self, document_id: str, request_object: dict) -> Updat "nonce": request_object["nonce"], "state": request_object["state"], } - } + }, ) if documentStatus.matched_count != 1 or documentStatus.modified_count != 1: - raise ValueError( - f"Cannot update document {document_id}')" - ) + raise ValueError(f"Cannot update document {document_id}')") return documentStatus def set_finalized(self, document_id: str): @@ -191,27 +187,24 @@ def set_finalized(self, document_id: str): update_result: UpdateResult = self.sessions.update_one( {"document_id": document_id}, { - "$set": { - "finalized": True - }, - } + "$set": {"finalized": True}, + }, ) if update_result.matched_count != 1: # or update_result.modified_count != 1: - raise ValueError( - f"Cannot update document {document_id}'" - ) + raise ValueError(f"Cannot update document {document_id}'") return update_result - def update_response_object(self, nonce: str, state: str, internal_response: dict) -> UpdateResult: + def update_response_object( + self, nonce: str, state: str, internal_response: dict + ) -> UpdateResult: document = self.get_by_nonce_state(nonce, state) document_id = document["_id"] document_status = self.sessions.update_one( {"_id": document_id}, - {"$set": - { - "internal_response": internal_response - }, - }) + { + "$set": {"internal_response": internal_response}, + }, + ) return document_status @@ -227,14 +220,12 @@ def get_trust_source(self, entity_id: str) -> dict | None: def get_trust_attestation(self, entity_id: str) -> dict | None: return self._get_db_entity( - self.storage_conf["db_trust_attestations_collection"], - entity_id + self.storage_conf["db_trust_attestations_collection"], entity_id ) def get_trust_anchor(self, entity_id: str) -> dict | None: return self._get_db_entity( - self.storage_conf["db_trust_anchors_collection"], - entity_id + self.storage_conf["db_trust_anchors_collection"], entity_id ) def _has_db_entity(self, collection: str, entity_id: str) -> bool: @@ -242,44 +233,41 @@ def _has_db_entity(self, collection: str, entity_id: str) -> bool: def has_trust_attestation(self, entity_id: str) -> bool: return self._has_db_entity( - self.storage_conf["db_trust_attestations_collection"], - entity_id + self.storage_conf["db_trust_attestations_collection"], entity_id ) def has_trust_anchor(self, entity_id: str) -> bool: return self._has_db_entity( - self.storage_conf["db_trust_anchors_collection"], - entity_id + self.storage_conf["db_trust_anchors_collection"], entity_id ) def has_trust_source(self, entity_id: str) -> bool: return self._has_db_entity( - self.storage_conf["db_trust_sources_collection"], - entity_id + self.storage_conf["db_trust_sources_collection"], entity_id ) def _upsert_entry( - self, - key_label: str, - collection: str, - data: Union[str, dict] + self, key_label: str, collection: str, data: Union[str, dict] ) -> tuple[str, dict]: db_collection = getattr(self, collection) document_status = db_collection.update_one( - {key_label: data[key_label]}, - {"$set": data}, - upsert=True + {key_label: data[key_label]}, {"$set": data}, upsert=True ) if not document_status.acknowledged: - raise StorageEntryUpdateFailed( - "Trust Anchor matched count is ZERO" - ) + raise StorageEntryUpdateFailed("Trust Anchor matched count is ZERO") return document_status - def _update_attestation_metadata(self, entity: dict, attestation: list[str], exp: datetime, trust_type: TrustType, jwks: list[dict]): + def _update_attestation_metadata( + self, + entity: dict, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + jwks: list[dict], + ): trust_name = trust_type_map[trust_type] trust_field = trust_attestation_field_map.get(trust_type, None) @@ -296,7 +284,14 @@ def _update_attestation_metadata(self, entity: dict, attestation: list[str], exp return entity - def _update_anchor_metadata(self, entity: dict, attestation: list[str], exp: datetime, trust_type: TrustType, entity_id: str): + def _update_anchor_metadata( + self, + entity: dict, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + entity_id: str, + ): if entity.get("entity_id", None) is None: entity["entity_id"] = entity_id @@ -313,20 +308,30 @@ def _update_anchor_metadata(self, entity: dict, attestation: list[str], exp: dat return entity - def add_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType, jwks: list[dict]) -> str: + def add_trust_attestation( + self, + entity_id: str, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + jwks: list[dict], + ) -> str: entity = { "entity_id": entity_id, "federation": {}, "x509": {}, "direct_trust_sd_jwt_vc": {}, - "metadata": {} + "metadata": {}, } updated_entity = self._update_attestation_metadata( - entity, attestation, exp, trust_type, jwks) + entity, attestation, exp, trust_type, jwks + ) self._upsert_entry( - "entity_id", self.storage_conf["db_trust_attestations_collection"], updated_entity + "entity_id", + self.storage_conf["db_trust_attestations_collection"], + updated_entity, ) return entity_id @@ -336,14 +341,15 @@ def add_trust_source(self, trust_source: dict) -> str: "entity_id", self.storage_conf["db_trust_sources_collection"], trust_source ) - def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, metadata: dict): + def add_trust_attestation_metadata( + self, entity_id: str, metadata_type: str, metadata: dict + ): entity = self._get_db_entity( - self.storage_conf["db_trust_attestations_collection"], entity_id) + self.storage_conf["db_trust_attestations_collection"], entity_id + ) if entity is None: - raise ValueError( - f'Document with entity_id {entity_id} not found.' - ) + raise ValueError(f"Document with entity_id {entity_id} not found.") entity["metadata"][metadata_type] = metadata @@ -358,36 +364,67 @@ def add_trust_attestation_metadata(self, entity_id: str, metadata_type: str, met return documentStatus - def add_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType): - entity = { - "entity_id": entity_id, - "federation": {}, - "x509": {} - } + def add_trust_anchor( + self, + entity_id: str, + entity_configuration: str, + exp: datetime, + trust_type: TrustType, + ): + entity = {"entity_id": entity_id, "federation": {}, "x509": {}} updated_entity = self._update_anchor_metadata( - entity, entity_configuration, exp, trust_type, entity_id) + entity, entity_configuration, exp, trust_type, entity_id + ) self._upsert_entry( - "entity_id", self.storage_conf["db_trust_anchors_collection"], updated_entity) + "entity_id", + self.storage_conf["db_trust_anchors_collection"], + updated_entity, + ) return entity_id - def update_trust_attestation(self, entity_id: str, attestation: list[str], exp: datetime, trust_type: TrustType, jwks: list[dict]) -> str: - old_entity = self._get_db_entity( - self.storage_conf["db_trust_attestations_collection"], entity_id) or {} + def update_trust_attestation( + self, + entity_id: str, + attestation: list[str], + exp: datetime, + trust_type: TrustType, + jwks: list[dict], + ) -> str: + old_entity = ( + self._get_db_entity( + self.storage_conf["db_trust_attestations_collection"], entity_id + ) + or {} + ) upd_entity = self._update_attestation_metadata( - old_entity, attestation, exp, trust_type, jwks) + old_entity, attestation, exp, trust_type, jwks + ) return self._upsert_entry( - "entity_id", self.storage_conf["db_trust_attestations_collection"], upd_entity + "entity_id", + self.storage_conf["db_trust_attestations_collection"], + upd_entity, ) - def update_trust_anchor(self, entity_id: str, entity_configuration: str, exp: datetime, trust_type: TrustType) -> str: - old_entity = self._get_db_entity( - self.storage_conf["db_trust_attestations_collection"], entity_id) or {} + def update_trust_anchor( + self, + entity_id: str, + entity_configuration: str, + exp: datetime, + trust_type: TrustType, + ) -> str: + old_entity = ( + self._get_db_entity( + self.storage_conf["db_trust_attestations_collection"], entity_id + ) + or {} + ) upd_entity = self._update_anchor_metadata( - old_entity, entity_configuration, exp, trust_type, entity_id) + old_entity, entity_configuration, exp, trust_type, entity_id + ) if not self.has_trust_anchor(entity_id): raise ChainNotExist(f"Chain with entity id {entity_id} not exist") diff --git a/pyeudiw/storage/schemas/storage.py b/pyeudiw/storage/schemas/storage.py index 1fcd22a5..611b93c1 100644 --- a/pyeudiw/storage/schemas/storage.py +++ b/pyeudiw/storage/schemas/storage.py @@ -8,7 +8,7 @@ class InitParams(BaseModel): class StorageConfig(BaseModel): module: str - class_: str = Field(..., alias='class') + class_: str = Field(..., alias="class") init_params: InitParams diff --git a/pyeudiw/tests/federation/base.py b/pyeudiw/tests/federation/base.py index 625e1d9d..97889d47 100644 --- a/pyeudiw/tests/federation/base.py +++ b/pyeudiw/tests/federation/base.py @@ -1,9 +1,10 @@ +import json + from cryptojwt.jwk.ec import new_ec_key from cryptojwt.jws.jws import JWS -import json import pyeudiw.federation.trust_chain_validator as tcv_test -from pyeudiw.tools.utils import iat_now, exp_from_now +from pyeudiw.tools.utils import exp_from_now, iat_now httpc_params = { "connection": {"ssl": True}, @@ -30,28 +31,23 @@ "iat": NOW, "iss": "https://credential_issuer.example.org", "sub": "https://credential_issuer.example.org", - 'jwks': {"keys": []}, + "jwks": {"keys": []}, "metadata": { - "openid_credential_issuer": { - 'jwks': {"keys": []} - }, + "openid_credential_issuer": {"jwks": {"keys": []}}, "federation_entity": { "organization_name": "OpenID Credential Issuer example", "homepage_uri": "https://credential_issuer.example.org/home", "policy_uri": "https://credential_issuer.example.org/policy", "logo_uri": "https://credential_issuer.example.org/static/logo.svg", - "contacts": [ - "tech@credential_issuer.example.org" - ] - } + "contacts": ["tech@credential_issuer.example.org"], + }, }, - "authority_hints": [ - "https://intermediate.eidas.example.org" - ] + "authority_hints": ["https://intermediate.eidas.example.org"], } -leaf_cred['jwks']['keys'] = [leaf_cred_jwk.serialize()] -leaf_cred['metadata']['openid_credential_issuer']['jwks']['keys'] = [ - leaf_cred_jwk_prot.serialize()] +leaf_cred["jwks"]["keys"] = [leaf_cred_jwk.serialize()] +leaf_cred["metadata"]["openid_credential_issuer"]["jwks"]["keys"] = [ + leaf_cred_jwk_prot.serialize() +] # Define intermediate Entity Statement for credential @@ -60,9 +56,9 @@ "iat": NOW, "iss": "https://intermediate.eidas.example.org", "sub": "https://credential_issuer.example.org", - 'jwks': {"keys": []} + "jwks": {"keys": []}, } -intermediate_es_cred["jwks"]['keys'] = [leaf_cred_jwk.serialize()] +intermediate_es_cred["jwks"]["keys"] = [leaf_cred_jwk.serialize()] # Define leaf Wallet Provider leaf_wallet_jwk = new_ec_key(ec_crv, alg=ec_alg) @@ -71,27 +67,21 @@ "iat": NOW, "iss": "https://wallet-provider.example.org", "sub": "https://wallet-provider.example.org", - 'jwks': {"keys": []}, + "jwks": {"keys": []}, "metadata": { - "wallet_provider": { - "jwks": {"keys": []} - }, + "wallet_provider": {"jwks": {"keys": []}}, "federation_entity": { "organization_name": "OpenID Wallet Verifier example", "homepage_uri": "https://wallet-provider.example.org/home", "policy_uri": "https://wallet-provider.example.org/policy", "logo_uri": "https://wallet-provider.example.org/static/logo.svg", - "contacts": [ - "tech@wallet-provider.example.org" - ] - } + "contacts": ["tech@wallet-provider.example.org"], + }, }, - "authority_hints": [ - "https://intermediate.eidas.example.org" - ] + "authority_hints": ["https://intermediate.eidas.example.org"], } -leaf_wallet['jwks']['keys'] = [leaf_wallet_jwk.serialize()] -leaf_wallet['metadata']['wallet_provider'] = [leaf_wallet_jwk.serialize()] +leaf_wallet["jwks"]["keys"] = [leaf_wallet_jwk.serialize()] +leaf_wallet["metadata"]["wallet_provider"] = [leaf_wallet_jwk.serialize()] # Define intermediate Entity Statement for wallet provider intermediate_es_wallet = { @@ -99,29 +89,27 @@ "iat": NOW, "iss": "https://intermediate.eidas.example.org", "sub": "https://wallet-provider.example.org", - 'jwks': {"keys": [leaf_wallet_jwk.serialize()]} + "jwks": {"keys": [leaf_wallet_jwk.serialize()]}, } # Intermediate EC intermediate_ec = { "exp": EXP, "iat": NOW, - 'iss': 'https://intermediate.eidas.example.org', - 'sub': 'https://intermediate.eidas.example.org', - 'jwks': {"keys": [intermediate_jwk.serialize()]}, - 'metadata': { - 'federation_entity': { - 'contacts': ['soggetto@intermediate.eidas.example.it'], - 'federation_fetch_endpoint': 'https://intermediate.eidas.example.org/fetch', - 'federation_resolve_endpoint': 'https://intermediate.eidas.example.org/resolve', - 'federation_list_endpoint': 'https://intermediate.eidas.example.org/list', - 'homepage_uri': 'https://soggetto.intermediate.eidas.example.it', - 'name': 'Example Intermediate intermediate.eidas.example' + "iss": "https://intermediate.eidas.example.org", + "sub": "https://intermediate.eidas.example.org", + "jwks": {"keys": [intermediate_jwk.serialize()]}, + "metadata": { + "federation_entity": { + "contacts": ["soggetto@intermediate.eidas.example.it"], + "federation_fetch_endpoint": "https://intermediate.eidas.example.org/fetch", + "federation_resolve_endpoint": "https://intermediate.eidas.example.org/resolve", + "federation_list_endpoint": "https://intermediate.eidas.example.org/list", + "homepage_uri": "https://soggetto.intermediate.eidas.example.it", + "name": "Example Intermediate intermediate.eidas.example", } }, - "authority_hints": [ - "https://trust-anchor.example.org" - ] + "authority_hints": ["https://trust-anchor.example.org"], } @@ -131,7 +119,7 @@ "iat": NOW, "iss": "https://trust-anchor.example.org", "sub": "https://intermediate.eidas.example.org", - 'jwks': {"keys": [intermediate_jwk.serialize()]} + "jwks": {"keys": [intermediate_jwk.serialize()]}, } ta_ec = { @@ -139,51 +127,47 @@ "iat": NOW, "iss": "https://trust-anchor.example.org", "sub": "https://trust-anchor.example.org", - 'jwks': {"keys": [ta_jwk.serialize()]}, + "jwks": {"keys": [ta_jwk.serialize()]}, "metadata": { "federation_entity": { - 'federation_fetch_endpoint': 'https://trust-anchor.example.org/fetch', - 'federation_resolve_endpoint': 'https://trust-anchor.example.org/resolve', - 'federation_list_endpoint': 'https://trust-anchor.example.org/list', + "federation_fetch_endpoint": "https://trust-anchor.example.org/fetch", + "federation_resolve_endpoint": "https://trust-anchor.example.org/resolve", + "federation_list_endpoint": "https://trust-anchor.example.org/list", "organization_name": "TA example", "homepage_uri": "https://trust-anchor.example.org/home", "policy_uri": "https://trust-anchor.example.org/policy", "logo_uri": "https://trust-anchor.example.org/static/logo.svg", - "contacts": [ - "tech@trust-anchor.example.org" - ] + "contacts": ["tech@trust-anchor.example.org"], } }, - 'constraints': {'max_path_length': 1} + "constraints": {"max_path_length": 1}, } # Sign step -leaf_cred_signer = JWS(leaf_cred, alg=ec_alg, - typ='entity-statement+jwt') +leaf_cred_signer = JWS(leaf_cred, alg=ec_alg, typ="entity-statement+jwt") leaf_cred_signed = leaf_cred_signer.sign_compact([leaf_cred_jwk]) -leaf_wallet_signer = JWS(leaf_wallet, alg=ec_alg, - typ='entity-statement+jwt') +leaf_wallet_signer = JWS(leaf_wallet, alg=ec_alg, typ="entity-statement+jwt") leaf_wallet_signed = leaf_wallet_signer.sign_compact([leaf_wallet_jwk]) -intermediate_signer_ec = JWS( - intermediate_ec, alg=ec_alg, - typ="entity-statement+jwt" -) -intermediate_ec_signed = intermediate_signer_ec.sign_compact([ - intermediate_jwk]) +intermediate_signer_ec = JWS(intermediate_ec, alg=ec_alg, typ="entity-statement+jwt") +intermediate_ec_signed = intermediate_signer_ec.sign_compact([intermediate_jwk]) intermediate_signer_es_cred = JWS( - intermediate_es_cred, alg=ec_alg, typ='entity-statement+jwt') -intermediate_es_cred_signed = intermediate_signer_es_cred.sign_compact([ - intermediate_jwk]) + intermediate_es_cred, alg=ec_alg, typ="entity-statement+jwt" +) +intermediate_es_cred_signed = intermediate_signer_es_cred.sign_compact( + [intermediate_jwk] +) intermediate_signer_es_wallet = JWS( - intermediate_es_wallet, alg=ec_alg, typ='entity-statement+jwt') -intermediate_es_wallet_signed = intermediate_signer_es_wallet.sign_compact([ - intermediate_jwk]) + intermediate_es_wallet, alg=ec_alg, typ="entity-statement+jwt" +) +intermediate_es_wallet_signed = intermediate_signer_es_wallet.sign_compact( + [intermediate_jwk] +) ta_es_signer = JWS(ta_es, alg=ec_alg, typ="entity-statement+jwt") ta_es_signed = ta_es_signer.sign_compact([ta_jwk]) @@ -196,14 +180,10 @@ leaf_cred_signed, intermediate_es_cred_signed, ta_es_signed, - ta_ec_signed + ta_ec_signed, ] -trust_chain_wallet = [ - leaf_wallet_signed, - intermediate_es_wallet_signed, - ta_es_signed -] +trust_chain_wallet = [leaf_wallet_signed, intermediate_es_wallet_signed, ta_es_signed] test_cred = tcv_test.StaticTrustChainValidator( trust_chain_issuer, [ta_jwk.serialize()], httpc_params=httpc_params diff --git a/pyeudiw/tests/federation/mocked_response.py b/pyeudiw/tests/federation/mocked_response.py index 3645ca05..b7caa5db 100644 --- a/pyeudiw/tests/federation/mocked_response.py +++ b/pyeudiw/tests/federation/mocked_response.py @@ -1,8 +1,13 @@ - -from . base import intermediate_ec_signed, intermediate_es_wallet_signed, leaf_wallet_signed, ta_ec_signed, ta_es_signed - import logging +from .base import ( + intermediate_ec_signed, + intermediate_es_wallet_signed, + leaf_wallet_signed, + ta_ec_signed, + ta_es_signed, +) + logger = logging.getLogger(__name__) @@ -28,7 +33,7 @@ def content(self): 1: leaf_wallet_signed, 2: intermediate_ec_signed, 3: intermediate_es_wallet_signed, - 4: ta_es_signed + 4: ta_es_signed, } self.result = resp_seq.get(self.req_counter, None) @@ -40,112 +45,113 @@ def content(self): "The mocked resposes seems to be not aligned with the correct flow" ) + # class EntityResponseNoIntermediateSignedJwksUri(EntityResponse): - # @property - # def content(self): - # if self.req_counter == 0: - # self.result = self.trust_anchor_ec() - # elif self.req_counter == 1: - # self.result = self.rp_ec() - # elif self.req_counter == 2: - # self.result = self.fetch_rp_from_ta() - # elif self.req_counter == 3: - # metadata = copy.deepcopy( - # rp_conf['metadata']['openid_relying_party']) - # _jwks = metadata.pop('jwks') - # fed_jwks = rp_conf['jwks_fed'][0] - # self.result = create_jws(_jwks, fed_jwks) - # return self.result.encode() - # elif self.req_counter > 3: - # raise NotImplementedError( - # "The mocked resposes seems to be not aligned with the correct flow" - # ) - - # return self.result_as_jws() +# @property +# def content(self): +# if self.req_counter == 0: +# self.result = self.trust_anchor_ec() +# elif self.req_counter == 1: +# self.result = self.rp_ec() +# elif self.req_counter == 2: +# self.result = self.fetch_rp_from_ta() +# elif self.req_counter == 3: +# metadata = copy.deepcopy( +# rp_conf['metadata']['openid_relying_party']) +# _jwks = metadata.pop('jwks') +# fed_jwks = rp_conf['jwks_fed'][0] +# self.result = create_jws(_jwks, fed_jwks) +# return self.result.encode() +# elif self.req_counter > 3: +# raise NotImplementedError( +# "The mocked resposes seems to be not aligned with the correct flow" +# ) + +# return self.result_as_jws() # class EntityResponseWithIntermediate(EntityResponse): - # @property - # def content(self): - # if self.req_counter == 0: - # self.result = self.trust_anchor_ec() - # elif self.req_counter == 1: - # self.result = self.rp_ec() - # elif self.req_counter == 2: - # sa = FederationEntityConfiguration.objects.get( - # sub=intermediary_conf["sub"]) - # self.result = DummyContent(sa.entity_configuration_as_jws) - # elif self.req_counter == 3: - # url = reverse("oidcfed_fetch") - # self.result = self.client.get( - # url, - # data={ - # "sub": rp_onboarding_data["sub"], - # "iss": intermediary_conf["sub"], - # }, - # ) - # elif self.req_counter == 4: - # url = reverse("oidcfed_fetch") - # self.result = self.client.get( - # url, data={"sub": intermediary_conf["sub"]}) - # elif self.req_counter == 5: - # url = reverse("entity_configuration") - # self.result = self.client.get( - # url, data={"sub": ta_conf_data["sub"]}) - # elif self.req_counter > 5: - # raise NotImplementedError( - # "The mocked resposes seems to be not aligned with the correct flow" - # ) - - # if self.result.status_code != 200: - # raise HttpError( - # f"Something went wrong with Http Request: {self.result.__dict__}") - - # logger.info("-------------------------------------------------") - # logger.info("") - # return self.result_as_jws() +# @property +# def content(self): +# if self.req_counter == 0: +# self.result = self.trust_anchor_ec() +# elif self.req_counter == 1: +# self.result = self.rp_ec() +# elif self.req_counter == 2: +# sa = FederationEntityConfiguration.objects.get( +# sub=intermediary_conf["sub"]) +# self.result = DummyContent(sa.entity_configuration_as_jws) +# elif self.req_counter == 3: +# url = reverse("oidcfed_fetch") +# self.result = self.client.get( +# url, +# data={ +# "sub": rp_onboarding_data["sub"], +# "iss": intermediary_conf["sub"], +# }, +# ) +# elif self.req_counter == 4: +# url = reverse("oidcfed_fetch") +# self.result = self.client.get( +# url, data={"sub": intermediary_conf["sub"]}) +# elif self.req_counter == 5: +# url = reverse("entity_configuration") +# self.result = self.client.get( +# url, data={"sub": ta_conf_data["sub"]}) +# elif self.req_counter > 5: +# raise NotImplementedError( +# "The mocked resposes seems to be not aligned with the correct flow" +# ) + +# if self.result.status_code != 200: +# raise HttpError( +# f"Something went wrong with Http Request: {self.result.__dict__}") + +# logger.info("-------------------------------------------------") +# logger.info("") +# return self.result_as_jws() # class EntityResponseWithIntermediateManyHints(EntityResponse): - # @property - # def content(self): - # if self.req_counter == 0: - # self.result = self.trust_anchor_ec() - # elif self.req_counter == 1: - # self.result = self.rp_ec() - # elif self.req_counter == 2: - # sa = FederationEntityConfiguration.objects.get( - # sub=intermediary_conf["sub"]) - # self.result = DummyContent(sa.entity_configuration_as_jws) - # elif self.req_counter == 3: - # self.result = DummyContent("crap") - - # elif self.req_counter == 4: - # url = reverse("oidcfed_fetch") - # self.result = self.client.get( - # url, - # data={ - # "sub": rp_onboarding_data["sub"], - # "iss": intermediary_conf["sub"], - # }, - # ) - # elif self.req_counter == 5: - # url = reverse("oidcfed_fetch") - # self.result = self.client.get( - # url, data={"sub": intermediary_conf["sub"]}) - # elif self.req_counter == 6: - # url = reverse("entity_configuration") - # self.result = self.client.get( - # url, data={"sub": ta_conf_data["sub"]}) - # elif self.req_counter > 6: - # raise NotImplementedError( - # "The mocked resposes seems to be not aligned with the correct flow" - # ) - # if self.result.status_code != 200: - # raise HttpError( - # f"Something went wrong with Http Request: {self.result.__dict__}") - - # try: - # return self.result_as_jws() - # except Exception: - # return self.result_as_it_is() +# @property +# def content(self): +# if self.req_counter == 0: +# self.result = self.trust_anchor_ec() +# elif self.req_counter == 1: +# self.result = self.rp_ec() +# elif self.req_counter == 2: +# sa = FederationEntityConfiguration.objects.get( +# sub=intermediary_conf["sub"]) +# self.result = DummyContent(sa.entity_configuration_as_jws) +# elif self.req_counter == 3: +# self.result = DummyContent("crap") + +# elif self.req_counter == 4: +# url = reverse("oidcfed_fetch") +# self.result = self.client.get( +# url, +# data={ +# "sub": rp_onboarding_data["sub"], +# "iss": intermediary_conf["sub"], +# }, +# ) +# elif self.req_counter == 5: +# url = reverse("oidcfed_fetch") +# self.result = self.client.get( +# url, data={"sub": intermediary_conf["sub"]}) +# elif self.req_counter == 6: +# url = reverse("entity_configuration") +# self.result = self.client.get( +# url, data={"sub": ta_conf_data["sub"]}) +# elif self.req_counter > 6: +# raise NotImplementedError( +# "The mocked resposes seems to be not aligned with the correct flow" +# ) +# if self.result.status_code != 200: +# raise HttpError( +# f"Something went wrong with Http Request: {self.result.__dict__}") + +# try: +# return self.result_as_jws() +# except Exception: +# return self.result_as_it_is() diff --git a/pyeudiw/tests/federation/schemas/test_entity_configuration.py b/pyeudiw/tests/federation/schemas/test_entity_configuration.py index 0d3890d1..14a1f19a 100644 --- a/pyeudiw/tests/federation/schemas/test_entity_configuration.py +++ b/pyeudiw/tests/federation/schemas/test_entity_configuration.py @@ -1,14 +1,16 @@ import pytest - from pydantic import ValidationError -from pyeudiw.federation.schemas.entity_configuration import EntityConfigurationHeader, EntityConfigurationPayload +from pyeudiw.federation.schemas.entity_configuration import ( + EntityConfigurationHeader, + EntityConfigurationPayload, +) ENTITY_CONFIGURATION = { "header": { "alg": "RS256", "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", - "typ": "entity-statement+jwt" + "typ": "entity-statement+jwt", }, "payload": { "exp": 1649590602, @@ -21,12 +23,12 @@ "kty": "RSA", "n": "5s4qi …", "e": "AQAB", - "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs" + "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", } ] }, "metadata": { - "wallet_relying_party": { + "openid_credential_verifier": { "application_type": "web", "client_id": "https://rp.example.it", "client_name": "Name of an example organization", @@ -38,75 +40,45 @@ "n": "1Ta-sE …", "e": "AQAB", "kid": "YhNFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", - "x5c": ["..."] + "x5c": ["..."], } ] }, - - "contacts": [ - "ops@verifier.example.org" - ], - - "request_uris": [ - "https://verifier.example.org/request_uri" - ], - "redirect_uris": [ - "https://verifier.example.org/callback" - ], - + "contacts": ["ops@verifier.example.org"], + "request_uris": ["https://verifier.example.org/request_uri"], + "redirect_uris": ["https://verifier.example.org/callback"], "default_acr_values": [ "https://www.spid.gov.it/SpidL2", - "https://www.spid.gov.it/SpidL3" + "https://www.spid.gov.it/SpidL3", ], "vp_formats": { "vc+sd-jwt": { - "sd-jwt_alg_values": [ - "ES256", - "ES384" - ], - "kb-jwt_alg_values": [ - "ES256", - "ES384" - ] + "sd-jwt_alg_values": ["ES256", "ES384"], + "kb-jwt_alg_values": ["ES256", "ES384"], } }, - "default_max_age": 1111, - - "authorization_signed_response_alg": [ - "RS256", - "ES256" - ], - "authorization_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], + "authorization_signed_response_alg": ["RS256", "ES256"], + "authorization_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], "authorization_encrypted_response_enc": [ "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" + "A256GCM", ], - "subject_type": "pairwise", "require_auth_time": True, - "id_token_signed_response_alg": [ - "RS256", - "ES256" - ], - "id_token_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], + "id_token_signed_response_alg": ["RS256", "ES256"], + "id_token_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], "id_token_encrypted_response_enc": [ "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" + "A256GCM", ], }, "federation_entity": { @@ -114,15 +86,11 @@ "homepage_uri": "https://verifier.example.org/home", "policy_uri": "https://verifier.example.org/policy", "logo_uri": "https://verifier.example.org/static/logo.svg", - "contacts": [ - "tech@verifier.example.org" - ] - } + "contacts": ["tech@verifier.example.org"], + }, }, - "authority_hints": [ - "https://registry.eudi-wallet.example.it" - ] - } + "authority_hints": ["https://registry.eudi-wallet.example.it"], + }, } @@ -131,7 +99,8 @@ def test_entity_configuration_header(): with pytest.raises(ValidationError): EntityConfigurationHeader.model_validate( - ENTITY_CONFIGURATION["header"], context={"supported_algorithms": ["ES256"]}) + ENTITY_CONFIGURATION["header"], context={"supported_algorithms": ["ES256"]} + ) ENTITY_CONFIGURATION["header"]["typ"] = "NOT-entity-statement+jwt" with pytest.raises(ValidationError): @@ -147,9 +116,11 @@ def test_entity_configuration_payload(): with pytest.raises(ValidationError): EntityConfigurationPayload(**ENTITY_CONFIGURATION["payload"]) - ENTITY_CONFIGURATION["payload"]["jwks"]["keys"] = [{ - "kty": "RSA", - "n": "5s4qi …", - "e": "AQAB", - "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs" - }] + ENTITY_CONFIGURATION["payload"]["jwks"]["keys"] = [ + { + "kty": "RSA", + "n": "5s4qi …", + "e": "AQAB", + "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", + } + ] diff --git a/pyeudiw/tests/federation/test_metadata.py b/pyeudiw/tests/federation/test_metadata.py index 7e4734a1..fb25204d 100644 --- a/pyeudiw/tests/federation/test_metadata.py +++ b/pyeudiw/tests/federation/test_metadata.py @@ -1,8 +1,7 @@ import pytest -from pyeudiw.federation.policy import combine -from pyeudiw.federation.policy import combine_claim_policy -from pyeudiw.federation.policy import TrustChainPolicy + from pyeudiw.federation.exceptions import PolicyError +from pyeudiw.federation.policy import TrustChainPolicy, combine, combine_claim_policy __author__ = "Roland Hedberg" __license__ = "Apache 2.0" @@ -11,277 +10,166 @@ SIMPLE = [ ( "SUBSET_OF", - {"subset_of": ['X', 'Y', 'Z']}, - {"subset_of": ['X', 'Y']}, - {"subset_of": ['X', 'Y']}), - ( - "SUBSET_OF", - {"subset_of": ['X', 'Y', 'Z']}, - {"subset_of": ['X', 'Y', 'W']}, - {"subset_of": ['X', 'Y']} + {"subset_of": ["X", "Y", "Z"]}, + {"subset_of": ["X", "Y"]}, + {"subset_of": ["X", "Y"]}, ), ( "SUBSET_OF", - {"subset_of": ['A', 'X', 'Y', 'Z']}, - {"subset_of": ['X', 'Y', 'W']}, - {"subset_of": ['X', 'Y']} + {"subset_of": ["X", "Y", "Z"]}, + {"subset_of": ["X", "Y", "W"]}, + {"subset_of": ["X", "Y"]}, ), ( "SUBSET_OF", - {"subset_of": ['Y', 'Z']}, - {"subset_of": ['X', 'Y']}, - {"subset_of": ['Y']} + {"subset_of": ["A", "X", "Y", "Z"]}, + {"subset_of": ["X", "Y", "W"]}, + {"subset_of": ["X", "Y"]}, ), ( "SUBSET_OF", - {"subset_of": ['X', 'Y']}, - {"subset_of": ['Z', 'Y']}, - {"subset_of": ['Y']} + {"subset_of": ["Y", "Z"]}, + {"subset_of": ["X", "Y"]}, + {"subset_of": ["Y"]}, ), ( "SUBSET_OF", - {"subset_of": ['X', 'Y']}, - {"subset_of": ['Z', 'W']}, - PolicyError + {"subset_of": ["X", "Y"]}, + {"subset_of": ["Z", "Y"]}, + {"subset_of": ["Y"]}, ), + ("SUBSET_OF", {"subset_of": ["X", "Y"]}, {"subset_of": ["Z", "W"]}, PolicyError), ( "SUPERSET_OF", - {"superset_of": ['X', 'Y', 'Z']}, - {"superset_of": ['X', 'Y']}, - {"superset_of": ['X', 'Y']} + {"superset_of": ["X", "Y", "Z"]}, + {"superset_of": ["X", "Y"]}, + {"superset_of": ["X", "Y"]}, ), ( "SUPERSET_OF", - {"superset_of": ['X', 'Y', 'Z']}, - {"superset_of": ['X', 'Y', 'W']}, - {"superset_of": ['X', 'Y']} + {"superset_of": ["X", "Y", "Z"]}, + {"superset_of": ["X", "Y", "W"]}, + {"superset_of": ["X", "Y"]}, ), ( "SUPERSET_OF", - {"superset_of": ['A', 'X', 'Y', 'Z']}, - {"superset_of": ['X', 'Y', 'W']}, - {"superset_of": ['X', 'Y']} + {"superset_of": ["A", "X", "Y", "Z"]}, + {"superset_of": ["X", "Y", "W"]}, + {"superset_of": ["X", "Y"]}, ), ( "SUPERSET_OF", - {"superset_of": ['Y', 'Z']}, - {"superset_of": ['X', 'Y']}, - {"superset_of": ['Y']} + {"superset_of": ["Y", "Z"]}, + {"superset_of": ["X", "Y"]}, + {"superset_of": ["Y"]}, ), ( "SUPERSET_OF", - {"superset_of": ['X', 'Y']}, - {"superset_of": ['Z', 'Y']}, - {"superset_of": ['Y']} + {"superset_of": ["X", "Y"]}, + {"superset_of": ["Z", "Y"]}, + {"superset_of": ["Y"]}, ), ( "SUPERSET_OF", - {"superset_of": ['X', 'Y']}, - {"superset_of": ['Z', 'W']}, - PolicyError - ), - - ( - "ONE_OF", - {"one_of": ['X', 'Y', 'Z']}, - {"one_of": ['X', 'Y']}, - {"one_of": ['X', 'Y']} - ), - ( - "ONE_OF", - {"one_of": ['X', 'Y', 'Z']}, - {"one_of": ['X', 'Y', 'W']}, - {"one_of": ['X', 'Y']} + {"superset_of": ["X", "Y"]}, + {"superset_of": ["Z", "W"]}, + PolicyError, ), ( "ONE_OF", - {"one_of": ['A', 'X', 'Y', 'Z']}, - {"one_of": ['X', 'Y', 'W']}, - {"one_of": ['X', 'Y']} + {"one_of": ["X", "Y", "Z"]}, + {"one_of": ["X", "Y"]}, + {"one_of": ["X", "Y"]}, ), ( "ONE_OF", - {"one_of": ['Y', 'Z']}, - {"one_of": ['X', 'Y']}, - {"one_of": ['Y']} + {"one_of": ["X", "Y", "Z"]}, + {"one_of": ["X", "Y", "W"]}, + {"one_of": ["X", "Y"]}, ), ( "ONE_OF", - {"one_of": ['X', 'Y']}, - {"one_of": ['Z', 'Y']}, - {"one_of": ['Y']} - ), - ( - "ONE_OF", - {"one_of": ['X', 'Y']}, - {"one_of": ['Z', 'W']}, - PolicyError - ), - ( - "ADD", - {"add": "X"}, - {"add": "B"}, - {"add": ["X", "B"]} - ), - ( - "ADD", - {"add": "X"}, - {"add": "X"}, - {"add": "X"} - ), - ( - "VALUE", - {"value": "X"}, - {"value": "B"}, - PolicyError - ), - ( - "VALUE", - {"value": "X"}, - {"value": "X"}, - {"value": "X"} - ), - ( - "VALUE", - {"value": ["X", "Y"]}, - {"value": ["X", "Z"]}, - PolicyError - ), - ( - "DEFAULT", - {"default": "X"}, - {"default": "B"}, - PolicyError - ), - ( - "DEFAULT", - {"default": ["X", "B"]}, - {"default": ["B", "Y"]}, - PolicyError - ), - ( - "DEFAULT", - {"default": "X"}, - {"default": "X"}, - {"default": "X"} - ), - ( - "ESSENTIAL", - {"essential": True}, - {"essential": False}, - PolicyError - ), - ( - "ESSENTIAL", - {"essential": False}, - {"essential": True}, - {"essential": True} - ), - ( - "ESSENTIAL", - {"essential": True}, - {"essential": True}, - {"essential": True} - ), - ( - "ESSENTIAL", - {"essential": False}, - {"essential": False}, - {"essential": False} - ) + {"one_of": ["A", "X", "Y", "Z"]}, + {"one_of": ["X", "Y", "W"]}, + {"one_of": ["X", "Y"]}, + ), + ("ONE_OF", {"one_of": ["Y", "Z"]}, {"one_of": ["X", "Y"]}, {"one_of": ["Y"]}), + ("ONE_OF", {"one_of": ["X", "Y"]}, {"one_of": ["Z", "Y"]}, {"one_of": ["Y"]}), + ("ONE_OF", {"one_of": ["X", "Y"]}, {"one_of": ["Z", "W"]}, PolicyError), + ("ADD", {"add": "X"}, {"add": "B"}, {"add": ["X", "B"]}), + ("ADD", {"add": "X"}, {"add": "X"}, {"add": "X"}), + ("VALUE", {"value": "X"}, {"value": "B"}, PolicyError), + ("VALUE", {"value": "X"}, {"value": "X"}, {"value": "X"}), + ("VALUE", {"value": ["X", "Y"]}, {"value": ["X", "Z"]}, PolicyError), + ("DEFAULT", {"default": "X"}, {"default": "B"}, PolicyError), + ("DEFAULT", {"default": ["X", "B"]}, {"default": ["B", "Y"]}, PolicyError), + ("DEFAULT", {"default": "X"}, {"default": "X"}, {"default": "X"}), + ("ESSENTIAL", {"essential": True}, {"essential": False}, PolicyError), + ("ESSENTIAL", {"essential": False}, {"essential": True}, {"essential": True}), + ("ESSENTIAL", {"essential": True}, {"essential": True}, {"essential": True}), + ("ESSENTIAL", {"essential": False}, {"essential": False}, {"essential": False}), ] COMPLEX = [ + ({"essential": False}, {"default": "A"}, {"essential": False, "default": "A"}), + ({"essential": True}, {"default": "A"}, {"essential": True, "default": "A"}), ( - {"essential": False}, - {"default": 'A'}, - {"essential": False, "default": 'A'} - ), - ( - {"essential": True}, - {"default": 'A'}, - {"essential": True, "default": 'A'} - ), - ( - {"essential": False, "default": 'A'}, - {"default": 'A', "essential": True}, - {"essential": True, "default": 'A'} + {"essential": False, "default": "A"}, + {"default": "A", "essential": True}, + {"essential": True, "default": "A"}, ), ( - {"essential": True, "default": 'A'}, - {"default": 'B', "essential": True}, - PolicyError + {"essential": True, "default": "A"}, + {"default": "B", "essential": True}, + PolicyError, ), ( {"essential": False}, - {"subset_of": ['B']}, - {"essential": False, "subset_of": ['B']} - ), - ( - {"subset_of": ['X', 'Y', 'Z']}, - {"superset_of": ['Y', 'Z']}, - {"subset_of": ['X', 'Y', 'Z'], "superset_of": ['Y', 'Z']} - ), - ( - {"superset_of": ['Y', 'Z']}, - {"subset_of": ['X', 'Y']}, - PolicyError - ), - ( - {"subset_of": ['X', 'Y']}, - {"superset_of": ['X', 'Y']}, - {"subset_of": ['X', 'Y'], "superset_of": ['X', 'Y']} - ), - ( - {"superset_of": ['X', 'Y']}, - {"subset_of": ['X', 'Y']}, - {"subset_of": ['X', 'Y'], "superset_of": ['X', 'Y']} - ), - ( - {"subset_of": ['X', 'Y', 'Z']}, - {"superset_of": ['Y', 'A']}, - PolicyError - ), - ( - {"subset_of": ['X', 'Y', ]}, - {"superset_of": ['X', 'Y', 'A']}, - PolicyError + {"subset_of": ["B"]}, + {"essential": False, "subset_of": ["B"]}, ), ( - {"subset_of": ['X', 'Y']}, - {"default": ['X']}, - {"subset_of": ['X', 'Y'], "default": ['X']} + {"subset_of": ["X", "Y", "Z"]}, + {"superset_of": ["Y", "Z"]}, + {"subset_of": ["X", "Y", "Z"], "superset_of": ["Y", "Z"]}, ), + ({"superset_of": ["Y", "Z"]}, {"subset_of": ["X", "Y"]}, PolicyError), ( - {"superset_of": ['X', 'Y']}, - {"default": ['X', 'Y', 'Z']}, - {"superset_of": ['X', 'Y'], "default": ['X', 'Y', 'Z']} + {"subset_of": ["X", "Y"]}, + {"superset_of": ["X", "Y"]}, + {"subset_of": ["X", "Y"], "superset_of": ["X", "Y"]}, ), ( - {"one_of": ['X', 'Y']}, - {"default": 'X'}, - {"one_of": ['X', 'Y'], "default": 'X'} + {"superset_of": ["X", "Y"]}, + {"subset_of": ["X", "Y"]}, + {"subset_of": ["X", "Y"], "superset_of": ["X", "Y"]}, ), + ({"subset_of": ["X", "Y", "Z"]}, {"superset_of": ["Y", "A"]}, PolicyError), ( - {"subset_of": ['X', 'Y']}, - {"default": ['X', 'Z']}, - PolicyError + { + "subset_of": [ + "X", + "Y", + ] + }, + {"superset_of": ["X", "Y", "A"]}, + PolicyError, ), ( - {"subset_of": ['X', 'Y']}, - {"one_of": ['X', 'Y']}, - PolicyError + {"subset_of": ["X", "Y"]}, + {"default": ["X"]}, + {"subset_of": ["X", "Y"], "default": ["X"]}, ), ( - {"superset_of": ['X', 'Y']}, - {"default": ['X', 'Z']}, - PolicyError + {"superset_of": ["X", "Y"]}, + {"default": ["X", "Y", "Z"]}, + {"superset_of": ["X", "Y"], "default": ["X", "Y", "Z"]}, ), - ( - {"one_of": ['X', 'Y']}, - {"default": 'Z'}, - PolicyError - ) + ({"one_of": ["X", "Y"]}, {"default": "X"}, {"one_of": ["X", "Y"], "default": "X"}), + ({"subset_of": ["X", "Y"]}, {"default": ["X", "Z"]}, PolicyError), + ({"subset_of": ["X", "Y"]}, {"one_of": ["X", "Y"]}, PolicyError), + ({"superset_of": ["X", "Y"]}, {"default": ["X", "Z"]}, PolicyError), + ({"one_of": ["X", "Y"]}, {"default": "Z"}, PolicyError), ] @@ -321,50 +209,48 @@ def test_complex_policy_combinations(superior, subordinate, result): "scopes": { "subset_of": ["openid", "eduperson", "phone"], "superset_of": ["openid"], - "default": ["openid", "eduperson"]}, + "default": ["openid", "eduperson"], + }, "id_token_signed_response_alg": { "one_of": ["ES256", "ES384", "ES512"], - "default": "ES256" + "default": "ES256", }, - "contacts": { - "add": "helpdesk@federation.example.org"}, - "application_type": {"value": "web"} + "contacts": {"add": "helpdesk@federation.example.org"}, + "application_type": {"value": "web"}, } ORG = { "scopes": { "subset_of": ["openid", "eduperson", "address"], - "default": ["openid", "eduperson"]}, - "id_token_signed_response_alg": { - "one_of": ["ES256", "ES384"], - "default": "ES256"}, - "contacts": { - "add": "helpdesk@org.example.org"}, + "default": ["openid", "eduperson"], + }, + "id_token_signed_response_alg": {"one_of": ["ES256", "ES384"], "default": "ES256"}, + "contacts": {"add": "helpdesk@org.example.org"}, } RES = { "scopes": { "subset_of": ["openid", "eduperson"], "superset_of": ["openid"], - "default": ["openid", "eduperson"]}, - "id_token_signed_response_alg": { - "one_of": ["ES256", "ES384"], - "default": "ES256"}, + "default": ["openid", "eduperson"], + }, + "id_token_signed_response_alg": {"one_of": ["ES256", "ES384"], "default": "ES256"}, "contacts": { - "add": ["helpdesk@federation.example.org", - "helpdesk@org.example.org"]}, - "application_type": { - "value": "web"} + "add": ["helpdesk@federation.example.org", "helpdesk@org.example.org"] + }, + "application_type": {"value": "web"}, } def test_combine_policies(): - res = combine({'metadata_policy': FED, 'metadata': {}}, - {'metadata_policy': ORG, 'metadata': {}}) + res = combine( + {"metadata_policy": FED, "metadata": {}}, + {"metadata_policy": ORG, "metadata": {}}, + ) - assert set(res['metadata_policy'].keys()) == set(RES.keys()) + assert set(res["metadata_policy"].keys()) == set(RES.keys()) - for claim, policy in res['metadata_policy'].items(): + for claim, policy in res["metadata_policy"].items(): assert set(policy.keys()) == set(RES[claim].keys()) assert assert_equal(policy, RES[claim]) @@ -372,33 +258,29 @@ def test_combine_policies(): RP = { "contacts": ["rp_admins@cs.example.com"], "redirect_uris": ["https://cs.example.com/rp1"], - "response_types": ["code"] + "response_types": ["code"], } FED1 = { "scopes": { "superset_of": ["openid", "eduperson"], - "default": ["openid", "eduperson"] + "default": ["openid", "eduperson"], }, - "response_types": { - "subset_of": ["code", "code id_token"]}, - "id_token_signed_response_alg": { - "one_of": ["ES256", "ES384"], - "default": "ES256"} + "response_types": {"subset_of": ["code", "code id_token"]}, + "id_token_signed_response_alg": {"one_of": ["ES256", "ES384"], "default": "ES256"}, } ORG1 = { - "contacts": { - "add": "helpdesk@example.com"}, + "contacts": {"add": "helpdesk@example.com"}, "logo_uri": { - "one_of": ["https://example.com/logo_small.jpg", - "https://example.com/logo_big.jpg"], - "default": "https://example.com/logo_small.jpg" + "one_of": [ + "https://example.com/logo_small.jpg", + "https://example.com/logo_big.jpg", + ], + "default": "https://example.com/logo_small.jpg", }, - "policy_uri": { - "value": "https://example.com/policy.html"}, - "tos_uri": { - "value": "https://example.com/tos.html"} + "policy_uri": {"value": "https://example.com/policy.html"}, + "tos_uri": {"value": "https://example.com/tos.html"}, } RES1 = { @@ -409,13 +291,15 @@ def test_combine_policies(): "scopes": ["openid", "eduperson"], "response_types": ["code"], "redirect_uris": ["https://cs.example.com/rp1"], - "id_token_signed_response_alg": "ES256" + "id_token_signed_response_alg": "ES256", } def test_apply_policies(): - comb_policy = combine({'metadata_policy': FED1, 'metadata': {}}, - {'metadata_policy': ORG1, 'metadata': {}}) + comb_policy = combine( + {"metadata_policy": FED1, "metadata": {}}, + {"metadata_policy": ORG1, "metadata": {}}, + ) res = TrustChainPolicy().apply_policy(RP, comb_policy) @@ -434,32 +318,25 @@ def test_apply_policies(): assert value == RES1[claim] -@pytest.mark.parametrize("policy, metadata, result", - [ - ( - [{ - 'metadata': {'B': 123}, - 'metadata_policy': { - "A": {"subset_of": ['a', 'b']} - }}, - { - 'metadata': {'C': 'foo'}, - 'metadata_policy': { - "A": {"subset_of": ['a']} - } - } - ], - { - "A": ['a', 'b', 'e'], - "C": 'foo' - }, - { - 'A': ['a'], - 'B': 123, - 'C': 'foo' - } - ) - ]) +@pytest.mark.parametrize( + "policy, metadata, result", + [ + ( + [ + { + "metadata": {"B": 123}, + "metadata_policy": {"A": {"subset_of": ["a", "b"]}}, + }, + { + "metadata": {"C": "foo"}, + "metadata_policy": {"A": {"subset_of": ["a"]}}, + }, + ], + {"A": ["a", "b", "e"], "C": "foo"}, + {"A": ["a"], "B": 123, "C": "foo"}, + ) + ], +) def test_combine_metadata_and_metadata_policy_OK(policy, metadata, result): comb_policy = policy[0] for pol in policy[1:]: @@ -470,33 +347,24 @@ def test_combine_metadata_and_metadata_policy_OK(policy, metadata, result): # 1 a subordinate can not change something a superior has set -@pytest.mark.parametrize("policy", - [ - [ - { - 'metadata': {'B': 123}, - 'metadata_policy': { - "A": {"subset_of": ['a', 'b']} - } - }, - { - 'metadata': {'B': 'foo'}, - 'metadata_policy': { - "A": {"subset_of": ['a']} - } - } - ], - [ - { - 'metadata': {'B': 123}, - }, - { - 'metadata_policy': { - "B": {"subset_of": [12, 6]} - } - } - ] - ]) +@pytest.mark.parametrize( + "policy", + [ + [ + { + "metadata": {"B": 123}, + "metadata_policy": {"A": {"subset_of": ["a", "b"]}}, + }, + {"metadata": {"B": "foo"}, "metadata_policy": {"A": {"subset_of": ["a"]}}}, + ], + [ + { + "metadata": {"B": 123}, + }, + {"metadata_policy": {"B": {"subset_of": [12, 6]}}}, + ], + ], +) def test_combine_metadata_and_metadata_policy_NOT_OK(policy): with pytest.raises(PolicyError): combine(policy[0], policy[1]) @@ -505,14 +373,14 @@ def test_combine_metadata_and_metadata_policy_NOT_OK(policy): POLICY_1 = { "scopes": { "superset_of": ["openid", "eduperson"], - "subset_of": ["openid", "eduperson"] + "subset_of": ["openid", "eduperson"], } } POLICY_2 = { "response_types": { "subset_of": ["code", "code id_token"], - "superset_of": ["code", "code id_token"] + "superset_of": ["code", "code id_token"], } } @@ -520,15 +388,17 @@ def test_combine_metadata_and_metadata_policy_NOT_OK(policy): "contacts": ["rp_admins@cs.example.com"], "redirect_uris": ["https://cs.example.com/rp1"], "response_types": ["code", "code id_token", "id_token"], - "scopes": ["openid", "eduperson", "email", "address"] + "scopes": ["openid", "eduperson", "email", "address"], } def test_set_equality(): - comb_policy = combine({'metadata_policy': POLICY_1, 'metadata': {}}, - {'metadata_policy': POLICY_2, 'metadata': {}}) + comb_policy = combine( + {"metadata_policy": POLICY_1, "metadata": {}}, + {"metadata_policy": POLICY_2, "metadata": {}}, + ) res = TrustChainPolicy().apply_policy(ENT, comb_policy) - assert set(res['scopes']) == {"openid", "eduperson"} - assert set(res['response_types']) == {"code", "code id_token"} + assert set(res["scopes"]) == {"openid", "eduperson"} + assert set(res["response_types"]) == {"code", "code id_token"} diff --git a/pyeudiw/tests/federation/test_policy.py b/pyeudiw/tests/federation/test_policy.py index fc7c6f1f..48bfe7e2 100644 --- a/pyeudiw/tests/federation/test_policy.py +++ b/pyeudiw/tests/federation/test_policy.py @@ -1,32 +1,20 @@ - -from pyeudiw.federation.policy import ( - do_sub_one_super_add, do_value -) - from pyeudiw.federation.exceptions import PolicyError +from pyeudiw.federation.policy import do_sub_one_super_add, do_value def test_do_sub_one_super_add_subset_of(): - SUPERIOR = { - "subset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"subset_of": set(["test_a", "test_b"])} - CHILD = { - "subset_of": set(["test_a", "test_d"]) - } + CHILD = {"subset_of": set(["test_a", "test_d"])} policy = do_sub_one_super_add(SUPERIOR, CHILD, "subset_of") - assert policy == ['test_a'] + assert policy == ["test_a"] def test_do_sub_one_super_add_subset_of_fail(): - SUPERIOR = { - "subset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"subset_of": set(["test_a", "test_b"])} - CHILD = { - "subset_of": set(["test_q", "test_d"]) - } + CHILD = {"subset_of": set(["test_q", "test_d"])} try: do_sub_one_super_add(SUPERIOR, CHILD, "subset_of") @@ -35,26 +23,18 @@ def test_do_sub_one_super_add_subset_of_fail(): def test_do_sub_one_super_add_combine_superset_of(): - SUPERIOR = { - "superset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"superset_of": set(["test_a", "test_b"])} - CHILD = { - "superset_of": set(["test_a", "test_d"]) - } + CHILD = {"superset_of": set(["test_a", "test_d"])} policy = do_sub_one_super_add(SUPERIOR, CHILD, "superset_of") - assert policy == ['test_a'] + assert policy == ["test_a"] def test_do_superset_of_fail(): - SUPERIOR = { - "superset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"superset_of": set(["test_a", "test_b"])} - CHILD = { - "superset_of": set(["test_q", "test_d"]) - } + CHILD = {"superset_of": set(["test_q", "test_d"])} try: do_sub_one_super_add(SUPERIOR, CHILD, "superset_of") @@ -63,26 +43,18 @@ def test_do_superset_of_fail(): def test_do_value_superset_of(): - SUPERIOR = { - "superset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"superset_of": set(["test_a", "test_b"])} - CHILD = { - "superset_of": set(["test_a", "test_b"]) - } + CHILD = {"superset_of": set(["test_a", "test_b"])} policy = do_value(SUPERIOR, CHILD, "superset_of") assert policy == set(["test_a", "test_b"]) def test_do_value_superset_of_fail(): - SUPERIOR = { - "superset_of": set(["test_a", "test_b"]) - } + SUPERIOR = {"superset_of": set(["test_a", "test_b"])} - CHILD = { - "superset_of": set(["test_q", "test_d"]) - } + CHILD = {"superset_of": set(["test_q", "test_d"])} try: do_value(SUPERIOR, CHILD, "superset_of") diff --git a/pyeudiw/tests/federation/test_schema.py b/pyeudiw/tests/federation/test_schema.py index ad288d1a..4eff43e0 100644 --- a/pyeudiw/tests/federation/test_schema.py +++ b/pyeudiw/tests/federation/test_schema.py @@ -1,7 +1,9 @@ - -from pyeudiw.tools.utils import iat_now, exp_from_now -from pyeudiw.federation import is_es, is_ec -from pyeudiw.federation.exceptions import InvalidEntityStatement, InvalidEntityConfiguration +from pyeudiw.federation.exceptions import ( + InvalidEntityConfiguration, + InvalidEntityStatement, +) +from pyeudiw.federation.utils import is_es +from pyeudiw.tools.utils import exp_from_now, iat_now NOW = iat_now() EXP = exp_from_now(5) @@ -11,18 +13,18 @@ "iat": NOW, "iss": "https://trust-anchor.example.eu", "sub": "https://intermediate.eidas.example.org", - 'jwks': {"keys": []}, - "source_endpoint": "https://rp.example.it" + "jwks": {"keys": []}, + "source_endpoint": "https://rp.example.it", } ta_ec = { "exp": EXP, "iat": NOW, - 'iss': 'https://registry.eidas.trust-anchor.example.eu/', - 'sub': 'https://registry.eidas.trust-anchor.example.eu/', - 'jwks': {"keys": []}, - 'metadata': { - 'wallet_relying_party': { + "iss": "https://registry.eidas.trust-anchor.example.eu/", + "sub": "https://registry.eidas.trust-anchor.example.eu/", + "jwks": {"keys": []}, + "metadata": { + "openid_credential_verifier": { "application_type": "web", "client_id": "https://rp.example.it", "client_name": "Name of an example organization", @@ -34,62 +36,56 @@ "n": "1Ta-sE …", "e": "AQAB", "kid": "YhNFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", - "x5c": ["..."] + "x5c": ["..."], } ] }, - 'contacts': [], - 'request_uris': [], - 'redirect_uris': [], - 'default_acr_values': [], - 'authorization_signed_response_alg': ['RS256'], - 'authorization_encrypted_response_alg': ["RSA-OAEP"], - 'authorization_encrypted_response_enc': ["A128CBC-HS256"], - 'subject_type': '', - 'require_auth_time': True, - 'id_token_encrypted_response_alg': ["RSA-OAEP"], - 'id_token_encrypted_response_enc': ["A128CBC-HS256"], - 'id_token_signed_response_alg': ["ES256"], - 'default_max_age': 5000, + "contacts": [], + "request_uris": [], + "redirect_uris": [], + "default_acr_values": [], + "authorization_signed_response_alg": ["RS256"], + "authorization_encrypted_response_alg": ["RSA-OAEP"], + "authorization_encrypted_response_enc": ["A128CBC-HS256"], + "subject_type": "", + "require_auth_time": True, + "id_token_encrypted_response_alg": ["RSA-OAEP"], + "id_token_encrypted_response_enc": ["A128CBC-HS256"], + "id_token_signed_response_alg": ["ES256"], + "default_max_age": 5000, "vp_formats": { "vc+sd-jwt": { - "sd-jwt_alg_values": [ - "ES256", - "ES384" - ], - "kb-jwt_alg_values": [ - "ES256", - "ES384" - ] + "sd-jwt_alg_values": ["ES256", "ES384"], + "kb-jwt_alg_values": ["ES256", "ES384"], } }, - 'policy_uri': '' + "policy_uri": "", }, - 'federation_entity': { - 'organization_name': 'example TA', - 'contacts': ['tech@eidas.trust-anchor.example.eu'], - 'homepage_uri': 'https://registry.eidas.trust-anchor.example.eu/', - 'logo_uri': 'https://registry.eidas.trust-anchor.example.eu/static/svg/logo.svg', - 'policy_uri': 'https://registry.eidas.trust-anchor.example.eu/policy/', - 'federation_fetch_endpoint': 'https://registry.eidas.trust-anchor.example.eu/fetch/', - 'federation_resolve_endpoint': 'https://registry.eidas.trust-anchor.example.eu/resolve/', - 'federation_list_endpoint': 'https://registry.eidas.trust-anchor.example.eu/list/', - 'federation_trust_mark_status_endpoint': 'https://registry.eidas.trust-anchor.example.eu/trust_mark_status/', + "federation_entity": { + "organization_name": "example TA", + "contacts": ["tech@eidas.trust-anchor.example.eu"], + "homepage_uri": "https://registry.eidas.trust-anchor.example.eu/", + "logo_uri": "https://registry.eidas.trust-anchor.example.eu/static/svg/logo.svg", + "policy_uri": "https://registry.eidas.trust-anchor.example.eu/policy/", + "federation_fetch_endpoint": "https://registry.eidas.trust-anchor.example.eu/fetch/", + "federation_resolve_endpoint": "https://registry.eidas.trust-anchor.example.eu/resolve/", + "federation_list_endpoint": "https://registry.eidas.trust-anchor.example.eu/list/", + "federation_trust_mark_status_endpoint": "https://registry.eidas.trust-anchor.example.eu/trust_mark_status/", }, - 'authority_hints': [] + "authority_hints": [], }, - 'trust_marks_issuers': { - 'https://registry.eidas.trust-anchor.example.eu/openid_relying_party/public/': [ - 'https://registry.spid.eidas.trust-anchor.example.eu/', - 'https://public.intermediary.spid.org/' + "trust_marks_issuers": { + "https://registry.eidas.trust-anchor.example.eu/openid_relying_party/public/": [ + "https://registry.spid.eidas.trust-anchor.example.eu/", + "https://public.intermediary.spid.org/", + ], + "https://registry.eidas.trust-anchor.example.eu/openid_relying_party/private/": [ + "https://registry.spid.eidas.trust-anchor.example.eu/", + "https://private.other.intermediary.org/", ], - 'https://registry.eidas.trust-anchor.example.eu/openid_relying_party/private/': [ - 'https://registry.spid.eidas.trust-anchor.example.eu/', - 'https://private.other.intermediary.org/' - ] }, - 'constraints': {'max_path_length': 1}, - 'authority_hints': [], + "constraints": {"max_path_length": 1}, + "authority_hints": [], } @@ -102,14 +98,3 @@ def test_is_es_false(): is_es(ta_ec) except InvalidEntityStatement: pass - - -def test_is_ec(): - is_ec(ta_ec) - - -def test_is_ec_false(): - try: - is_ec(ta_es) - except InvalidEntityConfiguration: - pass diff --git a/pyeudiw/tests/federation/test_static_trust_chain_validator.py b/pyeudiw/tests/federation/test_static_trust_chain_validator.py index 05fddef4..53955510 100644 --- a/pyeudiw/tests/federation/test_static_trust_chain_validator.py +++ b/pyeudiw/tests/federation/test_static_trust_chain_validator.py @@ -1,48 +1,46 @@ import copy -import uuid import unittest.mock as mock +import uuid from unittest.mock import Mock -from pyeudiw.federation.trust_chain_validator import StaticTrustChainValidator + import pyeudiw.federation.trust_chain_validator as tcv from pyeudiw.federation.exceptions import HttpError - +from pyeudiw.federation.trust_chain_validator import StaticTrustChainValidator from pyeudiw.tests.settings import httpc_params - -from . base import ( +from .base import ( EXP, JWS, NOW, intermediate_es_wallet, intermediate_es_wallet_signed, intermediate_jwk, - leaf_wallet_signed, leaf_wallet_jwk, + leaf_wallet_signed, ta_es, ta_es_signed, ta_jwk, - trust_chain_wallet + trust_chain_wallet, ) - trust_anchor_example = "https://trust-anchor.example.org" intermediate_example = "https://intermediate.eidas.example.org" def test_is_valid(): assert StaticTrustChainValidator( - trust_chain_wallet, [ta_jwk.serialize()], httpc_params=httpc_params).is_valid + trust_chain_wallet, [ta_jwk.serialize()], httpc_params=httpc_params + ).is_valid invalid_intermediate = copy.deepcopy(intermediate_es_wallet) invalid_leaf_jwk = copy.deepcopy(leaf_wallet_jwk.serialize()) invalid_leaf_jwk["kid"] = str(uuid.uuid4()) -invalid_intermediate["jwks"]['keys'] = [invalid_leaf_jwk] +invalid_intermediate["jwks"]["keys"] = [invalid_leaf_jwk] intermediate_signer = JWS( - invalid_intermediate, alg="ES256", - typ="application/entity-statement+jwt" + invalid_intermediate, alg="ES256", typ="application/entity-statement+jwt" ) invalid_intermediate_es_wallet_signed = intermediate_signer.sign_compact( [intermediate_jwk] @@ -51,7 +49,7 @@ def test_is_valid(): invalid_trust_chain = [ leaf_wallet_signed, invalid_intermediate_es_wallet_signed, - ta_es_signed + ta_es_signed, ] @@ -64,7 +62,8 @@ def test_is_valid_equals_false(): def test_retrieve_ec_fails(): try: StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_ec(trust_anchor_example) + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params + )._retrieve_ec(trust_anchor_example) except HttpError: return @@ -72,22 +71,34 @@ def test_retrieve_ec_fails(): def test_retrieve_ec(): tcv.get_entity_configurations = Mock(return_value=[leaf_wallet_signed]) - assert tcv.StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_ec(trust_anchor_example) == leaf_wallet_signed + assert ( + tcv.StaticTrustChainValidator( + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params + )._retrieve_ec(trust_anchor_example) + == leaf_wallet_signed + ) def test_retrieve_es(): tcv.get_entity_statements = Mock(return_value=[ta_es]) - assert tcv.StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_es(trust_anchor_example, trust_anchor_example) == ta_es + assert ( + tcv.StaticTrustChainValidator( + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params + )._retrieve_es(trust_anchor_example, trust_anchor_example) + == ta_es + ) def test_retrieve_es_output_is_none(): tcv.get_entity_statements = Mock(return_value=[None]) - assert tcv.StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._retrieve_es(trust_anchor_example, trust_anchor_example) is None + assert ( + tcv.StaticTrustChainValidator( + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params + )._retrieve_es(trust_anchor_example, trust_anchor_example) + is None + ) def test_update_st_ec_case(): @@ -98,8 +109,12 @@ def mock_method(*args, **kwargs): # raise Exception("Wrong issuer") with mock.patch.object(tcv, "get_entity_configurations", mock_method): - assert tcv.StaticTrustChainValidator( - invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params)._update_st(leaf_wallet_signed) == leaf_wallet_signed + assert ( + tcv.StaticTrustChainValidator( + invalid_trust_chain, [ta_jwk.serialize()], httpc_params=httpc_params + )._update_st(leaf_wallet_signed) + == leaf_wallet_signed + ) def test_update_st_es_case_source_endpoint(): @@ -108,8 +123,8 @@ def test_update_st_es_case_source_endpoint(): "iat": NOW, "iss": trust_anchor_example, "sub": intermediate_example, - 'jwks': {"keys": []}, - "source_endpoint": trust_anchor_example+"/fetch" + "jwks": {"keys": []}, + "source_endpoint": trust_anchor_example + "/fetch", } ta_signer = JWS(ta_es, alg="ES256", typ="application/entity-statement+jwt") @@ -132,7 +147,7 @@ def test_update_st_es_case_no_source_endpoint(): "iat": NOW, "iss": trust_anchor_example, "sub": intermediate_example, - 'jwks': {"keys": []}, + "jwks": {"keys": []}, } ta_signer = JWS(ta_es, alg="ES256", typ="application/entity-statement+jwt") diff --git a/pyeudiw/tests/federation/test_trust_chain_builder.py b/pyeudiw/tests/federation/test_trust_chain_builder.py index 6fe84cba..bb1b242a 100644 --- a/pyeudiw/tests/federation/test_trust_chain_builder.py +++ b/pyeudiw/tests/federation/test_trust_chain_builder.py @@ -1,18 +1,17 @@ +from unittest.mock import patch + +from pyeudiw.federation.statements import EntityStatement, get_entity_configurations from pyeudiw.federation.trust_chain_builder import TrustChainBuilder -from pyeudiw.federation.statements import get_entity_configurations, EntityStatement from pyeudiw.tests.settings import httpc_params -from . base import ta_ec, leaf_wallet -from . mocked_response import EntityResponseWithIntermediate - -from unittest.mock import patch +from .base import leaf_wallet, ta_ec +from .mocked_response import EntityResponseWithIntermediate @patch("requests.get", return_value=EntityResponseWithIntermediate()) def test_trust_chain_valid_with_intermediaries(self, mocker): - jwt = get_entity_configurations( - [ta_ec["sub"]], httpc_params=httpc_params)[0] + jwt = get_entity_configurations([ta_ec["sub"]], httpc_params=httpc_params)[0] trust_anchor_ec = EntityStatement(jwt, httpc_params=httpc_params) trust_anchor_ec.validate_by_itself() @@ -20,7 +19,7 @@ def test_trust_chain_valid_with_intermediaries(self, mocker): subject=leaf_wallet["sub"], trust_anchor=trust_anchor_ec.sub, trust_anchor_configuration=trust_anchor_ec, - httpc_params=httpc_params + httpc_params=httpc_params, ) trust_chain.start() diff --git a/pyeudiw/tests/jwk/test_jwks.py b/pyeudiw/tests/jwk/test_jwks.py index 6c8d3b71..d1f49896 100644 --- a/pyeudiw/tests/jwk/test_jwks.py +++ b/pyeudiw/tests/jwk/test_jwks.py @@ -1,7 +1,26 @@ from dataclasses import dataclass + from pyeudiw.jwk import JWK from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint +raw_key_2 = { + "crv": "P-256", + "d": "dMCVfcZLPDMInj10w_aQdp-m4jZgwdZjDPwe5nKp-Lw", + "kid": "m_r7iPJLNZmQN5sEbILXr41xjSjSzfa3PgM5yURIh2Y", + "kty": "EC", + "use": "sig", + "x": "PA0jE_-Sxhdon9MGmjpMqlUykAbNIBcRgSvgL0eOoJQ", + "y": "PG-xPWEvEQxljYkBON1vGw9RTtDiDkMsRE1AOSo4ark", +} +raw_key_no_kid = { + "crv": "P-256", + "d": "Sz4XNTXk0JaUs6hoyMMUxCSqe9Jx_ciXyVGQj7JSW50", + "kty": "EC", + "use": "sig", + "x": "qojguJYLuM7ZtGspBfZ2SSrGgTnCgCUzjwUkOyOjGMk", + "y": "uRUCqLQjngS0iBZlhHLEGMqpUAe4AMpmMMr6BUkRD50", +} + def test_find_jwk_by_kid(): @dataclass @@ -11,49 +30,50 @@ class TestCase: expected: dict | None explanation: str - raw_key_1 = {"crv": "P-256", "d": "eTEvyBCxriRg6juv_H4bLRgRkdMaCF91k4bLEsdB2yI", "kid": "adeyyLKVrJyu3CLC9ewDHrobulXBZNOfPYM_4bERHqk", - "kty": "EC", "use": "sig", "x": "--7isDCDQZF7cZL-UrvRCLV5Rfo2Di1gaPX6_5uGalA", "y": "e2svMtnHH4s5dOPg8YhuHw2lEPlnVpkKJO7PGQeMTFw"} - raw_key_2 = {"crv": "P-256", "d": "dMCVfcZLPDMInj10w_aQdp-m4jZgwdZjDPwe5nKp-Lw", "kid": "m_r7iPJLNZmQN5sEbILXr41xjSjSzfa3PgM5yURIh2Y", - "kty": "EC", "use": "sig", "x": "PA0jE_-Sxhdon9MGmjpMqlUykAbNIBcRgSvgL0eOoJQ", "y": "PG-xPWEvEQxljYkBON1vGw9RTtDiDkMsRE1AOSo4ark"} - raw_key_no_kid = {"crv": "P-256", "d": "Sz4XNTXk0JaUs6hoyMMUxCSqe9Jx_ciXyVGQj7JSW50", "kty": "EC", "use": "sig", - "x": "qojguJYLuM7ZtGspBfZ2SSrGgTnCgCUzjwUkOyOjGMk", "y": "uRUCqLQjngS0iBZlhHLEGMqpUAe4AMpmMMr6BUkRD50"} + raw_key_1 = { + "crv": "P-256", + "d": "eTEvyBCxriRg6juv_H4bLRgRkdMaCF91k4bLEsdB2yI", + "kid": "adeyyLKVrJyu3CLC9ewDHrobulXBZNOfPYM_4bERHqk", + "kty": "EC", + "use": "sig", + "x": "--7isDCDQZF7cZL-UrvRCLV5Rfo2Di1gaPX6_5uGalA", + "y": "e2svMtnHH4s5dOPg8YhuHw2lEPlnVpkKJO7PGQeMTFw", + } test_cases: list[TestCase] = [ - TestCase( - jwks=[], - kid="NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8", - expected=None, - explanation="no keys" - ), TestCase( jwks=[raw_key_1], kid=raw_key_1["kid"], expected=raw_key_1, - explanation="one matching key" + explanation="one matching key", ), TestCase( jwks=[raw_key_1, raw_key_2], kid=raw_key_2["kid"], expected=raw_key_2, - explanation="one matching key ot ouf two" + explanation="one matching key ot ouf two", ), - TestCase( - jwks=[raw_key_2], - kid="NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8", - expected=None, - explanation="no matching key" - ), - TestCase( - jwks=[raw_key_no_kid], - kid="NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8", - expected=None, - explanation="no matching on key without explicit kid (note: here kid=thumbprint)" - ) ] for i, case in enumerate(test_cases): obt = find_jwk_by_kid(case.jwks, case.kid) assert obt == case.expected, f"failed case {i}, testcase: {case.expected}" +def test_jwk_not_found(): + try: + find_jwk_by_kid([], "NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8") + except Exception as e: + assert str(e) == "Key with Kid NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8 not found" + + try: + find_jwk_by_kid([raw_key_2], "NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8") + except Exception as e: + assert str(e) == "Key with Kid NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8 not found" + + try: + find_jwk_by_kid([raw_key_no_kid], "NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8") + except Exception as e: + assert str(e) == "Key with Kid NMrR5wD0p-VqbRbR9ej6M16v5Fs7hLXwonO9vhJYsn8 not found" + def test_find_jwk_by_thumbprint(): @dataclass @@ -63,10 +83,24 @@ class TestCase: expected: dict | None explanation: str - raw_key_1 = {"crv": "P-256", "d": "eTEvyBCxriRg6juv_H4bLRgRkdMaCF91k4bLEsdB2yI", "kid": "adeyyLKVrJyu3CLC9ewDHrobulXBZNOfPYM_4bERHqk", - "kty": "EC", "use": "sig", "x": "--7isDCDQZF7cZL-UrvRCLV5Rfo2Di1gaPX6_5uGalA", "y": "e2svMtnHH4s5dOPg8YhuHw2lEPlnVpkKJO7PGQeMTFw"} - raw_key_2 = {"crv": "P-256", "d": "dMCVfcZLPDMInj10w_aQdp-m4jZgwdZjDPwe5nKp-Lw", "kid": "m_r7iPJLNZmQN5sEbILXr41xjSjSzfa3PgM5yURIh2Y", - "kty": "EC", "use": "sig", "x": "PA0jE_-Sxhdon9MGmjpMqlUykAbNIBcRgSvgL0eOoJQ", "y": "PG-xPWEvEQxljYkBON1vGw9RTtDiDkMsRE1AOSo4ark"} + raw_key_1 = { + "crv": "P-256", + "d": "eTEvyBCxriRg6juv_H4bLRgRkdMaCF91k4bLEsdB2yI", + "kid": "adeyyLKVrJyu3CLC9ewDHrobulXBZNOfPYM_4bERHqk", + "kty": "EC", + "use": "sig", + "x": "--7isDCDQZF7cZL-UrvRCLV5Rfo2Di1gaPX6_5uGalA", + "y": "e2svMtnHH4s5dOPg8YhuHw2lEPlnVpkKJO7PGQeMTFw", + } + raw_key_2 = { + "crv": "P-256", + "d": "dMCVfcZLPDMInj10w_aQdp-m4jZgwdZjDPwe5nKp-Lw", + "kid": "m_r7iPJLNZmQN5sEbILXr41xjSjSzfa3PgM5yURIh2Y", + "kty": "EC", + "use": "sig", + "x": "PA0jE_-Sxhdon9MGmjpMqlUykAbNIBcRgSvgL0eOoJQ", + "y": "PG-xPWEvEQxljYkBON1vGw9RTtDiDkMsRE1AOSo4ark", + } # expected values obtained using an online calculator raw_thumprint_1 = b"adeyyLKVrJyu3CLC9ewDHrobulXBZNOfPYM_4bERHqk" raw_thumprint_2 = b"m_r7iPJLNZmQN5sEbILXr41xjSjSzfa3PgM5yURIh2Y" @@ -80,60 +114,50 @@ class TestCase: test_cases: list[TestCase] = [ TestCase( - jwks=[ - raw_key_1 - ], + jwks=[raw_key_1], thumbrpint=raw_thumprint_1, expected=raw_key_1, - explanation="one matching key" + explanation="one matching key", ), TestCase( - jwks=[ - raw_key_2, - raw_key_1 - ], + jwks=[raw_key_2, raw_key_1], thumbrpint=raw_thumprint_1, expected=raw_key_1, - explanation="one matching key out of two" + explanation="one matching key out of two", ), TestCase( - jwks=[], - thumbrpint=raw_thumprint_1, - expected=None, - explanation="no key" + jwks=[], thumbrpint=raw_thumprint_1, expected=None, explanation="no key" ), TestCase( - jwks=[ - raw_key_1 - ], + jwks=[raw_key_1], thumbrpint=raw_thumprint_2, expected=None, - explanation="no matching key" + explanation="no matching key", ), TestCase( jwks=[auto_key_1], thumbrpint=auto_thumprint_1, expected=auto_key_1, - explanation="one matching autorgenerated ECDAS key" + explanation="one matching autorgenerated ECDAS key", ), TestCase( jwks=[auto_key_2], thumbrpint=auto_thumprint_2, expected=auto_key_2, - explanation="one matching autorgenerated RSA key" + explanation="one matching autorgenerated RSA key", ), TestCase( jwks=[raw_key_1, raw_key_2, auto_key_1, auto_key_2], thumbrpint=auto_thumprint_1, expected=auto_key_1, - explanation="generic matching test" + explanation="generic matching test", ), TestCase( jwks=[raw_key_2, auto_key_1, auto_key_2], thumbrpint=raw_thumprint_1, expected=None, - explanation="generic non matching test" - ) + explanation="generic non matching test", + ), ] for i, case in enumerate(test_cases): obt = find_jwk_by_thumbprint(case.jwks, case.thumbrpint) diff --git a/pyeudiw/tests/jwk/test_parse.py b/pyeudiw/tests/jwk/test_parse.py index ab4392de..c29b6b91 100644 --- a/pyeudiw/tests/jwk/test_parse.py +++ b/pyeudiw/tests/jwk/test_parse.py @@ -6,13 +6,13 @@ def test_parse_key_from_x5c(): x5c = [ "MIIE3jCCA8agAwIBAgICAwEwDQYJKoZIhvcNAQEFBQAwYzELMAkGA1UEBhMCVVMxITAfBgNVBAoTGFRoZSBHbyBEYWRkeSBHcm91cCwgSW5jLjExMC8GA1UECxMoR28gRGFkZHkgQ2xhc3MgMiBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAeFw0wNjExMTYwMTU0MzdaFw0yNjExMTYwMTU0MzdaMIHKMQswCQYDVQQGEwJVUzEQMA4GA1UECBMHQXJpem9uYTETMBEGA1UEBxMKU2NvdHRzZGFsZTEaMBgGA1UEChMRR29EYWRkeS5jb20sIEluYy4xMzAxBgNVBAsTKmh0dHA6Ly9jZXJ0aWZpY2F0ZXMuZ29kYWRkeS5jb20vcmVwb3NpdG9yeTEwMC4GA1UEAxMnR28gRGFkZHkgU2VjdXJlIENlcnRpZmljYXRpb24gQXV0aG9yaXR5MREwDwYDVQQFEwgwNzk2OTI4NzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMQt1RWMnCZM7DI161+4WQFapmGBWTtwY6vj3D3HKrjJM9N55DrtPDAjhI6zMBS2sofDPZVUBJ7fmd0LJR4h3mUpfjWoqVTr9vcyOdQmVZWt7/v+WIbXnvQAjYwqDL1CBM6nPwT27oDyqu9SoWlm2r4arV3aLGbqGmu75RpRSgAvSMeYddi5Kcju+GZtCpyz8/x4fKL4o/K1w/O5epHBp+YlLpyo7RJlbmr2EkRTcDCVw5wrWCs9CHRK8r5RsL+H0EwnWGu1NcWdrxcx+AuP7q2BNgWJCJjPOq8lh8BJ6qf9Z/dFjpfMFDniNoW1fho3/Rb2cRGadDAW/hOUoz+EDU8CAwEAAaOCATIwggEuMB0GA1UdDgQWBBT9rGEyk2xF1uLuhV+auud2mWjM5zAfBgNVHSMEGDAWgBTSxLDSkdRMEXGzYcs9of7dqGrU4zASBgNVHRMBAf8ECDAGAQH/AgEAMDMGCCsGAQUFBwEBBCcwJTAjBggrBgEFBQcwAYYXaHR0cDovL29jc3AuZ29kYWRkeS5jb20wRgYDVR0fBD8wPTA7oDmgN4Y1aHR0cDovL2NlcnRpZmljYXRlcy5nb2RhZGR5LmNvbS9yZXBvc2l0b3J5L2dkcm9vdC5jcmwwSwYDVR0gBEQwQjBABgRVHSAAMDgwNgYIKwYBBQUHAgEWKmh0dHA6Ly9jZXJ0aWZpY2F0ZXMuZ29kYWRkeS5jb20vcmVwb3NpdG9yeTAOBgNVHQ8BAf8EBAMCAQYwDQYJKoZIhvcNAQEFBQADggEBANKGwOy9+aG2Z+5mC6IGOgRQjhVyrEp0lVPLN8tESe8HkGsz2ZbwlFalEzAFPIUyIXvJxwqoJKSQ3kbTJSMUA2fCENZvD117esyfxVgqwcSeIaha86ykRvOe5GPLL5CkKSkB2XIsKd83ASe8T+5o0yGPwLPk9Qnt0hCqU7S+8MxZC9Y7lhyVJEnfzuz9p0iRFEUOOjZv2kWzRaJBydTXRE4+uXR21aITVSzGh6O1mawGhId/dQb8vxRMDsxuxN89txJx9OjxUUAiKEngHUuHqDTMBqLdElrRhjZkAzVvb3du6/KFUJheqwNTrZEjYx8WnM25sgVjOuH0aBsXBTWVU+4=", "MIIE+zCCBGSgAwIBAgICAQ0wDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTA0MDYyOTE3MDYyMFoXDTI0MDYyOTE3MDYyMFowYzELMAkGA1UEBhMCVVMxITAfBgNVBAoTGFRoZSBHbyBEYWRkeSBHcm91cCwgSW5jLjExMC8GA1UECxMoR28gRGFkZHkgQ2xhc3MgMiBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTCCASAwDQYJKoZIhvcNAQEBBQADggENADCCAQgCggEBAN6d1+pXGEmhW+vXX0iG6r7d/+TvZxz0ZWizV3GgXne77ZtJ6XCAPVYYYwhv2vLM0D9/AlQiVBDYsoHUwHU9S3/Hd8M+eKsaA7Ugay9qK7HFiH7Eux6wwdhFJ2+qN1j3hybX2C32qRe3H3I2TqYXP2WYktsqbl2i/ojgC95/5Y0V4evLOtXiEqITLdiOr18SPaAIBQi2XKVlOARFmR6jYGB0xUGlcmIbYsUfb18aQr4CUWWoriMYavx4A6lNf4DD+qta/KFApMoZFv6yyO9ecw3ud72a9nmYvLEHZ6IVDd2gWMZEewo+YihfukEHU1jPEX44dMX4/7VpkI+EdOqXG68CAQOjggHhMIIB3TAdBgNVHQ4EFgQU0sSw0pHUTBFxs2HLPaH+3ahq1OMwgdIGA1UdIwSByjCBx6GBwaSBvjCBuzEkMCIGA1UEBxMbVmFsaUNlcnQgVmFsaWRhdGlvbiBOZXR3b3JrMRcwFQYDVQQKEw5WYWxpQ2VydCwgSW5jLjE1MDMGA1UECxMsVmFsaUNlcnQgQ2xhc3MgMiBQb2xpY3kgVmFsaWRhdGlvbiBBdXRob3JpdHkxITAfBgNVBAMTGGh0dHA6Ly93d3cudmFsaWNlcnQuY29tLzEgMB4GCSqGSIb3DQEJARYRaW5mb0B2YWxpY2VydC5jb22CAQEwDwYDVR0TAQH/BAUwAwEB/zAzBggrBgEFBQcBAQQnMCUwIwYIKwYBBQUHMAGGF2h0dHA6Ly9vY3NwLmdvZGFkZHkuY29tMEQGA1UdHwQ9MDswOaA3oDWGM2h0dHA6Ly9jZXJ0aWZpY2F0ZXMuZ29kYWRkeS5jb20vcmVwb3NpdG9yeS9yb290LmNybDBLBgNVHSAERDBCMEAGBFUdIAAwODA2BggrBgEFBQcCARYqaHR0cDovL2NlcnRpZmljYXRlcy5nb2RhZGR5LmNvbS9yZXBvc2l0b3J5MA4GA1UdDwEB/wQEAwIBBjANBgkqhkiG9w0BAQUFAAOBgQC1QPmnHfbq/qQaQlpE9xXUhUaJwL6e4+PrxeNYiY+Sn1eocSxI0YGyeR+sBjUZsE4OWBsUs5iB0QQeyAfJg594RAoYC5jcdnplDQ1tgMQLARzLrUc+cb53S8wGd9D0VmsfSxOaFIqII6hR8INMqzW/Rn453HWkrugp++85j09VZw==", - "MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMTk1NFoXDTE5MDYyNjAwMTk1NFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDOOnHK5avIWZJV16vYdA757tn2VUdZZUcOBVXc65g2PFxTXdMwzzjsvUGJ7SVCCSRrCl6zfN1SLUzm1NZ9WlmpZdRJEy0kTRxQb7XBhVQ7/nHk01xC+YDgkRoKWzk2Z/M/VXwbP7RfZHM047QSv4dk+NoS/zcnwbNDu+97bi5p9wIDAQABMA0GCSqGSIb3DQEBBQUAA4GBADt/UG9vUJSZSWI4OB9L+KXIPqeCgfYrx+jFzug6EILLGACOTb2oWH+heQC1u+mNr0HZDzTuIYEZoDJJKPTEjlbVUjP9UNV+mWwD5MlM/Mtsq2azSiGM5bUMMj4QssxsodyamEwCW/POuZ6lcg5Ktz885hZo+L7tdEy8W9ViH0Pd" + "MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMTk1NFoXDTE5MDYyNjAwMTk1NFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDOOnHK5avIWZJV16vYdA757tn2VUdZZUcOBVXc65g2PFxTXdMwzzjsvUGJ7SVCCSRrCl6zfN1SLUzm1NZ9WlmpZdRJEy0kTRxQb7XBhVQ7/nHk01xC+YDgkRoKWzk2Z/M/VXwbP7RfZHM047QSv4dk+NoS/zcnwbNDu+97bi5p9wIDAQABMA0GCSqGSIb3DQEBBQUAA4GBADt/UG9vUJSZSWI4OB9L+KXIPqeCgfYrx+jFzug6EILLGACOTb2oWH+heQC1u+mNr0HZDzTuIYEZoDJJKPTEjlbVUjP9UNV+mWwD5MlM/Mtsq2azSiGM5bUMMj4QssxsodyamEwCW/POuZ6lcg5Ktz885hZo+L7tdEy8W9ViH0Pd", ] # these values are hand crafted from x5c[0] exp_key = { "kty": "RSA", "e": "AQAB", - "n": "xC3VFYycJkzsMjXrX7hZAVqmYYFZO3Bjq-PcPccquMkz03nkOu08MCOEjrMwFLayh8M9lVQEnt-Z3QslHiHeZSl-NaipVOv29zI51CZVla3v-_5Yhtee9ACNjCoMvUIEzqc_BPbugPKq71KhaWbavhqtXdosZuoaa7vlGlFKAC9Ix5h12LkpyO74Zm0KnLPz_Hh8ovij8rXD87l6kcGn5iUunKjtEmVuavYSRFNwMJXDnCtYKz0IdEryvlGwv4fQTCdYa7U1xZ2vFzH4C4_urYE2BYkImM86ryWHwEnqp_1n90WOl8wUOeI2hbV-Gjf9FvZxEZp0MBb-E5SjP4QNTw" + "n": "xC3VFYycJkzsMjXrX7hZAVqmYYFZO3Bjq-PcPccquMkz03nkOu08MCOEjrMwFLayh8M9lVQEnt-Z3QslHiHeZSl-NaipVOv29zI51CZVla3v-_5Yhtee9ACNjCoMvUIEzqc_BPbugPKq71KhaWbavhqtXdosZuoaa7vlGlFKAC9Ix5h12LkpyO74Zm0KnLPz_Hh8ovij8rXD87l6kcGn5iUunKjtEmVuavYSRFNwMJXDnCtYKz0IdEryvlGwv4fQTCdYa7U1xZ2vFzH4C4_urYE2BYkImM86ryWHwEnqp_1n90WOl8wUOeI2hbV-Gjf9FvZxEZp0MBb-E5SjP4QNTw", } obt_key = parse_key_from_x5c(x5c).as_dict() assert exp_key["kty"] == obt_key["kty"] diff --git a/pyeudiw/tests/jwk/test_schema.py b/pyeudiw/tests/jwk/test_schema.py index f3d997fc..1a36f17a 100644 --- a/pyeudiw/tests/jwk/test_schema.py +++ b/pyeudiw/tests/jwk/test_schema.py @@ -1,4 +1,4 @@ -from pyeudiw.jwk.schemas.public import JwkSchema, ECJwkSchema, RSAJwkSchema +from pyeudiw.jwk.schemas.public import ECJwkSchema, JwkSchema, RSAJwkSchema def test_valid_rsa_jwk(): @@ -29,12 +29,7 @@ def test_valid_ec_jwk(): def test_invalid_keys(): # table with keys that should fail jwk parsing bad_keys_table: list[tuple[dict, str]] = [ - ( - { - "aaaa": "1" - }, - "non-sense key" - ), + ({"aaaa": "1"}, "non-sense key"), ( { "kty": "RSA", @@ -42,7 +37,7 @@ def test_invalid_keys(): "alg": "RS256", "kid": "2011-04-29", }, - "rsa key with missing attribute [n]" + "rsa key with missing attribute [n]", ), ( { @@ -53,7 +48,7 @@ def test_invalid_keys(): "alg": "RS256", "kid": "2011-04-29", }, - "rsa key with unexpected attribute [x]" + "rsa key with unexpected attribute [x]", ), ( { @@ -63,7 +58,7 @@ def test_invalid_keys(): "use": "enc", "kid": "1", }, - "ec key with missing attribute [crv]" + "ec key with missing attribute [crv]", ), ( { @@ -73,7 +68,7 @@ def test_invalid_keys(): "use": "enc", "kid": "1", }, - "ec key with missing attribute [x]" + "ec key with missing attribute [x]", ), ( { @@ -85,8 +80,8 @@ def test_invalid_keys(): "use": "enc", "kid": "1", }, - "ec key with unexpected attribute [e]" - ) + "ec key with unexpected attribute [e]", + ), ] for i, (bad_key, reason) in enumerate(bad_keys_table): try: diff --git a/pyeudiw/tests/jwt/test_helper.py b/pyeudiw/tests/jwt/test_helper.py index b71c8ce9..aa4e7fbe 100644 --- a/pyeudiw/tests/jwt/test_helper.py +++ b/pyeudiw/tests/jwt/test_helper.py @@ -4,52 +4,47 @@ def test_validate_jwt_timestamps_claims_ok(): now = iat_now() - payload = { - "iat": now - 10, - "nbf": now - 10, - "exp": now + 9999 - } + payload = {"iat": now - 10, "nbf": now - 10, "exp": now + 9999} try: validate_jwt_timestamps_claims(payload) except Exception as e: - assert True, f"encountered unexpeted error when validating the lifetime of a good token payload: {e}" + assert ( + True + ), f"encountered unexpeted error when validating the lifetime of a good token payload: {e}" def test_validate_jwt_timestamps_claims_bad_iat(): now = iat_now() - payload = { - "iat": now + 100, - "exp": now + 9999 - } + payload = {"iat": now + 100, "exp": now + 9999} try: validate_jwt_timestamps_claims(payload) - assert False, "failed to raise exception when validating a token payload with bad iat" + assert ( + False + ), "failed to raise exception when validating a token payload with bad iat" except Exception: pass def test_validate_jwt_timestamps_claims_bad_nbf(): now = iat_now() - payload = { - "nbf": now + 100, - "exp": now + 9999 - } + payload = {"nbf": now + 100, "exp": now + 9999} try: validate_jwt_timestamps_claims(payload) - assert False, "failed to raise exception when validating a token payload with bad nbf" + assert ( + False + ), "failed to raise exception when validating a token payload with bad nbf" except Exception: pass def test_validate_jwt_timestamps_claims_bad_exp(): now = iat_now() - payload = { - "iat": now - 100, - "exp": now - 10 - } + payload = {"iat": now - 100, "exp": now - 10} try: validate_jwt_timestamps_claims(payload) - assert False, "failed to raise exception when validating a token payload with bad exp" + assert ( + False + ), "failed to raise exception when validating a token payload with bad exp" except Exception: pass @@ -59,24 +54,20 @@ def test_test_validate_jwt_timestamps_claims_tolerance_window(): # case 0: tolerance window covers a token issuer "slightly" in the future now = iat_now() - payload = { - "iat": now + 15, - "nbf": now + 15, - "exp": now + 9999 - } + payload = {"iat": now + 15, "nbf": now + 15, "exp": now + 9999} try: validate_jwt_timestamps_claims(payload, tolerance_window) except Exception as e: - assert False, f"encountered unexpeted error when validating the lifetime of a token payload with a tolerance window (for iat, nbf): {e}" + assert ( + False + ), f"encountered unexpeted error when validating the lifetime of a token payload with a tolerance window (for iat, nbf): {e}" # case 1: tolerance window covers a token "slightly" expired now = iat_now() - payload = { - "iat": now - 100, - "nbf": now - 100, - "exp": now - 15 - } + payload = {"iat": now - 100, "nbf": now - 100, "exp": now - 15} try: validate_jwt_timestamps_claims(payload, tolerance_window) except Exception as e: - assert False, f"encountered unexpeted error when validating the lifetime of a token payload with a tolerance window (for exp): {e}" + assert ( + False + ), f"encountered unexpeted error when validating the lifetime of a token payload with a tolerance window (for exp): {e}" diff --git a/pyeudiw/tests/jwt/test_parse.py b/pyeudiw/tests/jwt/test_parse.py index c5bcb14e..ff7c3b6d 100644 --- a/pyeudiw/tests/jwt/test_parse.py +++ b/pyeudiw/tests/jwt/test_parse.py @@ -6,15 +6,11 @@ def test_kid_jwt(): decoded_jwt = DecodedJwt.parse(VALID_KID_JWT) assert decoded_jwt.jwt == VALID_KID_JWT - assert decoded_jwt.header == { - "kid": "123456", - "alg": "HS256", - "typ": "JWT" - } + assert decoded_jwt.header == {"kid": "123456", "alg": "HS256", "typ": "JWT"} assert decoded_jwt.payload == { "sub": "1234567890", "name": "John Doe", - "iat": 1516239022 + "iat": 1516239022, } assert decoded_jwt.signature == "bjM57L1H4gB60_020lKBVvVEhiYCOeEWGzMVEt-XNjc" @@ -28,10 +24,10 @@ def test_tc_jwt(): "eyJhbGciOiJFUzI1NiIsImtpZCI6ImFrNVBOMGR1WjNCeVlVUkVNMWszWm1RM1RFVnlOSEowWTJWVFlUWk1TSFI0VWsxSVExQk9USEpQU1EiLCJ0eXAiOiJlbnRpdHktc3RhdGVtZW50K2p3dCJ9.eyJleHAiOjE3Mjk5MDQzNDIsImlhdCI6MTcyOTYwNDM0MiwiaXNzIjoiaHR0cHM6Ly9jcmVkZW50aWFsX2lzc3Vlci5leGFtcGxlLm9yZyIsInN1YiI6Imh0dHBzOi8vY3JlZGVudGlhbF9pc3N1ZXIuZXhhbXBsZS5vcmciLCJqd2tzIjp7ImtleXMiOlt7Imt0eSI6IkVDIiwia2lkIjoiYWs1UE4wZHVaM0J5WVVSRU0xazNabVEzVEVWeU5ISjBZMlZUWVRaTVNIUjRVazFJUTFCT1RISlBTUSIsImFsZyI6IkVTMjU2IiwiY3J2IjoiUC0yNTYiLCJ4IjoiYjBIcmV6bTVxN1MzUE96ZVNobU9WRjJVV18zbnJvR0RNWnBaeFhlS1B0USIsInkiOiItME9HV0xnOGNoaVItQndPQ2pZeng1Mm1MZlE1b3BSVjVYQ0lVamlpaVRRIn1dfSwibWV0YWRhdGEiOnsib3BlbmlkX2NyZWRlbnRpYWxfaXNzdWVyIjp7Imp3a3MiOnsia2V5cyI6W3sia3R5IjoiRUMiLCJraWQiOiJNblE0VUdKbmVWUldYMDl5ZWpCUGIyeDBVMU50YUZabFgwMU9PVTlIU0d0MVVpMU5NRE5VV0dsU1JRIiwiYWxnIjoiRVMyNTYiLCJjcnYiOiJQLTI1NiIsIngiOiJ6VHBjNDYxN1dLSUF0UUVXWllYeDFFRjZGOEpnV3ozdHllaHc4MUJ3bG84IiwieSI6ImNITy1DaDZseUUyYmwzMTNrelRhS3JEbC14N3ZXbkU0dkU0VTdWUUF5ak0ifV19fSwiZmVkZXJhdGlvbl9lbnRpdHkiOnsib3JnYW5pemF0aW9uX25hbWUiOiJPcGVuSUQgQ3JlZGVudGlhbCBJc3N1ZXIgZXhhbXBsZSIsImhvbWVwYWdlX3VyaSI6Imh0dHBzOi8vY3JlZGVudGlhbF9pc3N1ZXIuZXhhbXBsZS5vcmcvaG9tZSIsInBvbGljeV91cmkiOiJodHRwczovL2NyZWRlbnRpYWxfaXNzdWVyLmV4YW1wbGUub3JnL3BvbGljeSIsImxvZ29fdXJpIjoiaHR0cHM6Ly9jcmVkZW50aWFsX2lzc3Vlci5leGFtcGxlLm9yZy9zdGF0aWMvbG9nby5zdmciLCJjb250YWN0cyI6WyJ0ZWNoQGNyZWRlbnRpYWxfaXNzdWVyLmV4YW1wbGUub3JnIl19fSwiYXV0aG9yaXR5X2hpbnRzIjpbImh0dHBzOi8vaW50ZXJtZWRpYXRlLmVpZGFzLmV4YW1wbGUub3JnIl19.ke58LCSSFvyi6daoaRR346aF3TCn4lCA86GXHhFa09uVE6Gkt6jUJhB8tFlvvdZberhqbvatoGECPCPeCK26Mw", "eyJhbGciOiJFUzI1NiIsImtpZCI6Ik9UTnRTRTgyVld4YWFqSjFWbGxXWVRSUGJIWkRZblF5UWxwZmQyVmliRU0yVEVwVGVqRk5WWGRSWnciLCJ0eXAiOiJlbnRpdHktc3RhdGVtZW50K2p3dCJ9.eyJleHAiOjE3Mjk5MDQzNDIsImlhdCI6MTcyOTYwNDM0MiwiaXNzIjoiaHR0cHM6Ly9pbnRlcm1lZGlhdGUuZWlkYXMuZXhhbXBsZS5vcmciLCJzdWIiOiJodHRwczovL2NyZWRlbnRpYWxfaXNzdWVyLmV4YW1wbGUub3JnIiwiandrcyI6eyJrZXlzIjpbeyJrdHkiOiJFQyIsImtpZCI6ImFrNVBOMGR1WjNCeVlVUkVNMWszWm1RM1RFVnlOSEowWTJWVFlUWk1TSFI0VWsxSVExQk9USEpQU1EiLCJhbGciOiJFUzI1NiIsImNydiI6IlAtMjU2IiwieCI6ImIwSHJlem01cTdTM1BPemVTaG1PVkYyVVdfM25yb0dETVpwWnhYZUtQdFEiLCJ5IjoiLTBPR1dMZzhjaGlSLUJ3T0NqWXp4NTJtTGZRNW9wUlY1WENJVWppaWlUUSJ9XX19.9m1i9qcDLSnpbwiNbGZJozovRTxhF6Qb-EvSZfYNe7csnhY_auTDKDieYoZBfainYGiHM2xw98-wgkygLV7KHw", "eyJhbGciOiJFUzI1NiIsImtpZCI6IlZtdzJZbGc0TVRrNWNWbHRiRWxHUkROVk16ZzRUV1pPTTBGUWMwNDFjM0JJZFVkc1lYRm9TbVJLTkEiLCJ0eXAiOiJlbnRpdHktc3RhdGVtZW50K2p3dCJ9.eyJleHAiOjE3Mjk5MDQzNDIsImlhdCI6MTcyOTYwNDM0MiwiaXNzIjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmciLCJzdWIiOiJodHRwczovL2ludGVybWVkaWF0ZS5laWRhcy5leGFtcGxlLm9yZyIsImp3a3MiOnsia2V5cyI6W3sia3R5IjoiRUMiLCJraWQiOiJPVE50U0U4MlZXeGFhakoxVmxsV1lUUlBiSFpEWW5ReVFscGZkMlZpYkVNMlRFcFRlakZOVlhkUlp3IiwiYWxnIjoiRVMyNTYiLCJjcnYiOiJQLTI1NiIsIngiOiJrN1RMWVF1SXE5eGNnbGVSd05vYXBGc1Q1eDVjd3B0OExST2d1MEhSZE8wIiwieSI6Ilh4MTBhWnZxeFFrVWxGZUQxdkx1bnhWSndvbGZpUGxqQi1wOXRfY0hLOWMifV19fQ.b7xyGtDp2-ZMWlNBNOjEeUgECL_oP7TQjdHlj2me_Y6js_AeoEhlQ-2eMzWtcuYK4GV8xLGoH7Cln7pFI1OxTg", - "eyJhbGciOiJFUzI1NiIsImtpZCI6IlZtdzJZbGc0TVRrNWNWbHRiRWxHUkROVk16ZzRUV1pPTTBGUWMwNDFjM0JJZFVkc1lYRm9TbVJLTkEiLCJ0eXAiOiJlbnRpdHktc3RhdGVtZW50K2p3dCJ9.eyJleHAiOjE3Mjk5MDQzNDIsImlhdCI6MTcyOTYwNDM0MiwiaXNzIjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmciLCJzdWIiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZyIsImp3a3MiOnsia2V5cyI6W3sia3R5IjoiRUMiLCJraWQiOiJWbXcyWWxnNE1UazVjVmx0YkVsR1JETlZNemc0VFdaT00wRlFjMDQxYzNCSWRVZHNZWEZvU21SS05BIiwiYWxnIjoiRVMyNTYiLCJjcnYiOiJQLTI1NiIsIngiOiJNQmxWX1NmX1N2aWsxWjJ4ZkxkdjJzNkdHbzZuQlpYMUNpQU9WWV9Ca3N3IiwieSI6ImNLdjEwYTRnT2JVNVluaU10ZU1QQTdpZjhwbDRyZ3hTTXJ0bC1WNDBRVHMifV19LCJtZXRhZGF0YSI6eyJmZWRlcmF0aW9uX2VudGl0eSI6eyJmZWRlcmF0aW9uX2ZldGNoX2VuZHBvaW50IjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvZmV0Y2giLCJmZWRlcmF0aW9uX3Jlc29sdmVfZW5kcG9pbnQiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZy9yZXNvbHZlIiwiZmVkZXJhdGlvbl9saXN0X2VuZHBvaW50IjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvbGlzdCIsIm9yZ2FuaXphdGlvbl9uYW1lIjoiVEEgZXhhbXBsZSIsImhvbWVwYWdlX3VyaSI6Imh0dHBzOi8vdHJ1c3QtYW5jaG9yLmV4YW1wbGUub3JnL2hvbWUiLCJwb2xpY3lfdXJpIjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvcG9saWN5IiwibG9nb191cmkiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZy9zdGF0aWMvbG9nby5zdmciLCJjb250YWN0cyI6WyJ0ZWNoQHRydXN0LWFuY2hvci5leGFtcGxlLm9yZyJdfX0sImNvbnN0cmFpbnRzIjp7Im1heF9wYXRoX2xlbmd0aCI6MX19.MbpXfe_NpPgbdWL_zN30SXA88aWrewaJyMWJFAegNrN-8Vy2umcpq3MQph7Yz3ZTawGgi6OGWX7UTDFOWWmf9w" + "eyJhbGciOiJFUzI1NiIsImtpZCI6IlZtdzJZbGc0TVRrNWNWbHRiRWxHUkROVk16ZzRUV1pPTTBGUWMwNDFjM0JJZFVkc1lYRm9TbVJLTkEiLCJ0eXAiOiJlbnRpdHktc3RhdGVtZW50K2p3dCJ9.eyJleHAiOjE3Mjk5MDQzNDIsImlhdCI6MTcyOTYwNDM0MiwiaXNzIjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmciLCJzdWIiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZyIsImp3a3MiOnsia2V5cyI6W3sia3R5IjoiRUMiLCJraWQiOiJWbXcyWWxnNE1UazVjVmx0YkVsR1JETlZNemc0VFdaT00wRlFjMDQxYzNCSWRVZHNZWEZvU21SS05BIiwiYWxnIjoiRVMyNTYiLCJjcnYiOiJQLTI1NiIsIngiOiJNQmxWX1NmX1N2aWsxWjJ4ZkxkdjJzNkdHbzZuQlpYMUNpQU9WWV9Ca3N3IiwieSI6ImNLdjEwYTRnT2JVNVluaU10ZU1QQTdpZjhwbDRyZ3hTTXJ0bC1WNDBRVHMifV19LCJtZXRhZGF0YSI6eyJmZWRlcmF0aW9uX2VudGl0eSI6eyJmZWRlcmF0aW9uX2ZldGNoX2VuZHBvaW50IjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvZmV0Y2giLCJmZWRlcmF0aW9uX3Jlc29sdmVfZW5kcG9pbnQiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZy9yZXNvbHZlIiwiZmVkZXJhdGlvbl9saXN0X2VuZHBvaW50IjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvbGlzdCIsIm9yZ2FuaXphdGlvbl9uYW1lIjoiVEEgZXhhbXBsZSIsImhvbWVwYWdlX3VyaSI6Imh0dHBzOi8vdHJ1c3QtYW5jaG9yLmV4YW1wbGUub3JnL2hvbWUiLCJwb2xpY3lfdXJpIjoiaHR0cHM6Ly90cnVzdC1hbmNob3IuZXhhbXBsZS5vcmcvcG9saWN5IiwibG9nb191cmkiOiJodHRwczovL3RydXN0LWFuY2hvci5leGFtcGxlLm9yZy9zdGF0aWMvbG9nby5zdmciLCJjb250YWN0cyI6WyJ0ZWNoQHRydXN0LWFuY2hvci5leGFtcGxlLm9yZyJdfX0sImNvbnN0cmFpbnRzIjp7Im1heF9wYXRoX2xlbmd0aCI6MX19.MbpXfe_NpPgbdWL_zN30SXA88aWrewaJyMWJFAegNrN-8Vy2umcpq3MQph7Yz3ZTawGgi6OGWX7UTDFOWWmf9w", ], "alg": "HS256", - "typ": "JWT" + "typ": "JWT", } diff --git a/pyeudiw/tests/jwt/test_sign_verify.py b/pyeudiw/tests/jwt/test_sign_verify.py index c5aa7dfc..86fcd980 100644 --- a/pyeudiw/tests/jwt/test_sign_verify.py +++ b/pyeudiw/tests/jwt/test_sign_verify.py @@ -9,10 +9,24 @@ class TestJWSHeperSelectSigningKey: @pytest.fixture def sign_jwks(self): return [ - {"crv": "P-256", "d": "qIVMRJ0ioosFjCFhBw-kLBuip9tV0Y2D6iYD42nCKBA", "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", - "kty": "EC", "use": "sig", "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk"}, - {"crv": "P-256", "d": "SW976Rpuse5crOTbM5yBifa7u1tgw46XlJCJRwon4kA", "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", - "kty": "EC", "use": "sig", "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M"} + { + "crv": "P-256", + "d": "qIVMRJ0ioosFjCFhBw-kLBuip9tV0Y2D6iYD42nCKBA", + "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", + "kty": "EC", + "use": "sig", + "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", + "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk", + }, + { + "crv": "P-256", + "d": "SW976Rpuse5crOTbM5yBifa7u1tgw46XlJCJRwon4kA", + "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", + "kty": "EC", + "use": "sig", + "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", + "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M", + }, ] def test_JWSHelper_select_signing_key_undefined(self, sign_jwks): @@ -42,14 +56,28 @@ def test_JWSHelper_select_signing_key_unique(self, sign_jwks): assert k == exp_k -class TestJWSHelperSignerHeader(): +class TestJWSHelperSignerHeader: @pytest.fixture def sign_jwks(self): return [ - {"crv": "P-256", "d": "qIVMRJ0ioosFjCFhBw-kLBuip9tV0Y2D6iYD42nCKBA", "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", - "kty": "EC", "use": "sig", "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk"}, - {"crv": "P-256", "d": "SW976Rpuse5crOTbM5yBifa7u1tgw46XlJCJRwon4kA", "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", - "kty": "EC", "use": "sig", "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M"} + { + "crv": "P-256", + "d": "qIVMRJ0ioosFjCFhBw-kLBuip9tV0Y2D6iYD42nCKBA", + "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", + "kty": "EC", + "use": "sig", + "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", + "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk", + }, + { + "crv": "P-256", + "d": "SW976Rpuse5crOTbM5yBifa7u1tgw46XlJCJRwon4kA", + "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", + "kty": "EC", + "use": "sig", + "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", + "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M", + }, ] def test_signed_header_add_kid(self, sign_jwks): @@ -72,14 +100,26 @@ def test_signed_header_add_alg(selg, sign_jwks): assert "alg" in dec_header -class TestJWSHelperSelectVerifyingKey(): +class TestJWSHelperSelectVerifyingKey: @pytest.fixture def verify_jwks(self): return [ - {"crv": "P-256", "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", "kty": "EC", "use": "sig", - "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk"}, - {"crv": "P-256", "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", "kty": "EC", "use": "sig", - "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M"} + { + "crv": "P-256", + "kid": "ppBQZHPUTaEPdiLsj99gadhfqLtYMwiU9bmDCfAsWfI", + "kty": "EC", + "use": "sig", + "x": "_336mq5GanihcG_V40tiLDq2sFJ83w-vxaPAZtfCr40", + "y": "CYUM4Q1YlSTTgSp6OnJZt-O4YlzPf430AgVAM0oNlQk", + }, + { + "crv": "P-256", + "kid": "35DgiI1eugPL1QB7sHG826YLLLLGDogvHmDa2jUilas", + "kty": "EC", + "use": "sig", + "x": "RXQ0lfXVXikgi00Yy8Qm2EX83_1JbLTXhyUXj9M21lk", + "y": "xTfCwP-eelZXMBFNKwiEUQaUJeebHWcVgnGyB7fOF1M", + }, ] def test_JWSHelper_select_verifying_key_undefined(self, verify_jwks): @@ -100,11 +140,18 @@ def test_JWSHelper_select_verifying_key_unique(self, verify_jwks): assert k == exp_k -class TestJWSHelperSignVerify(): +class TestJWSHelperSignVerify: @pytest.fixture def signing_key(self): - return {"crv": "P-256", "d": "1Fpynl9yQN88xI_AIkna0PiO0-5y5vUtNwC7rbg-BHE", "kid": "lfnXwtreAr8zgUE9CUFr9rGZsS5f52I7whhfiPr5I1o", - "kty": "EC", "use": "sig", "x": "2I-JeMD_JgNw95NORslAFUElmwMHWbT4uOdDCy99mac", "y": "Oy7Cyg2O_4GsLt475BbD5m71-snr52uMneUUHRiodBY"} + return { + "crv": "P-256", + "d": "1Fpynl9yQN88xI_AIkna0PiO0-5y5vUtNwC7rbg-BHE", + "kid": "lfnXwtreAr8zgUE9CUFr9rGZsS5f52I7whhfiPr5I1o", + "kty": "EC", + "use": "sig", + "x": "2I-JeMD_JgNw95NORslAFUElmwMHWbT4uOdDCy99mac", + "y": "Oy7Cyg2O_4GsLt475BbD5m71-snr52uMneUUHRiodBY", + } def test_JWSHelper_sign_then_verify(self, signing_key): helper = JWSHelper(signing_key) @@ -113,7 +160,7 @@ def test_JWSHelper_sign_then_verify(self, signing_key): "exp": iat_now() + 999, "iss": "token-issuer", "sub": "token-subject", - "aud": "token-audience" + "aud": "token-audience", } token = helper.sign(claims, kid_in_header=True) assert "alg" in decode_jwt_header(token) @@ -132,7 +179,7 @@ def test_JWSHelper_sign_then_verify_clock_skewed(self, signing_key): "exp": iat_now() + 999, "iss": "token-issuer", "sub": "token-subject", - "aud": "token-audience" + "aud": "token-audience", } token = helper.sign(claims, kid_in_header=True) @@ -144,11 +191,13 @@ def test_JWSHelper_sign_then_verify_clock_skewed(self, signing_key): # case 1: using global configured tolerance DEFAULT_TOKEN_TIME_TOLERANCE claims = { - "iat": iat_now() + DEFAULT_TOKEN_TIME_TOLERANCE//2, # oops, issuer clock is slightly skewed! + "iat": iat_now() + + DEFAULT_TOKEN_TIME_TOLERANCE + // 2, # oops, issuer clock is slightly skewed! "exp": iat_now() + 999, "iss": "token-issuer", "sub": "token-subject", - "aud": "token-audience" + "aud": "token-audience", } try: helper.verify(token) diff --git a/pyeudiw/tests/jwt/test_utils.py b/pyeudiw/tests/jwt/test_utils.py index 10524a1b..258d625d 100644 --- a/pyeudiw/tests/jwt/test_utils.py +++ b/pyeudiw/tests/jwt/test_utils.py @@ -1,7 +1,12 @@ -from pyeudiw.tests.jwt import VALID_TC_JWT, VALID_JWE -from pyeudiw.jwt.exceptions import JWTInvalidElementPosition, JWTDecodeError - -from pyeudiw.jwt.utils import decode_jwt_element, decode_jwt_header, decode_jwt_payload, is_jwt_format, is_jwe_format +from pyeudiw.jwt.exceptions import JWTDecodeError, JWTInvalidElementPosition +from pyeudiw.jwt.utils import ( + decode_jwt_element, + decode_jwt_header, + decode_jwt_payload, + is_jwe_format, + is_jwt_format, +) +from pyeudiw.tests.jwt import VALID_JWE, VALID_TC_JWT def test_decode_jwt_element(): diff --git a/pyeudiw/tests/jwt/test_verification.py b/pyeudiw/tests/jwt/test_verification.py index d1fea582..bc40bf0b 100644 --- a/pyeudiw/tests/jwt/test_verification.py +++ b/pyeudiw/tests/jwt/test_verification.py @@ -1,14 +1,13 @@ -from pyeudiw.jwt.helper import is_jwt_expired -from pyeudiw.jwt.jws_helper import JWSHelper - from cryptojwt.jwk.ec import new_ec_key +from pyeudiw.jwt.helper import is_jwt_expired +from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.jwt.verification import verify_jws_with_key from pyeudiw.tools.utils import iat_now def test_is_jwt_expired(): - jwk = new_ec_key('P-256') + jwk = new_ec_key("P-256") payload = {"exp": 1516239022} helper = JWSHelper(jwk) @@ -18,7 +17,7 @@ def test_is_jwt_expired(): def test_is_jwt_not_expired(): - jwk = new_ec_key('P-256') + jwk = new_ec_key("P-256") payload = {"exp": 999999999999} helper = JWSHelper(jwk) @@ -28,8 +27,8 @@ def test_is_jwt_not_expired(): def test_verify_jws_with_key(): - jwk = new_ec_key('P-256') - payload = {"exp": iat_now()+5000} + jwk = new_ec_key("P-256") + payload = {"exp": iat_now() + 5000} helper = JWSHelper(jwk) jws = helper.sign(payload) diff --git a/pyeudiw/tests/oauth2/test_dpop.py b/pyeudiw/tests/oauth2/test_dpop.py index 14abe1f5..8aa54415 100644 --- a/pyeudiw/tests/oauth2/test_dpop.py +++ b/pyeudiw/tests/oauth2/test_dpop.py @@ -1,17 +1,16 @@ import base64 import hashlib -import pytest +import pytest +from cryptojwt.jwk.ec import new_ec_key +from cryptojwt.jwk.rsa import new_rsa_key from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer, DPoPVerifier from pyeudiw.tools.utils import iat_now -from cryptojwt.jwk.ec import new_ec_key -from cryptojwt.jwk.rsa import new_rsa_key - -PRIVATE_JWK_EC = new_ec_key('P-256') +PRIVATE_JWK_EC = new_ec_key("P-256") PRIVATE_JWK = PRIVATE_JWK_EC.serialize(private=True) PUBLIC_JWK = PRIVATE_JWK_EC.serialize() @@ -24,34 +23,23 @@ "tos_uri": "https://wallet-provider.example.org/info_policy", "logo_uri": "https://wallet-provider.example.org/logo.svg", "aal": "https://wallet-provider.example.org/LoA/basic", - "cnf": - { - "jwk": PUBLIC_JWK - }, + "cnf": {"jwk": PUBLIC_JWK}, "authorization_endpoint": "haip:", - "response_types_supported": [ - "vp_token" - ], + "response_types_supported": ["vp_token"], "vp_formats_supported": { - "jwt_vp_json": { - "alg_values_supported": ["ES256"] - }, - "jwt_vc_json": { - "alg_values_supported": ["ES256"] - } + "jwt_vp_json": {"alg_values_supported": ["ES256"]}, + "jwt_vc_json": {"alg_values_supported": ["ES256"]}, }, - "request_object_signing_alg_values_supported": [ - "ES256" - ], + "request_object_signing_alg_values_supported": ["ES256"], "presentation_definition_uri_supported": False, "iat": iat_now(), - "exp": iat_now() + 1024 + "exp": iat_now() + 1024, } @pytest.fixture def private_jwk(): - return new_ec_key('P-256') + return new_ec_key("P-256") @pytest.fixture @@ -62,8 +50,7 @@ def jwshelper(private_jwk): @pytest.fixture def wia_jws(jwshelper): wia = jwshelper.sign( - WALLET_INSTANCE_ATTESTATION, - protected={'trust_chain': [], 'x5c': []} + WALLET_INSTANCE_ATTESTATION, protected={"trust_chain": [], "x5c": []} ) return wia @@ -77,9 +64,7 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): assert header["alg"] new_dpop = DPoPIssuer( - htu='https://example.org/redirect', - token=wia_jws, - private_jwk=private_jwk + htu="https://example.org/redirect", token=wia_jws, private_jwk=private_jwk ) proof = new_dpop.proof assert proof @@ -91,9 +76,12 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): assert "d" not in header["jwk"] payload = decode_jwt_payload(proof) - assert payload["ath"] == base64.urlsafe_b64encode( - hashlib.sha256(wia_jws.encode() - ).digest()).rstrip(b'=').decode() + assert ( + payload["ath"] + == base64.urlsafe_b64encode(hashlib.sha256(wia_jws.encode()).digest()) + .rstrip(b"=") + .decode() + ) assert payload["htm"] in ["GET", "POST", "get", "post"] assert payload["htu"] == "https://example.org/redirect" assert payload["jti"] @@ -103,7 +91,7 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): dpop = DPoPVerifier( public_jwk=PUBLIC_JWK, http_header_authz=f"DPoP {wia_jws}", - http_header_dpop=proof + http_header_dpop=proof, ) assert dpop.is_valid @@ -111,7 +99,7 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): dpop = DPoPVerifier( public_jwk=other_jwk, http_header_authz=f"DPoP {wia_jws}", - http_header_dpop=proof + http_header_dpop=proof, ) with pytest.raises(Exception): dpop.validate() @@ -120,7 +108,7 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): dpop = DPoPVerifier( public_jwk=PUBLIC_JWK, http_header_authz=f"DPoP {wia_jws}", - http_header_dpop="aaa" + http_header_dpop="aaa", ) assert dpop.is_valid is False @@ -128,5 +116,5 @@ def test_create_validate_dpop_http_headers(wia_jws, private_jwk=PRIVATE_JWK_EC): dpop = DPoPVerifier( public_jwk=PUBLIC_JWK, http_header_authz=f"DPoP {wia_jws}", - http_header_dpop="aaa" + proof[3:] + http_header_dpop="aaa" + proof[3:], ) diff --git a/pyeudiw/tests/openid4vp/schemas/test_schema.py b/pyeudiw/tests/openid4vp/schemas/test_schema.py index 1494c6b5..bdfc8ca3 100644 --- a/pyeudiw/tests/openid4vp/schemas/test_schema.py +++ b/pyeudiw/tests/openid4vp/schemas/test_schema.py @@ -1,12 +1,14 @@ import pytest - from pydantic import ValidationError from pyeudiw.federation.schemas.entity_configuration import ( - EntityConfigurationHeader, EntityConfigurationPayload) + EntityConfigurationHeader, + EntityConfigurationPayload, +) from pyeudiw.openid4vp.schemas.wallet_instance_attestation_request import ( WalletInstanceAttestationRequestHeader, - WalletInstanceAttestationRequestPayload) + WalletInstanceAttestationRequestPayload, +) def test_wir(): @@ -14,7 +16,7 @@ def test_wir(): "header": { "alg": "RS256", "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", - "typ": "var+jwt" + "typ": "var+jwt", }, "payload": { "iss": "vbeXJksM45xphtANnCiG6mCyuU4jfGNzopGuKvogg9c", @@ -42,67 +44,74 @@ def test_wir(): "RRfVrbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ", "e": "AQAB", "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", - "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg" + "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", } }, "iat": 1686645115, - "exp": 1686652315 - }} + "exp": 1686652315, + }, + } WalletInstanceAttestationRequestHeader(**wir_dict["header"]) - WalletInstanceAttestationRequestPayload( - **wir_dict["payload"]) + WalletInstanceAttestationRequestPayload(**wir_dict["payload"]) WalletInstanceAttestationRequestHeader.model_validate( - wir_dict["header"], context={"supported_algorithms": ["RS256"]}) + wir_dict["header"], context={"supported_algorithms": ["RS256"]} + ) with pytest.raises(ValidationError): WalletInstanceAttestationRequestHeader.model_validate( - wir_dict["header"], context={"supported_algorithms": []}) + wir_dict["header"], context={"supported_algorithms": []} + ) with pytest.raises(ValidationError): WalletInstanceAttestationRequestHeader.model_validate( - wir_dict["header"], context={"supported_algorithms": None}) + wir_dict["header"], context={"supported_algorithms": None} + ) with pytest.raises(ValidationError): WalletInstanceAttestationRequestHeader.model_validate( - wir_dict["header"], context={"supported_algorithms": ["RS384"]}) + wir_dict["header"], context={"supported_algorithms": ["RS384"]} + ) wir_dict["payload"]["type"] = "NOT_WalletInstanceAttestationRequest" with pytest.raises(ValidationError): WalletInstanceAttestationRequestPayload.model_validate( - wir_dict["payload"], context={"supported_algorithms": ["RS256"]}) + wir_dict["payload"], context={"supported_algorithms": ["RS256"]} + ) wir_dict["payload"]["type"] = "WalletInstanceAttestationRequest" - wir_dict["payload"]["cnf"] = { - "wrong_name_jwk": wir_dict["payload"]["cnf"]["jwk"]} + wir_dict["payload"]["cnf"] = {"wrong_name_jwk": wir_dict["payload"]["cnf"]["jwk"]} with pytest.raises(ValidationError): WalletInstanceAttestationRequestPayload.model_validate( - wir_dict["payload"], context={"supported_algorithms": ["RS256"]}) - wir_dict["payload"]["cnf"] = { - "jwk": wir_dict["payload"]["cnf"]["wrong_name_jwk"]} + wir_dict["payload"], context={"supported_algorithms": ["RS256"]} + ) + wir_dict["payload"]["cnf"] = {"jwk": wir_dict["payload"]["cnf"]["wrong_name_jwk"]} def test_entity_config_header(): header = { "alg": "RS256", "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", - "typ": "entity-statement+jwt" + "typ": "entity-statement+jwt", } EntityConfigurationHeader(**header) - header['typ'] = "entity-config+jwt" + header["typ"] = "entity-config+jwt" with pytest.raises(ValidationError): EntityConfigurationHeader(**header) - header['typ'] = "entity-statement+jwt" + header["typ"] = "entity-statement+jwt" with pytest.raises(ValidationError): EntityConfigurationHeader.model_validate( - header, context={"supported_algorithms": []}) + header, context={"supported_algorithms": []} + ) with pytest.raises(ValidationError): EntityConfigurationHeader.model_validate( - header, context={"supported_algorithms": ["asd"]}) + header, context={"supported_algorithms": ["asd"]} + ) EntityConfigurationHeader.model_validate( - header, context={"supported_algorithms": ["RS256"]}) + header, context={"supported_algorithms": ["RS256"]} + ) def test_entity_config_payload(): @@ -117,12 +126,12 @@ def test_entity_config_payload(): "kty": "RSA", "n": "5s4qi …", "e": "AQAB", - "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs" + "kid": "2HnoFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", } ] }, "metadata": { - "wallet_relying_party": { + "openid_credential_verifier": { "application_type": "web", "client_id": "https://rp.example.it", "client_name": "Name of an example organization", @@ -134,46 +143,29 @@ def test_entity_config_payload(): "n": "1Ta-sE …", "e": "AQAB", "kid": "YhNFS3YnC9tjiCaivhWLVUJ3AxwGGz_98uRFaqMEEs", - "x5c": [ - "..." - ] + "x5c": ["..."], } ] }, - "contacts": [ - "ops@verifier.example.org" - ], - "request_uris": [ - "https://verifier.example.org/request_uri" - ], - "redirect_uris": [ - "https://verifier.example.org/callback" - ], + "contacts": ["ops@verifier.example.org"], + "request_uris": ["https://verifier.example.org/request_uri"], + "redirect_uris": ["https://verifier.example.org/callback"], "default_acr_values": [ "https://www.spid.gov.it/SpidL2", - "https://www.spid.gov.it/SpidL3" + "https://www.spid.gov.it/SpidL3", ], "vp_formats": { "vc+sd-jwt": { - "sd-jwt_alg_values": [ - "ES256", - "ES384" - ], - "kb-jwt_alg_values": [ - "ES256", - "ES384" - ] + "sd-jwt_alg_values": ["ES256", "ES384"], + "kb-jwt_alg_values": ["ES256", "ES384"], } }, "default_max_age": 1111, - "authorization_signed_response_alg": [ - "RS256", - "ES256" - ], + "authorization_signed_response_alg": ["RS256", "ES256"], "authorization_encrypted_response_alg": [ "RSA-OAEP", "RSA-OAEP-256", - "ECDH-ES" + "ECDH-ES", ], "authorization_encrypted_response_enc": [ "A128CBC-HS256", @@ -181,39 +173,29 @@ def test_entity_config_payload(): "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" + "A256GCM", ], "subject_type": "pairwise", "require_auth_time": True, - "id_token_signed_response_alg": [ - "RS256", - "ES256" - ], - "id_token_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], + "id_token_signed_response_alg": ["RS256", "ES256"], + "id_token_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], "id_token_encrypted_response_enc": [ "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" - ] + "A256GCM", + ], }, "federation_entity": { "organization_name": "OpenID Wallet Verifier example", "homepage_uri": "https://verifier.example.org/home", "policy_uri": "https://verifier.example.org/policy", "logo_uri": "https://verifier.example.org/static/logo.svg", - "contacts": [ - "tech@verifier.example.org" - ] - } + "contacts": ["tech@verifier.example.org"], + }, }, - "authority_hints": [ - "https://registry.eudi-wallet.example.it" - ] + "authority_hints": ["https://registry.eudi-wallet.example.it"], } EntityConfigurationPayload(**payload) diff --git a/pyeudiw/tests/openid4vp/schemas/test_vp_token.py b/pyeudiw/tests/openid4vp/schemas/test_vp_token.py index 68251fd0..3b40199c 100644 --- a/pyeudiw/tests/openid4vp/schemas/test_vp_token.py +++ b/pyeudiw/tests/openid4vp/schemas/test_vp_token.py @@ -1,5 +1,4 @@ import pytest - from pydantic import ValidationError from pyeudiw.openid4vp.schemas.vp_token import VPTokenHeader, VPTokenPayload @@ -8,7 +7,7 @@ "header": { "alg": "ES256", "typ": "JWT", - "kid": "e0bbf2f1-8c3a-4eab-a8ac-2e8f34db8a47" + "kid": "e0bbf2f1-8c3a-4eab-a8ac-2e8f34db8a47", }, "payload": { "iss": "https://wallet-provider.example.org/instance/vbeXJksM45xphtANnCiG6mCyuU4jfGNzopGuKvogg9c", @@ -25,37 +24,41 @@ "kZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIlNjaHVsc3RyLiAxMiIsICJsb2NhbGl0eSI6ICJTY2h1bHBmb3J0YSIsICJyZWdpb24iOiAiU2FjaHNlbi1BbmhhbHQiLCAiY291bnRyeSI6ICJER" "SJ9XQ~WyJjR1ctZl9NVmlJUnp6M0Q1QVNxOUt3IiwgImVtYWlsIiwgIm1heEBob21lLmNvbSJd~eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImtiK2p3dCJ9.eyJub25jZSI6ICIyZjlkZWE4YTBkYm" "Y1ZGRiN2NlOWQyZmRlOWZiOGJkNiIsICJhdWQiOiAiaHR0cHM6Ly9leGFtcGxlLmNvbS92ZXJpZmllciIsICJpYXQiOiAxNjkwOTYyNzM1fQ.ScCgejwnR7fdF2trKDSJooNKWiz6-dLQGlQzRK-NV" - "MSayKWXxj6Ebxwleb2MS_SbSHYHN2GygLw5NNyXV_3TlA" + "MSayKWXxj6Ebxwleb2MS_SbSHYHN2GygLw5NNyXV_3TlA", # "vp": "~~~...~" - } + }, } def test_vp_token_header(): - VPTokenHeader(**VP_TOKEN['header']) + VPTokenHeader(**VP_TOKEN["header"]) # alg is ES256 # it should fail if alg is not in supported_algorithms with pytest.raises(ValidationError): VPTokenHeader.model_validate( - VP_TOKEN['header'], context={"supported_algorithms": None}) + VP_TOKEN["header"], context={"supported_algorithms": None} + ) with pytest.raises(ValidationError): VPTokenHeader.model_validate( - VP_TOKEN['header'], context={"supported_algorithms": []}) + VP_TOKEN["header"], context={"supported_algorithms": []} + ) with pytest.raises(ValidationError): VPTokenHeader.model_validate( - VP_TOKEN['header'], context={"supported_algorithms": ["asd"]}) + VP_TOKEN["header"], context={"supported_algorithms": ["asd"]} + ) VPTokenHeader.model_validate( - VP_TOKEN['header'], context={"supported_algorithms": ["ES256"]}) + VP_TOKEN["header"], context={"supported_algorithms": ["ES256"]} + ) def test_vp_token_payload(): - VPTokenPayload(**VP_TOKEN['payload']) + VPTokenPayload(**VP_TOKEN["payload"]) # it should fail on SD-JWT format or missing vp VP_TOKEN["payload"]["vp"] = VP_TOKEN["payload"]["vp"].replace("~", ".") with pytest.raises(ValidationError): - VPTokenPayload(**VP_TOKEN['payload']) + VPTokenPayload(**VP_TOKEN["payload"]) VP_TOKEN["payload"]["vp"] = VP_TOKEN["payload"]["vp"].replace(".", "~") del VP_TOKEN["payload"]["vp"] with pytest.raises(ValidationError): - VPTokenPayload(**VP_TOKEN['payload']) + VPTokenPayload(**VP_TOKEN["payload"]) diff --git a/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation.py b/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation.py index 0e3e7316..d799ac39 100644 --- a/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation.py +++ b/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation.py @@ -1,10 +1,9 @@ import pytest - from pydantic import ValidationError from pyeudiw.openid4vp.schemas.wallet_instance_attestation import ( WalletInstanceAttestationHeader, - WalletInstanceAttestationPayload + WalletInstanceAttestationPayload, ) WALLET_INSTANCE_ATTESTATION = { @@ -17,7 +16,7 @@ "eyJhbGciOiJFUz...H9gw", ], "typ": "wallet-attestation+jwt", - "x5c": ["MIIBjDCC ... XFehgKQA=="] + "x5c": ["MIIBjDCC ... XFehgKQA=="], }, "payload": { "iss": "https://wallet-provider.example.org", @@ -47,102 +46,107 @@ "rbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ", "e": "AQAB", "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", - "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg" + "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", } }, "authorization_endpoint": "haip:", - "response_types_supported": [ - "vp_token" - ], + "response_types_supported": ["vp_token"], "vp_formats_supported": { - "jwt_vp_json": { - "alg_values_supported": ["RS256"] - }, - "jwt_vc_json": { - "alg_values_supported": ["RS256"] - } + "jwt_vp_json": {"alg_values_supported": ["RS256"]}, + "jwt_vc_json": {"alg_values_supported": ["RS256"]}, }, - "request_object_signing_alg_values_supported": [ - "RS256" - ], + "request_object_signing_alg_values_supported": ["RS256"], "presentation_definition_uri_supported": False, "iat": 1687281195, - "exp": 1687288395 - } + "exp": 1687288395, + }, } def test_header(): - WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION['header']) + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) # alg is RS256 # it should fail if alg is not in supported_algorithms with pytest.raises(ValidationError): WalletInstanceAttestationHeader.model_validate( - WALLET_INSTANCE_ATTESTATION['header'], context={"supported_algorithms": None}) + WALLET_INSTANCE_ATTESTATION["header"], + context={"supported_algorithms": None}, + ) with pytest.raises(ValidationError): WalletInstanceAttestationHeader.model_validate( - WALLET_INSTANCE_ATTESTATION['header'], context={"supported_algorithms": []}) + WALLET_INSTANCE_ATTESTATION["header"], context={"supported_algorithms": []} + ) with pytest.raises(ValidationError): WalletInstanceAttestationHeader.model_validate( - WALLET_INSTANCE_ATTESTATION['header'], context={"supported_algorithms": ["asd"]}) + WALLET_INSTANCE_ATTESTATION["header"], + context={"supported_algorithms": ["asd"]}, + ) WalletInstanceAttestationHeader.model_validate( - WALLET_INSTANCE_ATTESTATION['header'], context={"supported_algorithms": ["RS256"]}) + WALLET_INSTANCE_ATTESTATION["header"], + context={"supported_algorithms": ["RS256"]}, + ) # x5c and trust_chain are not required - WALLET_INSTANCE_ATTESTATION['header']['x5c'] = None - WALLET_INSTANCE_ATTESTATION['header']['trust_chain'] = None - WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION['header']) - del WALLET_INSTANCE_ATTESTATION['header']['x5c'] - del WALLET_INSTANCE_ATTESTATION['header']['trust_chain'] - WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION['header']) + WALLET_INSTANCE_ATTESTATION["header"]["x5c"] = None + WALLET_INSTANCE_ATTESTATION["header"]["trust_chain"] = None + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) + del WALLET_INSTANCE_ATTESTATION["header"]["x5c"] + del WALLET_INSTANCE_ATTESTATION["header"]["trust_chain"] + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) # kid is required - WALLET_INSTANCE_ATTESTATION['header']['kid'] = None + WALLET_INSTANCE_ATTESTATION["header"]["kid"] = None with pytest.raises(ValidationError): - WalletInstanceAttestationHeader( - **WALLET_INSTANCE_ATTESTATION['header']) - del WALLET_INSTANCE_ATTESTATION['header']['kid'] + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) + del WALLET_INSTANCE_ATTESTATION["header"]["kid"] with pytest.raises(ValidationError): - WalletInstanceAttestationHeader( - **WALLET_INSTANCE_ATTESTATION['header']) + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) # typ must be "wallet-attestation-jwt" - WALLET_INSTANCE_ATTESTATION['header']['typ'] = "asd" + WALLET_INSTANCE_ATTESTATION["header"]["typ"] = "asd" with pytest.raises(ValidationError): - WalletInstanceAttestationHeader( - **WALLET_INSTANCE_ATTESTATION['header']) + WalletInstanceAttestationHeader(**WALLET_INSTANCE_ATTESTATION["header"]) def test_payload(): - WalletInstanceAttestationPayload(**WALLET_INSTANCE_ATTESTATION['payload']) + WalletInstanceAttestationPayload(**WALLET_INSTANCE_ATTESTATION["payload"]) WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) + WALLET_INSTANCE_ATTESTATION["payload"] + ) # iss is not HttpUrl - WALLET_INSTANCE_ATTESTATION['payload']['iss'] = WALLET_INSTANCE_ATTESTATION['payload']['iss'][4:] + WALLET_INSTANCE_ATTESTATION["payload"]["iss"] = WALLET_INSTANCE_ATTESTATION[ + "payload" + ]["iss"][4:] with pytest.raises(ValidationError): WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) - WALLET_INSTANCE_ATTESTATION['payload']['iss'] = "http" + \ - WALLET_INSTANCE_ATTESTATION['payload']['iss'] + WALLET_INSTANCE_ATTESTATION["payload"] + ) + WALLET_INSTANCE_ATTESTATION["payload"]["iss"] = ( + "http" + WALLET_INSTANCE_ATTESTATION["payload"]["iss"] + ) WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) + WALLET_INSTANCE_ATTESTATION["payload"] + ) # empty cnf - cnf = WALLET_INSTANCE_ATTESTATION['payload']['cnf'] - WALLET_INSTANCE_ATTESTATION['payload']['cnf'] = {} + cnf = WALLET_INSTANCE_ATTESTATION["payload"]["cnf"] + WALLET_INSTANCE_ATTESTATION["payload"]["cnf"] = {} with pytest.raises(ValidationError): WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) - del WALLET_INSTANCE_ATTESTATION['payload']['cnf'] + WALLET_INSTANCE_ATTESTATION["payload"] + ) + del WALLET_INSTANCE_ATTESTATION["payload"]["cnf"] with pytest.raises(ValidationError): WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) - WALLET_INSTANCE_ATTESTATION['payload']['cnf'] = cnf + WALLET_INSTANCE_ATTESTATION["payload"] + ) + WALLET_INSTANCE_ATTESTATION["payload"]["cnf"] = cnf # cnf jwk is not a JWK - WALLET_INSTANCE_ATTESTATION['payload']['cnf']['jwk'] = {} + WALLET_INSTANCE_ATTESTATION["payload"]["cnf"]["jwk"] = {} with pytest.raises(ValidationError): WalletInstanceAttestationPayload.model_validate( - WALLET_INSTANCE_ATTESTATION['payload']) + WALLET_INSTANCE_ATTESTATION["payload"] + ) diff --git a/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation_request.py b/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation_request.py index 0d8b8293..d0613d73 100644 --- a/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation_request.py +++ b/pyeudiw/tests/openid4vp/schemas/test_wallet_instance_attestation_request.py @@ -1,15 +1,16 @@ import pytest - from pydantic import ValidationError -from pyeudiw.openid4vp.schemas.wallet_instance_attestation_request import WalletInstanceAttestationRequestHeader, \ - WalletInstanceAttestationRequestPayload +from pyeudiw.openid4vp.schemas.wallet_instance_attestation_request import ( + WalletInstanceAttestationRequestHeader, + WalletInstanceAttestationRequestPayload, +) WALLET_INSTANCE_ATTESTATION_REQUEST = { "header": { "alg": "RS256", "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", - "typ": "var+jwt" + "typ": "var+jwt", }, "payload": { "iss": "vbeXJksM45xphtANnCiG6mCyuU4jfGNzopGuKvogg9c", @@ -18,57 +19,68 @@ "type": "WalletInstanceAttestationRequest", "nonce": ".....", "cnf": { - "jwk": { - "alg": "RS256", - "kty": "RSA", - "use": "sig", - "x5c": [ - "MIIC+DCCAeCgAwIBAgIJBIGjYW6hFpn2MA0GCSqGSIb3DQEBBQUAMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTAeFw0xNjExMjIyMjIyMDVaFw0zMDA4MD" - "EyMjIyMDVaMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMnjZc5bm/eGIHq09N9HKHahM7Y31P0u" - "l+A2wwP4lSpIwFrWHzxw88/7Dwk9QMc+orGXX95R6av4GF+Es/nG3uK45ooMVMa/hYCh0Mtx3gnSuoTavQEkLzCvSwTqVwzZ+5noukWVqJuMKNwjL77GNcPLY7Xy2/skMCT5bR" - "8UoWaufooQvYq6SyPcRAU4BtdquZRiBT4U5f+4pwNTxSvey7ki50yc1tG49Per/0zA4O6Tlpv8x7Red6m1bCNHt7+Z5nSl3RX/QYyAEUX1a28VcYmR41Osy+o2OUCXYdUAphDa" - "Ho4/8rbKTJhlu8jEcc1KoMXAKjgaVZtG/v5ltx6AXY0CAwEAAaMvMC0wDAYDVR0TBAUwAwEB/zAdBgNVHQ4EFgQUQxFG602h1cG+pnyvJoy9pGJJoCswDQYJKoZIhvcNAQEFBQ" - "ADggEBAGvtCbzGNBUJPLICth3mLsX0Z4z8T8iu4tyoiuAshP/Ry/ZBnFnXmhD8vwgMZ2lTgUWwlrvlgN+fAtYKnwFO2G3BOCFw96Nm8So9sjTda9CCZ3dhoH57F/hVMBB0K6xh" - "klAc0b5ZxUpCIN92v/w+xZoz1XQBHe8ZbRHaP1HpRM4M7DJk2G5cgUCyu3UBvYS41sHvzrxQ3z7vIePRA4WF4bEkfX12gvny0RsPkrbVMXX1Rj9t6V7QXrbPYBAO+43JvDGYaw" - "xYVvLhz+BJ45x50GFQmHszfY3BR9TPK8xmMmQwtIvLu1PMttNCs7niCYkSiUv2sc2mlq1i3IashGkkgmo=" - ], - "n": "yeNlzlub94YgerT030codqEztjfU_S6X4DbDA_iVKkjAWtYfPHDzz_sPCT1Axz6isZdf3lHpq_gYX4Sz-cbe4rjmigxUxr-FgKHQy3HeCdK6hNq9ASQvMK9LBOpXDNn7mei6R" - "ZWom4wo3CMvvsY1w8tjtfLb-yQwJPltHxShZq5-ihC9irpLI9xEBTgG12q5lGIFPhTl_7inA1PFK97LuSLnTJzW0bj096v_TMDg7pOWm_zHtF53qbVsI0e3v5nmdKXdFf9BjIARRfV" - "rbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ", - "e": "AQAB", - "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", - "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg" - } + "jwk": { + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "x5c": [ + "MIIC+DCCAeCgAwIBAgIJBIGjYW6hFpn2MA0GCSqGSIb3DQEBBQUAMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTAeFw0xNjExMjIyMjIyMDVaFw0zMDA4MD" + "EyMjIyMDVaMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMnjZc5bm/eGIHq09N9HKHahM7Y31P0u" + "l+A2wwP4lSpIwFrWHzxw88/7Dwk9QMc+orGXX95R6av4GF+Es/nG3uK45ooMVMa/hYCh0Mtx3gnSuoTavQEkLzCvSwTqVwzZ+5noukWVqJuMKNwjL77GNcPLY7Xy2/skMCT5bR" + "8UoWaufooQvYq6SyPcRAU4BtdquZRiBT4U5f+4pwNTxSvey7ki50yc1tG49Per/0zA4O6Tlpv8x7Red6m1bCNHt7+Z5nSl3RX/QYyAEUX1a28VcYmR41Osy+o2OUCXYdUAphDa" + "Ho4/8rbKTJhlu8jEcc1KoMXAKjgaVZtG/v5ltx6AXY0CAwEAAaMvMC0wDAYDVR0TBAUwAwEB/zAdBgNVHQ4EFgQUQxFG602h1cG+pnyvJoy9pGJJoCswDQYJKoZIhvcNAQEFBQ" + "ADggEBAGvtCbzGNBUJPLICth3mLsX0Z4z8T8iu4tyoiuAshP/Ry/ZBnFnXmhD8vwgMZ2lTgUWwlrvlgN+fAtYKnwFO2G3BOCFw96Nm8So9sjTda9CCZ3dhoH57F/hVMBB0K6xh" + "klAc0b5ZxUpCIN92v/w+xZoz1XQBHe8ZbRHaP1HpRM4M7DJk2G5cgUCyu3UBvYS41sHvzrxQ3z7vIePRA4WF4bEkfX12gvny0RsPkrbVMXX1Rj9t6V7QXrbPYBAO+43JvDGYaw" + "xYVvLhz+BJ45x50GFQmHszfY3BR9TPK8xmMmQwtIvLu1PMttNCs7niCYkSiUv2sc2mlq1i3IashGkkgmo=" + ], + "n": "yeNlzlub94YgerT030codqEztjfU_S6X4DbDA_iVKkjAWtYfPHDzz_sPCT1Axz6isZdf3lHpq_gYX4Sz-cbe4rjmigxUxr-FgKHQy3HeCdK6hNq9ASQvMK9LBOpXDNn7mei6R" + "ZWom4wo3CMvvsY1w8tjtfLb-yQwJPltHxShZq5-ihC9irpLI9xEBTgG12q5lGIFPhTl_7inA1PFK97LuSLnTJzW0bj096v_TMDg7pOWm_zHtF53qbVsI0e3v5nmdKXdFf9BjIARRfV" + "rbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ", + "e": "AQAB", + "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", + "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg", + } }, "iat": 1686645115, - "exp": 1686652315 - }} + "exp": 1686652315, + }, +} def test_header(): WalletInstanceAttestationRequestHeader( - **WALLET_INSTANCE_ATTESTATION_REQUEST['header']) + **WALLET_INSTANCE_ATTESTATION_REQUEST["header"] + ) with pytest.raises(ValidationError): WalletInstanceAttestationRequestHeader.model_validate( - WALLET_INSTANCE_ATTESTATION_REQUEST['header'], context={"supported_algorithms": ["RS128", "ES128"]}) + WALLET_INSTANCE_ATTESTATION_REQUEST["header"], + context={"supported_algorithms": ["RS128", "ES128"]}, + ) WalletInstanceAttestationRequestHeader.model_validate( - WALLET_INSTANCE_ATTESTATION_REQUEST['header'], context={"supported_algorithms": ["RS256", "ES256"]}) - WALLET_INSTANCE_ATTESTATION_REQUEST['header']['typ'] = 'wrong' + WALLET_INSTANCE_ATTESTATION_REQUEST["header"], + context={"supported_algorithms": ["RS256", "ES256"]}, + ) + WALLET_INSTANCE_ATTESTATION_REQUEST["header"]["typ"] = "wrong" with pytest.raises(ValidationError): WalletInstanceAttestationRequestHeader( - **WALLET_INSTANCE_ATTESTATION_REQUEST['header']) + **WALLET_INSTANCE_ATTESTATION_REQUEST["header"] + ) def test_payload(): WalletInstanceAttestationRequestPayload( - **WALLET_INSTANCE_ATTESTATION_REQUEST['payload']) - WALLET_INSTANCE_ATTESTATION_REQUEST['payload']['type'] = 'wrong' + **WALLET_INSTANCE_ATTESTATION_REQUEST["payload"] + ) + WALLET_INSTANCE_ATTESTATION_REQUEST["payload"]["type"] = "wrong" with pytest.raises(ValidationError): WalletInstanceAttestationRequestPayload( - **WALLET_INSTANCE_ATTESTATION_REQUEST['payload']) + **WALLET_INSTANCE_ATTESTATION_REQUEST["payload"] + ) WALLET_INSTANCE_ATTESTATION_REQUEST["payload"]["cnf"] = { - "wrong_name_jwk": WALLET_INSTANCE_ATTESTATION_REQUEST["payload"]["cnf"]["jwk"]} + "wrong_name_jwk": WALLET_INSTANCE_ATTESTATION_REQUEST["payload"]["cnf"]["jwk"] + } with pytest.raises(ValidationError): WalletInstanceAttestationRequestPayload.model_validate( - WALLET_INSTANCE_ATTESTATION_REQUEST["payload"]) + WALLET_INSTANCE_ATTESTATION_REQUEST["payload"] + ) diff --git a/pyeudiw/tests/openid4vp/test_authorization_request.py b/pyeudiw/tests/openid4vp/test_authorization_request.py index aacd8834..42aa52cd 100644 --- a/pyeudiw/tests/openid4vp/test_authorization_request.py +++ b/pyeudiw/tests/openid4vp/test_authorization_request.py @@ -1,6 +1,9 @@ from dataclasses import dataclass -from pyeudiw.openid4vp.authorization_request import build_authorization_request_claims, build_authorization_request_url +from pyeudiw.openid4vp.authorization_request import ( + build_authorization_request_claims, + build_authorization_request_url, +) def test_build_authoriation_request_url(): @@ -14,18 +17,22 @@ class TestCase: test_cases: list[TestCase] = [ TestCase( scheme="haip", - params={"client_id": "https://rp.example", - "request_uri": "https://rp.example/resource_location.jwt"}, + params={ + "client_id": "https://rp.example", + "request_uri": "https://rp.example/resource_location.jwt", + }, exp="haip://?client_id=https%3A%2F%2Frp.example&https%3A%2F%2Frp.example%2Fresource_location.jwt", - explanation="base scheme like haip or eudiw" + explanation="base scheme like haip or eudiw", ), TestCase( scheme="https://walletsolution.example", - params={"client_id": "https://rp.example", - "request_uri": "https://rp.example/resource_location.jwt"}, + params={ + "client_id": "https://rp.example", + "request_uri": "https://rp.example/resource_location.jwt", + }, exp="https://walletsolution.example?client_id=https%3A%2F%2Frp.example.org&https%3A%2F%2Frp.example.org%2Fresource_location.jwt", - explanation="base scheme is a complete URI location" - ) + explanation="base scheme is a complete URI location", + ), ] for i, case in enumerate(test_cases): @@ -50,36 +57,30 @@ def test_build_authorization_request_claims(): { "id": "specific-id", "purpose": "Request presentation holding Power of Representation attestation", - "format": { - "vc+sd-jwt": {} - }, + "format": {"vc+sd-jwt": {}}, "constraints": { "fields": [ { - "path": [ - "$.vct" - ], + "path": ["$.vct"], "filter": { "type": "string", - "pattern": "urn:eu.europa.ec.eudi:por:1" - } + "pattern": "urn:eu.europa.ec.eudi:por:1", + }, } ] - } + }, } - ] - } + ], + }, } - claims = build_authorization_request_claims( - client_id, state, response_uri, config) + claims = build_authorization_request_claims(client_id, state, response_uri, config) assert "aud" not in claims assert "nonce" in claims assert "presentation_definition" in claims assert claims["response_mode"] == "direct_post.jwt" - assert claims["scope"] in ( - "family_name given_name", "given_name family_name") + assert claims["scope"] in ("family_name given_name", "given_name family_name") assert claims["exp"] > claims["iat"] assert claims["client_id"] == client_id assert claims["response_type"] == "vp_token" @@ -95,36 +96,32 @@ def test_build_authorization_request_claims(): { "id": "specific-id", "purpose": "Request presentation holding Power of Representation attestation", - "format": { - "vc+sd-jwt": {} - }, + "format": {"vc+sd-jwt": {}}, "constraints": { "fields": [ { - "path": [ - "$.vct" - ], + "path": ["$.vct"], "filter": { "type": "string", - "pattern": "urn:eu.europa.ec.eudi:por:1" - } + "pattern": "urn:eu.europa.ec.eudi:por:1", + }, } ] - } + }, } - ] - } + ], + }, } claims = build_authorization_request_claims( - client_id, state, response_uri, config_aud) + client_id, state, response_uri, config_aud + ) assert claims["aud"] == "https://self-issued.me/v2" assert "nonce" in claims assert "presentation_definition" in claims assert claims["response_mode"] == "direct_post.jwt" - assert claims["scope"] in ( - "family_name given_name", "given_name family_name") + assert claims["scope"] in ("family_name given_name", "given_name family_name") assert claims["exp"] > claims["iat"] assert claims["client_id"] == client_id assert claims["response_type"] == "vp_token" @@ -140,35 +137,31 @@ def test_build_authorization_request_claims(): { "id": "specific-id", "purpose": "Request presentation holding Power of Representation attestation", - "format": { - "vc+sd-jwt": {} - }, + "format": {"vc+sd-jwt": {}}, "constraints": { "fields": [ { - "path": [ - "$.vct" - ], + "path": ["$.vct"], "filter": { "type": "string", - "pattern": "urn:eu.europa.ec.eudi:por:1" - } + "pattern": "urn:eu.europa.ec.eudi:por:1", + }, } ] - } + }, } - ] - } + ], + }, } claims = build_authorization_request_claims( - client_id, state, response_uri, config_rmode) + client_id, state, response_uri, config_rmode + ) assert claims["response_mode"] == "direct_post" assert "nonce" in claims assert "presentation_definition" in claims - assert claims["scope"] in ( - "family_name given_name", "given_name family_name") + assert claims["scope"] in ("family_name given_name", "given_name family_name") assert claims["exp"] > claims["iat"] assert claims["client_id"] == client_id assert claims["response_type"] == "vp_token" @@ -177,19 +170,18 @@ def test_build_authorization_request_claims(): config_noscope = { "expiration_time": 1, "aud": "https://self-issued.me/v2", - "presentation_definition": { - "id": "global-id", - "input_descriptors": [] - } + "presentation_definition": {"id": "global-id", "input_descriptors": []}, } claims = build_authorization_request_claims( - client_id, state, response_uri, config_noscope) + client_id, state, response_uri, config_noscope + ) assert "scope" not in claims # case 4: force nonce claims = build_authorization_request_claims( - client_id, state, response_uri, config_noscope, nonce="predetermined-nonce") + client_id, state, response_uri, config_noscope, nonce="predetermined-nonce" + ) assert claims["nonce"] == "predetermined-nonce" # case 5: custom client_id @@ -204,27 +196,24 @@ def test_build_authorization_request_claims(): { "id": "specific-id", "purpose": "Request presentation holding Power of Representation attestation", - "format": { - "vc+sd-jwt": {} - }, + "format": {"vc+sd-jwt": {}}, "constraints": { "fields": [ { - "path": [ - "$.vct" - ], + "path": ["$.vct"], "filter": { "type": "string", - "pattern": "urn:eu.europa.ec.eudi:por:1" - } + "pattern": "urn:eu.europa.ec.eudi:por:1", + }, } ] - } + }, } - ] - } + ], + }, } claims = build_authorization_request_claims( - "custom-client-id", state, response_uri, config_custom_id) + "custom-client-id", state, response_uri, config_custom_id + ) assert claims["iss"] != client_id diff --git a/pyeudiw/tests/openid4vp/test_authorization_response.py b/pyeudiw/tests/openid4vp/test_authorization_response.py index b351f692..fded18a8 100644 --- a/pyeudiw/tests/openid4vp/test_authorization_response.py +++ b/pyeudiw/tests/openid4vp/test_authorization_response.py @@ -2,8 +2,14 @@ import satosa.context from pyeudiw.jwt.jwe_helper import JWEHelper -from pyeudiw.openid4vp.authorization_response import DirectPostJwtJweParser, DirectPostParser -from pyeudiw.openid4vp.exceptions import AuthRespParsingException, AuthRespValidationException +from pyeudiw.openid4vp.authorization_response import ( + DirectPostJwtJweParser, + DirectPostParser, +) +from pyeudiw.openid4vp.exceptions import ( + AuthRespParsingException, + AuthRespValidationException, +) @pytest.fixture @@ -16,7 +22,7 @@ def jwe_helper(): "n": "vxPbeX_PgRI5mpOoANQkC1HU09LZePPygjaUtrXrZkx0rhK-2LaoNLns5EoBkJOuj8JrHU_W2UOZBIE-tpLJ8UUSuJNXNZhrWqezhjO0aIue_JgyMWjp2IZ_BofyhrMenqYI6oA8B9eKdD-1zxF0vflCHylq7vdYcKKPZcr0QjyVTltuEbRiS8WHjFV5_sWuYJkDt-5bXW4ZapDV4NG2OcwcRROR5gwU1EhWX-kQbNPw84wZpEGXy5fFTaosfUVbvKSXP_d8IZ4fizqMi69bk6IjLqfw3JZcIKnzz7ou302dq2sIH_R1gQHNJ6b-oTh5nq3JLCMViGHZAH0sIMu6eiLzvcADX5PpbMO9Y83l-UbMEqH0iv7wtW3gVE7OjolZCiFXwhhntWj5ccomlzgYFreebRevQItHZUxiN7n7tJMOWouV1LHecWfixHbweaBFooGSzY9hlFvERKmbfPqNaIce8PHd-dDWw9Yxq5RdNpcMQKm5ruYlV3pWGxoQaHdX", "p": "zxOeQgvnCZzIIsZoz18ROo8fCRZT7K6h4uvz6fvTz6e_6TKWUYohjsBfYAkbNgJamkxdSJyaQWDMPs0mIG2j2IcpqG0JJGQ68QCqai8H-o-_wb0hjp3fY6TofVaEzFvQiOJE2ZyqtSn7hDrEFEmceJB_VzIvOZbS-AZo9--R2PTQ4h7CurkWKCOKAbeBfOWEX_s6UOaLQzMYDykiHSlPQmF9BZAsaoRHXdZLlsgxgQ3nbPbsBx2d1axlIIhb7Xkl", "q": "7DiV8cgKFnOGCTHT3cr_8xIKwD0LoWnGibqdA2p0XSQTXTLUr532DKK_3YMdm0F0YtCyPQBbsJoLspbnK6yTo3RPFr0zooJ5eCYnLO_qOBYFwYbOhhXrYPjfpEDSXls9BD6cHhCCQtiNAADIjaMpmoEexPD6lMoijTF7qzEst2fszvYTToczroWHPe8RuJR7J_FEM-soD99ERqJanj5BUs-ruuNwshM_Fp7C-ubt9PbMo9gv5p9nrsT8oAI7wJvL", - "use": "enc" + "use": "enc", } jwe_helper = JWEHelper(private_key) return jwe_helper @@ -33,17 +39,13 @@ def test_direct_post_parser_good_case(): "id": "submit-id", "definition_id": "definition-id", "descriptor_map": [ - { - "id": "verifiable-credential-type", - "format": "dc+sd-jwt", - "path": "$.vct" - } - ] + {"id": "verifiable-credential-type", "format": "dc+sd-jwt", "path": "$.vct"} + ], } ctx.request = { "vp_token": vp_token, "state": state, - "presentation_submission": presentation_submission + "presentation_submission": presentation_submission, } resp = parser.parse_and_validate(ctx) @@ -64,17 +66,13 @@ def test_direct_post_response_bad_parse_case(): "id": "submit-id", "definition_id": "definition-id", "descriptor_map": [ - { - "id": "verifiable-credential-type", - "format": "dc+sd-jwt", - "path": "$.vct" - } - ] + {"id": "verifiable-credential-type", "format": "dc+sd-jwt", "path": "$.vct"} + ], } ctx.qs_params = { "vp_token": vp_token, "state": state, - "presentation_submission": presentation_submission + "presentation_submission": presentation_submission, } try: @@ -113,21 +111,15 @@ def test_direct_post_jwt_jwe_parser_good_case(jwe_helper): "id": "submit-id", "definition_id": "definition-id", "descriptor_map": [ - { - "id": "verifiable-credential-type", - "format": "dc+sd-jwt", - "path": "$.vct" - } - ] + {"id": "verifiable-credential-type", "format": "dc+sd-jwt", "path": "$.vct"} + ], } data = { "vp_token": vp_token, "state": state, - "presentation_submission": presentation_submission - } - ctx.request = { - "response": jwe_helper.encrypt(data) + "presentation_submission": presentation_submission, } + ctx.request = {"response": jwe_helper.encrypt(data)} resp = parser.parse_and_validate(ctx) assert resp.vp_token == vp_token @@ -147,19 +139,17 @@ def test_direct_post_jwt_jwe_parser_bad_parse_case(jwe_helper): "id": "submit-id", "definition_id": "definition-id", "descriptor_map": [ - { - "id": "verifiable-credential-type", - "format": "dc+sd-jwt", - "path": "$.vct" - } - ] + {"id": "verifiable-credential-type", "format": "dc+sd-jwt", "path": "$.vct"} + ], } ctx.qs_params = { - "response": jwe_helper.encrypt({ - "vp_token": vp_token, - "state": state, - "presentation_submission": presentation_submission - }) + "response": jwe_helper.encrypt( + { + "vp_token": vp_token, + "state": state, + "presentation_submission": presentation_submission, + } + ) } try: @@ -194,7 +184,7 @@ def test_direct_post_jwt_jwe_parser_bad_validation_case(jwe_helper): "e": "AQAB", "kty": "RSA", "n": "sCLDmvDKr4y7EHLf4TbjNqa3_p4GnTLqPXdvi0ce2BW2NIK1vYtz9uk8oIlResIWJk1T59LAS8YGF5BLkjLLSyMjrhHySoyRDrBEk_cz-F3Mabc7x-5GDAbxvFDZKQ2n5UVQUWgboFISGp2zpmrYzvewv2WCxZ4a3mS6kwAvjl_S9kahD-SFjiNyHsSaA0lDrF5xQpT2MaMha0dPwgNrChCcG4TTG5YBy4zgktlfA9GRnrEKUJioiKYapMAotziNRBoH128CJGAdMxaO5SVYC0PVLnmKd3cv4bPqGYMRszI6x3i5YUTLk8HwWPL9SUV25pAFp_nDRlgQTdvxssClhZ8VbMZQ3x2I738ixGud_1ggBVFTGDDGDQem4jOz6AsPBVrwtwWStVpA5V5FyEhbgZmE7Orb0cNsmIBjVIPBuFtLBmSELAiJ_WK7ajo3xKtIMTFB-JVX1PVawZOkUzS94BnJ0i7RGc4uzZBhhiOWxBHQGIFhfJnD1OggXnHkVYRn", - "use": "enc" + "use": "enc", } wrong_helper = JWEHelper(wrong_public_key) @@ -206,21 +196,15 @@ def test_direct_post_jwt_jwe_parser_bad_validation_case(jwe_helper): "id": "submit-id", "definition_id": "definition-id", "descriptor_map": [ - { - "id": "verifiable-credential-type", - "format": "dc+sd-jwt", - "path": "$.vct" - } - ] + {"id": "verifiable-credential-type", "format": "dc+sd-jwt", "path": "$.vct"} + ], } data = { "vp_token": vp_token, "state": state, - "presentation_submission": presentation_submission - } - ctx.request = { - "response": wrong_helper.encrypt(data) + "presentation_submission": presentation_submission, } + ctx.request = {"response": wrong_helper.encrypt(data)} try: parser.parse_and_validate(ctx) diff --git a/pyeudiw/tests/presentation_exchange/schemas/test_presentation_definition.py b/pyeudiw/tests/presentation_exchange/schemas/test_presentation_definition.py index 40d874ee..689b839a 100644 --- a/pyeudiw/tests/presentation_exchange/schemas/test_presentation_definition.py +++ b/pyeudiw/tests/presentation_exchange/schemas/test_presentation_definition.py @@ -1,12 +1,13 @@ import json from pathlib import Path -from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import \ - PresentationDefinition +from pyeudiw.presentation_exchange.schemas.oid4vc_presentation_definition import ( + PresentationDefinition, +) def test_presentation_definition(): - p = Path(__file__).with_name('presentation_definition_sd_jwt_vc.json') + p = Path(__file__).with_name("presentation_definition_sd_jwt_vc.json") with open(p) as json_file: data = json.load(json_file) PresentationDefinition(**data) diff --git a/pyeudiw/tests/satosa/__init__.py b/pyeudiw/tests/satosa/__init__.py index c535371e..f07a7032 100644 --- a/pyeudiw/tests/satosa/__init__.py +++ b/pyeudiw/tests/satosa/__init__.py @@ -1,38 +1,35 @@ -from pyeudiw.tests.federation.base import ( - leaf_wallet_jwk, - leaf_cred_jwk_prot -) +from io import StringIO + +from cryptojwt.jwk.jwk import key_from_jwk_dict +from satosa.context import Context + from pyeudiw.jwk import JWK +from pyeudiw.sd_jwt.holder import SDJWTHolder +from pyeudiw.sd_jwt.issuer import SDJWTIssuer +from pyeudiw.sd_jwt.utils.yaml_specification import _yaml_load_specification +from pyeudiw.storage.db_engine import DBEngine +from pyeudiw.tests.federation.base import leaf_cred_jwk_prot, leaf_wallet_jwk from pyeudiw.tests.settings import ( CONFIG, - CREDENTIAL_ISSUER_ENTITY_ID, CREDENTIAL_ISSUER_CONF, + CREDENTIAL_ISSUER_ENTITY_ID, ) -from pyeudiw.sd_jwt.holder import SDJWTHolder -from pyeudiw.sd_jwt.issuer import SDJWTIssuer from pyeudiw.tools.utils import exp_from_now, iat_now -from satosa.context import Context -from pyeudiw.storage.db_engine import DBEngine - -from pyeudiw.sd_jwt.utils.yaml_specification import _yaml_load_specification -from cryptojwt.jwk.jwk import key_from_jwk_dict -from io import StringIO issuer_jwk = leaf_cred_jwk_prot.serialize(private=True) holder_jwk = leaf_wallet_jwk.serialize(private=True) settings = CREDENTIAL_ISSUER_CONF -settings['issuer'] = CREDENTIAL_ISSUER_ENTITY_ID -settings['default_exp'] = CONFIG['jwt']['default_exp'] +settings["issuer"] = CREDENTIAL_ISSUER_ENTITY_ID +settings["default_exp"] = CONFIG["jwt"]["default_exp"] -sd_specification = _yaml_load_specification( - StringIO(settings["sd_specification"])) +sd_specification = _yaml_load_specification(StringIO(settings["sd_specification"])) user_claims = { "iss": settings["issuer"], "iat": iat_now(), - "exp": exp_from_now(settings["default_exp"]) # in seconds + "exp": exp_from_now(settings["default_exp"]), # in seconds } issued_jwt = SDJWTIssuer( @@ -40,8 +37,7 @@ issuer_jwk, holder_jwk, add_decoy_claims=sd_specification.get("add_decoy_claims", True), - serialization_format=sd_specification.get( - "serialization_format", "compact"), + serialization_format=sd_specification.get("serialization_format", "compact"), extra_header_parameters={"typ": "vc+sd-jwt"}, ) @@ -51,8 +47,11 @@ serialization_format="compact", ) -ec_key = key_from_jwk_dict(holder_jwk) if sd_specification.get( - "key_binding", False) else None +ec_key = ( + key_from_jwk_dict(holder_jwk) + if sd_specification.get("key_binding", False) + else None +) def _create_vp_token(nonce: str, aud: str, holder_jwk: JWK, sign_alg: str) -> str: @@ -83,30 +82,30 @@ def _generate_response(state: str, vp_token: str) -> dict: { "id": "pid-sd-jwt:unique_id+given_name+family_name", "path": "$.vp_token.verified_claims.claims._sd[0]", - "format": "vc+sd-jwt" + "format": "vc+sd-jwt", } - ] - } + ], + }, } -def _generate_post_context(context: Context, request_uri: str, encrypted_response: str) -> Context: +def _generate_post_context( + context: Context, request_uri: str, encrypted_response: str +) -> Context: context.request_method = "POST" context.request_uri = request_uri context.request = {"response": encrypted_response} - context.http_headers = { - "HTTP_CONTENT_TYPE": "application/x-www-form-urlencoded"} + context.http_headers = {"HTTP_CONTENT_TYPE": "application/x-www-form-urlencoded"} return context -def _initialize_session(db_engine: DBEngine, state: str, session_id: str, nonce: str) -> None: - db_engine.init_session( - state=state, - session_id=session_id - ) +def _initialize_session( + db_engine: DBEngine, state: str, session_id: str, nonce: str +) -> None: + db_engine.init_session(state=state, session_id=session_id) doc_id = db_engine.get_by_state(state)["document_id"] db_engine.update_request_object( - document_id=doc_id, - request_object={"nonce": nonce, "state": state}) + document_id=doc_id, request_object={"nonce": nonce, "state": state} + ) diff --git a/pyeudiw/tests/satosa/test_backend.py b/pyeudiw/tests/satosa/test_backend.py index 6d0b2577..ccbc2a66 100644 --- a/pyeudiw/tests/satosa/test_backend.py +++ b/pyeudiw/tests/satosa/test_backend.py @@ -6,24 +6,25 @@ import pytest from bs4 import BeautifulSoup +from cryptojwt.jws.jws import JWS from satosa.context import Context from satosa.internal import InternalData from satosa.state import State -from pyeudiw.jwt.jws_helper import JWSHelper -from cryptojwt.jws.jws import JWS +from pyeudiw.jwt.jws_helper import JWSHelper from pyeudiw.jwt.utils import decode_jwt_header, decode_jwt_payload from pyeudiw.oauth2.dpop import DPoPIssuer from pyeudiw.satosa.backend import OpenID4VPBackend from pyeudiw.storage.base_storage import TrustType from pyeudiw.storage.db_engine import DBEngine from pyeudiw.tests.federation.base import ( - trust_chain_wallet, - ta_ec, EXP, NOW, + leaf_cred_jwk_prot, + ta_ec, + ta_ec_signed, ta_jwk, - ta_ec_signed, leaf_cred_jwk_prot + trust_chain_wallet, ) from pyeudiw.tests.settings import ( BASE_URL, @@ -31,9 +32,9 @@ CREDENTIAL_ISSUER_ENTITY_ID, INTERNAL_ATTRIBUTES, PRIVATE_JWK, - WALLET_INSTANCE_ATTESTATION + WALLET_INSTANCE_ATTESTATION, ) - +from pyeudiw.trust.handler.interface import TrustHandlerInterface from pyeudiw.trust.model.trust_source import TrustSourceData @@ -42,29 +43,31 @@ class TestOpenID4VPBackend: @pytest.fixture(autouse=True) def create_backend(self): - db_engine_inst = DBEngine(CONFIG['storage']) + db_engine_inst = DBEngine(CONFIG["storage"]) + + # TODO - not necessary if federation is not tested db_engine_inst.add_trust_anchor( - entity_id=ta_ec['iss'], + entity_id=ta_ec["iss"], entity_configuration=ta_ec_signed, exp=EXP, ) issuer_jwk = leaf_cred_jwk_prot.serialize(private=True) + db_engine_inst.add_or_update_trust_attestation( entity_id=CREDENTIAL_ISSUER_ENTITY_ID, trust_type=TrustType.DIRECT_TRUST_SD_JWT_VC, - jwks=[issuer_jwk] + jwks=[issuer_jwk], ) tsd = TrustSourceData.empty(CREDENTIAL_ISSUER_ENTITY_ID) tsd.add_key(issuer_jwk) - db_engine_inst.add_trust_source( - tsd.serialize() - ) + db_engine_inst.add_trust_source(tsd.serialize()) self.backend = OpenID4VPBackend( - Mock(), INTERNAL_ATTRIBUTES, CONFIG, BASE_URL, "name") + Mock(), INTERNAL_ATTRIBUTES, CONFIG, BASE_URL, "name" + ) url_map = self.backend.register_endpoints() assert len(url_map) == 7 @@ -76,27 +79,36 @@ def internal_attributes(self): "givenname": {"openid": ["given_name"]}, "mail": {"openid": ["email"]}, "edupersontargetedid": {"openid": ["sub"]}, - "surname": {"openid": ["family_name"]} + "surname": {"openid": ["family_name"]}, } } @pytest.fixture def context(self): context = Context() - context.target_frontend = 'someFrontend' + context.target_frontend = "someFrontend" context.state = State() return context def test_backend_init(self): assert self.backend.name == "name" + # TODO: Move to trust evaluation handlers tests def test_entity_configuration(self, context): context.qs_params = {} - entity_config = self.backend.entity_configuration_endpoint(context) + + _fedback: TrustHandlerInterface = self.backend.get_trust_backend_by_class_name( + "FederationHandler" + ) + assert _fedback + + entity_config = _fedback.entity_configuration_endpoint(context) assert entity_config assert entity_config.status == "200" assert entity_config.message + # TODO: decode EC jwt, validate signature and both header and payload schema validation + def test_pre_request_without_frontend(self): context = Context() context.state = State() @@ -113,23 +125,24 @@ def test_pre_request_endpoint(self, context): context.http_headers = dict( HTTP_USER_AGENT="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.90 Safari/537.36" ) - pre_request_endpoint = self.backend.pre_request_endpoint( - context, internal_data) + pre_request_endpoint = self.backend.pre_request_endpoint(context, internal_data) assert pre_request_endpoint assert pre_request_endpoint.status == "200" assert pre_request_endpoint.message assert "src='data:image/svg+xml;base64," in pre_request_endpoint.message - soup = BeautifulSoup(pre_request_endpoint.message, 'html.parser') + soup = BeautifulSoup(pre_request_endpoint.message, "html.parser") # get the img tag with src attribute starting with data:image/svg+xml;base64, img_tag = soup.find( - lambda tag: tag.name == 'img' and tag.get('src', '').startswith('data:image/svg+xml;base64,')) + lambda tag: tag.name == "img" + and tag.get("src", "").startswith("data:image/svg+xml;base64,") + ) assert img_tag # get the src attribute - src = img_tag['src'] + src = img_tag["src"] # remove the data:image/svg+xml;base64, part - data = src.replace('data:image/svg+xml;base64,', '') + data = src.replace("data:image/svg+xml;base64,", "") # decode the base64 data base64.b64decode(data).decode("utf-8") @@ -143,25 +156,27 @@ def test_pre_request_endpoint_mobile(self, context): context.http_headers = dict( HTTP_USER_AGENT="Mozilla/5.0 (Linux; Android 10; SM-G960F) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.92 Mobile Safari/537.36" ) - pre_request_endpoint = self.backend.pre_request_endpoint( - context, internal_data) + pre_request_endpoint = self.backend.pre_request_endpoint(context, internal_data) assert pre_request_endpoint assert "302" in pre_request_endpoint.status - assert f"{CONFIG['authorization']['url_scheme']}://" in pre_request_endpoint.message + assert ( + f"{CONFIG['authorization']['url_scheme']}://" + in pre_request_endpoint.message + ) unquoted = urllib.parse.unquote( - pre_request_endpoint.message, encoding='utf-8', errors='replace') + pre_request_endpoint.message, encoding="utf-8", errors="replace" + ) parsed = urllib.parse.urlparse(unquoted) - assert parsed.scheme == CONFIG['authorization']['url_scheme'] + assert parsed.scheme == CONFIG["authorization"]["url_scheme"] assert parsed.path == "" assert parsed.query qs = urllib.parse.parse_qs(parsed.query) assert qs["client_id"][0] == CONFIG["metadata"]["client_id"] - assert qs["request_uri"][0].startswith( - CONFIG["metadata"]["request_uris"][0]) + assert qs["request_uri"][0].startswith(CONFIG["metadata"]["request_uris"][0]) # def test_vp_validation_in_response_endpoint(self, context): # TODO: re enable or delete the following commented @@ -205,7 +220,7 @@ def test_pre_request_endpoint_mobile(self, context): # vp_token = sdjwt_at_holder.sd_jwt_presentation # context.request_method = "POST" - # context.request_uri = CONFIG["metadata"]["response_uris_supported"][0].removeprefix( + # context.request_uri = CONFIG["metadata"]["response_uris"][0].removeprefix( # CONFIG["base_url"]) # state = str(uuid.uuid4()) @@ -323,7 +338,7 @@ def test_pre_request_endpoint_mobile(self, context): # vp_token_bad_nonce = sdjwt_at_holder.sd_jwt_presentation # context.request_method = "POST" - # context.request_uri = CONFIG["metadata"]["response_uris_supported"][0].removeprefix( + # context.request_uri = CONFIG["metadata"]["response_uris"][0].removeprefix( # CONFIG["base_url"]) # response_with_bad_nonce = { @@ -487,33 +502,29 @@ def test_request_endpoint(self, context): context.http_headers = dict( HTTP_USER_AGENT="Mozilla/5.0 (Linux; Android 10; SM-G960F) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.92 Mobile Safari/537.36" ) - pre_request_endpoint = self.backend.pre_request_endpoint( - context, internal_data - ) - state = urllib.parse.unquote( - pre_request_endpoint.message).split("=")[-1] + pre_request_endpoint = self.backend.pre_request_endpoint(context, internal_data) + state = urllib.parse.unquote(pre_request_endpoint.message).split("=")[-1] jwshelper = JWSHelper(PRIVATE_JWK) wia = jwshelper.sign( plain_dict=WALLET_INSTANCE_ATTESTATION, protected={ - 'trust_chain': trust_chain_wallet, - 'x5c': [], - } + "trust_chain": trust_chain_wallet, + "x5c": [], + }, ) dpop_wia = wia dpop_proof = DPoPIssuer( - htu=CONFIG['metadata']['request_uris'][0], + htu=CONFIG["metadata"]["request_uris"][0], token=dpop_wia, - private_jwk=PRIVATE_JWK + private_jwk=PRIVATE_JWK, ).proof context.http_headers = dict( - HTTP_AUTHORIZATION=f"DPoP {dpop_wia}", - HTTP_DPOP=dpop_proof + HTTP_AUTHORIZATION=f"DPoP {dpop_wia}", HTTP_DPOP=dpop_proof ) context.qs_params = {"id": state} @@ -521,26 +532,30 @@ def test_request_endpoint(self, context): # put a trust attestation related itself into the storage # this then is used as trust_chain header parameter in the signed # request object - db_engine_inst = DBEngine(CONFIG['storage']) + db_engine_inst = DBEngine(CONFIG["storage"]) + + _fedback: TrustHandlerInterface = self.backend.get_trust_backend_by_class_name( + "FederationHandler" + ) + assert _fedback _es = { "exp": EXP, "iat": NOW, "iss": "https://trust-anchor.example.org", "sub": self.backend.client_id, - 'jwks': self.backend.entity_configuration_as_dict['jwks'] + "jwks": _fedback.entity_configuration_as_dict["jwks"], } - ta_signer = JWS(_es, alg="ES256", - typ="application/entity-statement+jwt") + ta_signer = JWS(_es, alg="ES256", typ="application/entity-statement+jwt") its_trust_chain = [ - self.backend.entity_configuration, - ta_signer.sign_compact([ta_jwk]) + _fedback.entity_configuration, + ta_signer.sign_compact([ta_jwk]), ] db_engine_inst.add_or_update_trust_attestation( entity_id=self.backend.client_id, attestation=its_trust_chain, - exp=datetime.datetime.now().isoformat() + exp=datetime.datetime.now().isoformat(), ) # End RP trust chain @@ -556,7 +571,7 @@ def test_request_endpoint(self, context): context.request_method = "GET" context.qs_params = {"id": state} - request_uri = CONFIG['metadata']['request_uris'][0] + request_uri = CONFIG["metadata"]["request_uris"][0] context.request_uri = request_uri req_resp = self.backend.request_endpoint(context) @@ -565,19 +580,25 @@ def test_request_endpoint(self, context): map( lambda header_name_value_pair: header_name_value_pair[1], filter( - lambda header_name_value_pair: header_name_value_pair[0].lower( - ) == "content-type", - req_resp.headers - ) + lambda header_name_value_pair: header_name_value_pair[0].lower() + == "content-type", + req_resp.headers, + ), ) ) assert req_resp - assert req_resp.status == "200", f"invalid status in request object response {req_resp_str}" - assert len( - obtained_content_types) > 0, f"missing Content-Type in request object response {req_resp_str}" - assert obtained_content_types[ - 0] == "application/oauth-authz-req+jwt", f"invalid Content-Type in request object response {req_resp_str}" - assert req_resp.message, f"invalid message in request object response {req_resp_str}" + assert ( + req_resp.status == "200" + ), f"invalid status in request object response {req_resp_str}" + assert ( + len(obtained_content_types) > 0 + ), f"missing Content-Type in request object response {req_resp_str}" + assert ( + obtained_content_types[0] == "application/oauth-authz-req+jwt" + ), f"invalid Content-Type in request object response {req_resp_str}" + assert ( + req_resp.message + ), f"invalid message in request object response {req_resp_str}" request_object_jwt = req_resp.message header = decode_jwt_header(request_object_jwt) @@ -587,11 +608,13 @@ def test_request_endpoint(self, context): assert header["typ"] == "oauth-authz-req+jwt" assert payload["scope"] == " ".join(CONFIG["authorization"]["scopes"]) assert payload["client_id"] == CONFIG["metadata"]["client_id"] - assert payload["response_uri"] == CONFIG["metadata"]["response_uris_supported"][0] + assert ( + payload["response_uri"] == CONFIG["metadata"]["response_uris"][0] + ) datetime_mock = Mock(wraps=datetime.datetime) datetime_mock.now.return_value = datetime.datetime(2999, 1, 1) - with patch('datetime.datetime', new=datetime_mock): + with patch("datetime.datetime", new=datetime_mock): self.backend.status_endpoint(context) state_endpoint_response = self.backend.status_endpoint(context) assert state_endpoint_response.status == "403" @@ -608,8 +631,7 @@ def test_request_endpoint(self, context): def test_handle_error(self, context): error_message = "server_error" - error_resp = self.backend._handle_500( - context, error_message, Exception()) + error_resp = self.backend._handle_500(context, error_message, Exception()) assert error_resp.status == "500" assert error_resp.message err = json.loads(error_resp.message) diff --git a/pyeudiw/tests/satosa/test_backend_trust.py b/pyeudiw/tests/satosa/test_backend_trust.py index d29f3990..3f057676 100644 --- a/pyeudiw/tests/satosa/test_backend_trust.py +++ b/pyeudiw/tests/satosa/test_backend_trust.py @@ -1,14 +1,9 @@ -import pytest - from unittest.mock import Mock -from pyeudiw.satosa.backend import OpenID4VPBackend +import pytest -from pyeudiw.tests.settings import ( - BASE_URL, - CONFIG, - INTERNAL_ATTRIBUTES, -) +from pyeudiw.satosa.backend import OpenID4VPBackend +from pyeudiw.tests.settings import BASE_URL, CONFIG, INTERNAL_ATTRIBUTES class TestOpenID4VPBackend: @@ -16,11 +11,7 @@ class TestOpenID4VPBackend: @pytest.fixture(autouse=True) def setup_direct_trust(self): self.backend = OpenID4VPBackend( - Mock(), - INTERNAL_ATTRIBUTES, - CONFIG, - BASE_URL, - "name" + Mock(), INTERNAL_ATTRIBUTES, CONFIG, BASE_URL, "name" ) def test_response_endpoint(self): diff --git a/pyeudiw/tests/satosa/test_dynamic_backed.py b/pyeudiw/tests/satosa/test_dynamic_backed.py index ff311106..29b359ea 100644 --- a/pyeudiw/tests/satosa/test_dynamic_backed.py +++ b/pyeudiw/tests/satosa/test_dynamic_backed.py @@ -1,47 +1,47 @@ import json -from pyeudiw.satosa.backend import OpenID4VPBackend +from unittest.mock import Mock + from satosa.context import Context +from satosa.response import Redirect from satosa.state import State -from pyeudiw.tests.settings import ( - CONFIG, - BASE_URL, - INTERNAL_ATTRIBUTES -) - -from satosa.response import Redirect -from pyeudiw.satosa.utils.response import JsonResponse +from pyeudiw.satosa.backend import OpenID4VPBackend from pyeudiw.satosa.interfaces.request_handler import RequestHandlerInterface from pyeudiw.satosa.interfaces.response_handler import ResponseHandlerInterface - -from unittest.mock import Mock +from pyeudiw.satosa.utils.response import JsonResponse +from pyeudiw.tests.settings import BASE_URL, CONFIG, INTERNAL_ATTRIBUTES class RequestHandler(RequestHandlerInterface): - def request_endpoint(self, context: Context, *args: tuple) -> Redirect | JsonResponse: - return self._handle_400(context, "Request endpoint not implemented.", NotImplementedError()) + def request_endpoint( + self, context: Context, *args: tuple + ) -> Redirect | JsonResponse: + return self._handle_400( + context, "Request endpoint not implemented.", NotImplementedError() + ) class ResponseHandler(ResponseHandlerInterface): def response_endpoint(self, context: Context, *args) -> JsonResponse: - return self._handle_400(context, "Response endpoint not implemented.", NotImplementedError()) + return self._handle_400( + context, "Response endpoint not implemented.", NotImplementedError() + ) def test_dynamic_backend_creation(): CONFIG["endpoints"]["request"] = { "module": "pyeudiw.tests.satosa.test_dynamic_backed", "class": "RequestHandler", - "path": "/request_test" + "path": "/request_test", } CONFIG["endpoints"]["response"] = { "module": "pyeudiw.tests.satosa.test_dynamic_backed", "class": "ResponseHandler", - "path": "/response_test" + "path": "/response_test", } - backend = OpenID4VPBackend( - Mock(), INTERNAL_ATTRIBUTES, CONFIG, BASE_URL, "name") + backend = OpenID4VPBackend(Mock(), INTERNAL_ATTRIBUTES, CONFIG, BASE_URL, "name") handlers = backend.register_endpoints() published_endpoints = [handlers[i][0] for i in range(len(handlers))] @@ -53,11 +53,15 @@ def test_dynamic_backend_creation(): context.state = State() response = backend.request_endpoint(context) - assert response.status == '400' - assert json.loads(response.message)[ - 'error_description'] == "Request endpoint not implemented." + assert response.status == "400" + assert ( + json.loads(response.message)["error_description"] + == "Request endpoint not implemented." + ) response = backend.response_endpoint(context) - assert response.status == '400' - assert json.loads(response.message)[ - 'error_description'] == "Response endpoint not implemented." + assert response.status == "400" + assert ( + json.loads(response.message)["error_description"] + == "Response endpoint not implemented." + ) diff --git a/pyeudiw/tests/satosa/utils/test_respcode.py b/pyeudiw/tests/satosa/utils/test_respcode.py index 1cb7260a..3e8ad44b 100644 --- a/pyeudiw/tests/satosa/utils/test_respcode.py +++ b/pyeudiw/tests/satosa/utils/test_respcode.py @@ -42,8 +42,7 @@ def test_valid_code(self): def test_invalid_code(self): try: - self.respose_code_helper.create_code( - "this_is_an_invalid_response_code") + self.respose_code_helper.create_code("this_is_an_invalid_response_code") assert False except Exception: assert True diff --git a/pyeudiw/tests/sd_jwt/conftest.py b/pyeudiw/tests/sd_jwt/conftest.py index 84b50f75..b4073cb7 100644 --- a/pyeudiw/tests/sd_jwt/conftest.py +++ b/pyeudiw/tests/sd_jwt/conftest.py @@ -1,8 +1,9 @@ from pathlib import Path + import pytest -from pyeudiw.sd_jwt.utils.yaml_specification import load_yaml_specification from pyeudiw.sd_jwt.utils.demo_utils import load_yaml_settings +from pyeudiw.sd_jwt.utils.yaml_specification import load_yaml_specification tc_basedir = Path(__file__).parent / "testcases" @@ -13,7 +14,9 @@ def pytest_generate_tests(metafunc): if "testcase" in metafunc.fixturenames: testcases = list(tc_basedir.glob("*/specification.yml")) metafunc.parametrize( - "testcase", [load_yaml_specification(t) for t in testcases], ids=[t.parent.name for t in testcases] + "testcase", + [load_yaml_specification(t) for t in testcases], + ids=[t.parent.name for t in testcases], ) diff --git a/pyeudiw/tests/sd_jwt/test_disclose_all_shortcut.py b/pyeudiw/tests/sd_jwt/test_disclose_all_shortcut.py index 9e7e496f..3494d95e 100644 --- a/pyeudiw/tests/sd_jwt/test_disclose_all_shortcut.py +++ b/pyeudiw/tests/sd_jwt/test_disclose_all_shortcut.py @@ -2,8 +2,8 @@ from pyeudiw.sd_jwt.issuer import SDJWTIssuer from pyeudiw.sd_jwt.utils.demo_utils import get_jwk -from pyeudiw.sd_jwt.verifier import SDJWTVerifier from pyeudiw.sd_jwt.utils.yaml_specification import remove_sdobj_wrappers +from pyeudiw.sd_jwt.verifier import SDJWTVerifier def test_e2e(testcase, settings): @@ -25,8 +25,7 @@ def test_e2e(testcase, settings): sdjwt_at_issuer = SDJWTIssuer( user_claims, demo_keys["issuer_keys"], - demo_keys["holder_key"] if testcase.get( - "key_binding", False) else None, + demo_keys["holder_key"] if testcase.get("key_binding", False) else None, add_decoy_claims=use_decoys, serialization_format=serialization_format, extra_header_parameters=extra_header_parameters, @@ -75,7 +74,7 @@ def cb_get_issuer_key(issuer, header_parameters): expected_header_parameters = { "alg": testcase.get("sign_alg", "ES256"), - "typ": "testcase+sd-jwt" + "typ": "testcase+sd-jwt", } expected_header_parameters.update(extra_header_parameters) diff --git a/pyeudiw/tests/sd_jwt/test_e2e_testcases.py b/pyeudiw/tests/sd_jwt/test_e2e_testcases.py index 6571ec42..b7e290eb 100644 --- a/pyeudiw/tests/sd_jwt/test_e2e_testcases.py +++ b/pyeudiw/tests/sd_jwt/test_e2e_testcases.py @@ -1,8 +1,9 @@ +from cryptojwt.jwk.jwk import key_from_jwk_dict + from pyeudiw.sd_jwt.holder import SDJWTHolder from pyeudiw.sd_jwt.issuer import SDJWTIssuer from pyeudiw.sd_jwt.utils.demo_utils import get_jwk from pyeudiw.sd_jwt.verifier import SDJWTVerifier -from cryptojwt.jwk.jwk import key_from_jwk_dict def test_e2e(testcase, settings): @@ -25,8 +26,7 @@ def test_e2e(testcase, settings): sdjwt_at_issuer = SDJWTIssuer( user_claims, demo_keys["issuer_keys"], - demo_keys["holder_key"] if testcase.get( - "key_binding", False) else None, + demo_keys["holder_key"] if testcase.get("key_binding", False) else None, add_decoy_claims=use_decoys, serialization_format=serialization_format, extra_header_parameters=extra_header_parameters, @@ -42,15 +42,13 @@ def test_e2e(testcase, settings): sdjwt_at_holder.create_presentation( testcase["holder_disclosed_claims"], - settings["key_binding_nonce"] if testcase.get( - "key_binding", False) else None, + settings["key_binding_nonce"] if testcase.get("key_binding", False) else None, ( settings["identifiers"]["verifier"] if testcase.get("key_binding", False) else None ), - demo_keys["holder_key"] if testcase.get( - "key_binding", False) else None, + demo_keys["holder_key"] if testcase.get("key_binding", False) else None, ) output_holder = sdjwt_at_holder.sd_jwt_presentation @@ -60,8 +58,8 @@ def test_e2e(testcase, settings): def cb_get_issuer_key(issuer, header_parameters): if isinstance(header_parameters, dict): - if 'kid' in header_parameters: - header_parameters.pop('kid') + if "kid" in header_parameters: + header_parameters.pop("kid") sdjwt_header_parameters.update(header_parameters) return demo_keys["issuer_public_keys"] @@ -73,8 +71,7 @@ def cb_get_issuer_key(issuer, header_parameters): if testcase.get("key_binding", False) else None ), - settings["key_binding_nonce"] if testcase.get( - "key_binding", False) else None, + settings["key_binding_nonce"] if testcase.get("key_binding", False) else None, serialization_format=serialization_format, ) @@ -88,7 +85,9 @@ def cb_get_issuer_key(issuer, header_parameters): "jwk": key_from_jwk_dict(demo_keys["holder_key"], private=False).serialize() } - assert verified == expected_claims, f"Verified payload mismatch: {verified} != {expected_claims}" + assert ( + verified == expected_claims + ), f"Verified payload mismatch: {verified} != {expected_claims}" # We don't compare header parameters for JSON Serialization for now if serialization_format == "compact": @@ -99,6 +98,6 @@ def cb_get_issuer_key(issuer, header_parameters): expected_header_parameters.update(extra_header_parameters) # Assert degli header JWS - assert sdjwt_header_parameters == expected_header_parameters, ( - f"Header parameters mismatch: {sdjwt_header_parameters} != {expected_header_parameters}" - ) + assert ( + sdjwt_header_parameters == expected_header_parameters + ), f"Header parameters mismatch: {sdjwt_header_parameters} != {expected_header_parameters}" diff --git a/pyeudiw/tests/sd_jwt/test_sdjwt_schema.py b/pyeudiw/tests/sd_jwt/test_sdjwt_schema.py index fc0cbb1a..539f5bfc 100644 --- a/pyeudiw/tests/sd_jwt/test_sdjwt_schema.py +++ b/pyeudiw/tests/sd_jwt/test_sdjwt_schema.py @@ -14,67 +14,53 @@ class TestCase: TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~WyIyR0xDNDJzS1F2ZUNmR2ZyeU5STjl3IiwgImdpdmVuX25hbWUiLCAiSm9obiJd~WyJlbHVWNU9nM2dTTklJOEVZbnN4QV9BIiwgImZhbWlseV9uYW1lIiwgIkRvZSJd~WyI2SWo3dE0tYTVpVlBHYm9TNXRtdlZBIiwgImVtYWlsIiwgImpvaG5kb2VAZXhhbXBsZS5jb20iXQ~WyJlSThaV205UW5LUHBOUGVOZW5IZGhRIiwgInBob25lX251bWJlciIsICIrMS0yMDItNTU1LTAxMDEiXQ~WyJRZ19PNjR6cUF4ZTQxMmExMDhpcm9BIiwgInBob25lX251bWJlcl92ZXJpZmllZCIsIHRydWVd~WyJBSngtMDk1VlBycFR0TjRRTU9xUk9BIiwgImFkZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIjEyMyBNYWluIFN0IiwgImxvY2FsaXR5IjogIkFueXRvd24iLCAicmVnaW9uIjogIkFueXN0YXRlIiwgImNvdW50cnkiOiAiVVMifV0~WyJQYzMzSk0yTGNoY1VfbEhnZ3ZfdWZRIiwgImJpcnRoZGF0ZSIsICIxOTQwLTAxLTAxIl0~WyJHMDJOU3JRZmpGWFE3SW8wOXN5YWpBIiwgInVwZGF0ZWRfYXQiLCAxNTcwMDAwMDAwXQ~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgIlVTIl0~WyJuUHVvUW5rUkZxM0JJZUFtN0FuWEZBIiwgIkRFIl0~", True, - "sd-jwt without key binding" + "sd-jwt without key binding", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~WyJlbHVWNU9nM2dTTklJOEVZbnN4QV9BIiwgImZhbWlseV9uYW1lIiwgIkRvZSJd~WyJBSngtMDk1VlBycFR0TjRRTU9xUk9BIiwgImFkZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIjEyMyBNYWluIFN0IiwgImxvY2FsaXR5IjogIkFueXRvd24iLCAicmVnaW9uIjogIkFueXN0YXRlIiwgImNvdW50cnkiOiAiVVMifV0~WyIyR0xDNDJzS1F2ZUNmR2ZyeU5STjl3IiwgImdpdmVuX25hbWUiLCAiSm9obiJd~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgIlVTIl0~eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImtiK2p3dCJ9.eyJub25jZSI6ICIxMjM0NTY3ODkwIiwgImF1ZCI6ICJodHRwczovL3ZlcmlmaWVyLmV4YW1wbGUub3JnIiwgImlhdCI6IDE3MjA0MzU5MzYsICJzZF9oYXNoIjogIjNvc0U3T2I0VHJqcW80ckJHdkRfMWJkQ0lhTEY4cFB5Wlk3RUctcVBYN1UifQ.MBXm8GktzZw6GKCD_X2zUqec9oWXHGD0K77HzKLehoOsIVUI8dbP_ruNWDX8nfZ6G3s06iFlxRNLO76nGo51Mw", True, - "sd-jwt with key binding" + "sd-jwt with key binding", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~", True, - "sd-jwt with 0 disclosures" + "sd-jwt with 0 disclosures", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw", False, - "sd-jwt with missing disclosure symbol ~" + "sd-jwt with missing disclosure symbol ~", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.~", False, - "sd-jwt without signature" + "sd-jwt without signature", ), TestCase( ".eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~", False, - "sd-jwt without header" + "sd-jwt without header", ), + TestCase("qwertyuiop", False, "non sense string: not a jwt"), + TestCase("qwe.rty.uio.azs~", False, "too many dots in base jwt"), + TestCase("qwe.rty~", False, "too few dots in base jwt"), TestCase( - "qwertyuiop", - False, - "non sense string: not a jwt" - ), - TestCase( - "qwe.rty.uio.azs~", - False, - "too many dots in base jwt" - ), - TestCase( - "qwe.rty~", - False, - "too few dots in base jwt" - ), - TestCase( - "qwe.rty~asd~", - False, - "too few dots in base jwt and diclosure lookalike" + "qwe.rty~asd~", False, "too few dots in base jwt and diclosure lookalike" ), TestCase( - "qwe.rty.uio~asd~qwe", - False, - "sd-jwt with key binding that is not a jwt" + "qwe.rty.uio~asd~qwe", False, "sd-jwt with key binding that is not a jwt" ), TestCase( "qwe.rty.uio~asd~qwe.asd.", False, - "sd-jwt with key binding without signature" + "sd-jwt with key binding without signature", ), ] for i, case in enumerate(test_table): obt_result = is_sd_jwt_format(case.input) - assert obt_result == case.expected_result, f"failed test case {i}: scenario: {case.explanation}" + assert ( + obt_result == case.expected_result + ), f"failed test case {i}: scenario: {case.explanation}" def test_is_sd_jwt_kb_format(): @@ -88,79 +74,61 @@ class TestCase: TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~WyJlbHVWNU9nM2dTTklJOEVZbnN4QV9BIiwgImZhbWlseV9uYW1lIiwgIkRvZSJd~WyJBSngtMDk1VlBycFR0TjRRTU9xUk9BIiwgImFkZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIjEyMyBNYWluIFN0IiwgImxvY2FsaXR5IjogIkFueXRvd24iLCAicmVnaW9uIjogIkFueXN0YXRlIiwgImNvdW50cnkiOiAiVVMifV0~WyIyR0xDNDJzS1F2ZUNmR2ZyeU5STjl3IiwgImdpdmVuX25hbWUiLCAiSm9obiJd~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgIlVTIl0~eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImtiK2p3dCJ9.eyJub25jZSI6ICIxMjM0NTY3ODkwIiwgImF1ZCI6ICJodHRwczovL3ZlcmlmaWVyLmV4YW1wbGUub3JnIiwgImlhdCI6IDE3MjA0MzU5MzYsICJzZF9oYXNoIjogIjNvc0U3T2I0VHJqcW80ckJHdkRfMWJkQ0lhTEY4cFB5Wlk3RUctcVBYN1UifQ.MBXm8GktzZw6GKCD_X2zUqec9oWXHGD0K77HzKLehoOsIVUI8dbP_ruNWDX8nfZ6G3s06iFlxRNLO76nGo51Mw", True, - "sd-jwt with key binding" + "sd-jwt with key binding", ), TestCase( - "qwe.rty.uio~asd.fgh.jkl", - True, - "sd-jwt lookalike with 0 disclosures" + "qwe.rty.uio~asd.fgh.jkl", True, "sd-jwt lookalike with 0 disclosures" ), TestCase( - "qwe.rty.uio~zxc~asd.fgh.jkl", - True, - "sd-jwt lookalike with 1 disclosure" + "qwe.rty.uio~zxc~asd.fgh.jkl", True, "sd-jwt lookalike with 1 disclosure" ), TestCase( "qwe.rty.uio~zxc~vbn~asd.fgh.jkl", True, - "sd-jwt lookalike with 2 disclosures" + "sd-jwt lookalike with 2 disclosures", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~WyIyR0xDNDJzS1F2ZUNmR2ZyeU5STjl3IiwgImdpdmVuX25hbWUiLCAiSm9obiJd~WyJlbHVWNU9nM2dTTklJOEVZbnN4QV9BIiwgImZhbWlseV9uYW1lIiwgIkRvZSJd~WyI2SWo3dE0tYTVpVlBHYm9TNXRtdlZBIiwgImVtYWlsIiwgImpvaG5kb2VAZXhhbXBsZS5jb20iXQ~WyJlSThaV205UW5LUHBOUGVOZW5IZGhRIiwgInBob25lX251bWJlciIsICIrMS0yMDItNTU1LTAxMDEiXQ~WyJRZ19PNjR6cUF4ZTQxMmExMDhpcm9BIiwgInBob25lX251bWJlcl92ZXJpZmllZCIsIHRydWVd~WyJBSngtMDk1VlBycFR0TjRRTU9xUk9BIiwgImFkZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIjEyMyBNYWluIFN0IiwgImxvY2FsaXR5IjogIkFueXRvd24iLCAicmVnaW9uIjogIkFueXN0YXRlIiwgImNvdW50cnkiOiAiVVMifV0~WyJQYzMzSk0yTGNoY1VfbEhnZ3ZfdWZRIiwgImJpcnRoZGF0ZSIsICIxOTQwLTAxLTAxIl0~WyJHMDJOU3JRZmpGWFE3SW8wOXN5YWpBIiwgInVwZGF0ZWRfYXQiLCAxNTcwMDAwMDAwXQ~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgIlVTIl0~WyJuUHVvUW5rUkZxM0JJZUFtN0FuWEZBIiwgIkRFIl0~", False, - "sd-jwt without key binding" + "sd-jwt without key binding", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw", False, - "sd-jwt with missing disclosure symbol ~" + "sd-jwt with missing disclosure symbol ~", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~", False, - "sd-jwt with 0 disclosures" + "sd-jwt with 0 disclosures", ), TestCase( "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.~", False, - "sd-jwt without signature" + "sd-jwt without signature", ), TestCase( ".eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.GhUbWZvUw7I6ndaIK6rtU1rDEOkjqFuSXl_J5cuDkmnKt9jhZ3v5sghdwZ1WKH-g-4kna57RgWEl7YHlr-IZqw~", False, - "sd-jwt without header" - ), - TestCase( - "qwertyuiop", - False, - "non sense string: not a jwt" - ), - TestCase( - "qwe.rty.uio.azs~", - False, - "too many dots in base jwt" - ), - TestCase( - "qwe.rty~", - False, - "too few dots in base jwt" + "sd-jwt without header", ), + TestCase("qwertyuiop", False, "non sense string: not a jwt"), + TestCase("qwe.rty.uio.azs~", False, "too many dots in base jwt"), + TestCase("qwe.rty~", False, "too few dots in base jwt"), TestCase( - "qwe.rty~asd~", - False, - "too few dots in base jwt and diclosure lookalike" + "qwe.rty~asd~", False, "too few dots in base jwt and diclosure lookalike" ), TestCase( - "qwe.rty.uio~asd~qwe", - False, - "sd-jwt with key binding that is not a jwt" + "qwe.rty.uio~asd~qwe", False, "sd-jwt with key binding that is not a jwt" ), TestCase( "qwe.rty.uio~asd~qwe.asd.", False, - "sd-jwt with key binding without signature" + "sd-jwt with key binding without signature", ), ] for i, case in enumerate(test_table): obt_result = is_sd_jwt_kb_format(case.input) - assert obt_result == case.expected_result, f"failed test case {i}: scenario: {case.explanation}" + assert ( + obt_result == case.expected_result + ), f"failed test case {i}: scenario: {case.explanation}" diff --git a/pyeudiw/tests/sd_jwt/test_utils_yaml_specification.py b/pyeudiw/tests/sd_jwt/test_utils_yaml_specification.py index 90d16d23..0c8c4065 100644 --- a/pyeudiw/tests/sd_jwt/test_utils_yaml_specification.py +++ b/pyeudiw/tests/sd_jwt/test_utils_yaml_specification.py @@ -1,8 +1,9 @@ -import pytest import io -from pyeudiw.sd_jwt.utils.yaml_specification import _yaml_load_specification +import pytest + from pyeudiw.sd_jwt.common import SDObj +from pyeudiw.sd_jwt.utils.yaml_specification import _yaml_load_specification YAML_TESTCASES = [ """ @@ -16,7 +17,7 @@ yaml_parsing: | Multiline text is also supported -""" +""", ] YAML_TESTCASES_EXPECTED = [ @@ -29,9 +30,7 @@ } } }, - { - "yaml_parsing": "Multiline text\nis also supported\n" - } + {"yaml_parsing": "Multiline text\nis also supported\n"}, ] diff --git a/pyeudiw/tests/settings.py b/pyeudiw/tests/settings.py index 274bb925..06be671c 100644 --- a/pyeudiw/tests/settings.py +++ b/pyeudiw/tests/settings.py @@ -1,9 +1,9 @@ -import pathlib import os +import pathlib -from pyeudiw.tools.utils import exp_from_now, iat_now from cryptojwt.jwk.ec import new_ec_key +from pyeudiw.tools.utils import exp_from_now, iat_now BASE_URL = "https://example.com" AUTHZ_PAGE = "example.com" @@ -15,18 +15,58 @@ "session": {"timeout": 1}, } +_METADATA = { + "application_type": "web", + "authorization_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], + "authorization_encrypted_response_enc": [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ], + "authorization_signed_response_alg": ["RS256", "ES256"], + "client_id": f"{BASE_URL}/OpenID4VP", + "client_name": "Name of an example organization", + "contacts": ["ops@verifier.example.org"], + "default_acr_values": [ + "https://www.spid.gov.it/SpidL2", + "https://www.spid.gov.it/SpidL3", + ], + "default_max_age": 1111, + "id_token_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], + "id_token_encrypted_response_enc": [ + "A128CBC-HS256", + "A192CBC-HS384", + "A256CBC-HS512", + "A128GCM", + "A192GCM", + "A256GCM", + ], + "id_token_signed_response_alg": ["RS256", "ES256"], + "response_uris": [f"{BASE_URL}/OpenID4VP/response-uri"], + "request_uris": [f"{BASE_URL}/OpenID4VP/request-uri"], + "require_auth_time": True, + "subject_type": "pairwise", + "vp_formats": { + "vc+sd-jwt": { + "sd-jwt_alg_values": ["ES256", "ES384"], + "kb-jwt_alg_values": ["ES256", "ES384"], + } + }, +} + CONFIG = { "base_url": BASE_URL, - "ui": { "static_storage_url": BASE_URL, "template_folder": f"{pathlib.Path().absolute().__str__()}/pyeudiw/tests/satosa/templates", "qrcode_template": "qrcode.html", "error_template": "error.html", - "error_url": "https://localhost:9999/error_page.html" + "error_url": "https://localhost:9999/error_page.html", }, "endpoints": { - "entity_configuration": "/.well-known/openid-federation", "pre_request": "/pre-request", "response": "/response-uri", "request": "/request-uri", @@ -40,12 +80,9 @@ "size": 100, "color": "#2B4375", "expiration_time": 120, - "logo_path": "pyeudiw/tests/satosa/static/logo.png" - }, - "jwt": { - "default_sig_alg": "ES256", - "default_exp": 6 + "logo_path": "pyeudiw/tests/satosa/static/logo.png", }, + "jwt": {"default_sig_alg": "ES256", "default_exp": 6}, "authorization": { "url_scheme": "haip", # haip:// "scopes": ["pid-sd-jwt:unique_id+given_name+family_name"], @@ -62,51 +99,34 @@ { "filter": { "const": "PersonIdentificationData", - "type": "string" + "type": "string", }, - "path": [ - "$.sd-jwt.type" - ] + "path": ["$.sd-jwt.type"], }, { - "filter": { - "type": "object" - }, - "path": [ - "$.sd-jwt.cnf" - ] + "filter": {"type": "object"}, + "path": ["$.sd-jwt.cnf"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.family_name" - ] + "path": ["$.sd-jwt.family_name"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.given_name" - ] + "path": ["$.sd-jwt.given_name"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.unique_id" - ] - } + "path": ["$.sd-jwt.unique_id"], + }, ], - "limit_disclosure": "required" + "limit_disclosure": "required", }, - "jwt": { - "alg": [ - "EdDSA", - "ES256" - ] - } + "jwt": {"alg": ["EdDSA", "ES256"]}, }, - "id": "sd-jwt" + "id": "sd-jwt", } - ] + ], }, { "id": "mDL-sample-req", @@ -118,62 +138,45 @@ { "filter": { "const": "org.iso.18013.5.1.mDL", - "type": "string" + "type": "string", }, - "path": [ - "$.mdoc.doctype" - ] + "path": ["$.mdoc.doctype"], }, { "filter": { "const": "org.iso.18013.5.1", - "type": "string" + "type": "string", }, - "path": [ - "$.mdoc.namespace" - ] + "path": ["$.mdoc.namespace"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.family_name" - ] + "path": ["$.mdoc.family_name"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.portrait" - ] + "path": ["$.mdoc.portrait"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.driving_privileges" - ] - } + "path": ["$.mdoc.driving_privileges"], + }, ], - "limit_disclosure": "required" + "limit_disclosure": "required", }, - "mso_mdoc": { - "alg": [ - "EdDSA", - "ES256" - ] - } + "mso_mdoc": {"alg": ["EdDSA", "ES256"]}, }, - "id": "mDL" + "id": "mDL", } - ] - } + ], + }, ], }, - 'user_attributes': { + "user_attributes": { "unique_identifiers": ["tax_id_code", "unique_id"], - "subject_id_random_value": "CHANGEME!" - }, - 'network': { - "httpc_params": httpc_params + "subject_id_random_value": "CHANGEME!", }, + "network": {"httpc_params": httpc_params}, "trust": { "direct_trust_sd_jwt_vc": { "module": "pyeudiw.trust.handler.direct_trust_sd_jwt_vc", @@ -181,26 +184,20 @@ "config": { "jwk_endpoint": "/.well-known/jwt-vc-issuer", "httpc_params": { - "connection": { - "ssl": True - }, - "session": { - "timeout": 6 - } - } - } + "connection": {"ssl": True}, + "session": {"timeout": 6}, + }, + }, }, "federation": { "module": "pyeudiw.trust.handler.federation", "class": "FederationHandler", "config": { - "metadata_type": "wallet_relying_party", - "authority_hints": [ - "https://trust-anchor.example.org" - ], - "trust_anchors": [ - "https://trust-anchor.example.org" - ], + "entity_configuration_exp": 600, + "metadata": _METADATA, + "metadata_type": "openid_credential_verifier", + "authority_hints": ["https://trust-anchor.example.org"], + "trust_anchors": ["https://trust-anchor.example.org"], "default_sig_alg": "RS256", "federation_jwks": [ { @@ -216,28 +213,26 @@ "p": "2zmGXIMCEHPphw778YjVTar1eycih6fFSJ4I4bl1iq167GqO0PjlOx6CZ1-OdBTVU7HfrYRiUK_BnGRdPDn-DQghwwkB79ZdHWL14wXnpB5y-boHz_LxvjsEqXtuQYcIkidOGaMG6" "8XNT1nM4F9a8UKFr5hHYT5_UIQSwsxlRQ0", "q": "2jMFt2iFrdaYabdXuB4QMboVjPvbLA-IVb6_0hSG_-EueGBvgcBxdFGIZaG6kqHqlB7qMsSzdptU0vn6IgmCZnX-Hlt6c5X7JB_q91PZMLTO01pbZ2Bk58GloalCHnw_mjPh0YPvi" - "H5jGoWM5RHyl_HDDMI-UeLkzP7ImxGizrM" + "H5jGoWM5RHyl_HDDMI-UeLkzP7ImxGizrM", }, { - 'kty': 'EC', - 'kid': 'xPFTWxeGHTVTaDlzGad0MKN5JmWOSnRqEjJCtvQpoyg', - 'crv': 'P-256', - 'x': 'EkMoe7qPLGMydWO_evC3AXEeXJlLQk9tNRkYcpp7xHo', - 'y': 'VLoHFl90D1SdTTjMvNf3WssWiCBXcU1lGNPbOmcCqdU', - 'd': 'oGzjgBbIYNL9opdJ_rDPnCJF89yN8yj8wegdkYfaxw0' - } - ], - "trust_marks": [ - "..." + "kty": "EC", + "kid": "xPFTWxeGHTVTaDlzGad0MKN5JmWOSnRqEjJCtvQpoyg", + "crv": "P-256", + "x": "EkMoe7qPLGMydWO_evC3AXEeXJlLQk9tNRkYcpp7xHo", + "y": "VLoHFl90D1SdTTjMvNf3WssWiCBXcU1lGNPbOmcCqdU", + "d": "oGzjgBbIYNL9opdJ_rDPnCJF89yN8yj8wegdkYfaxw0", + }, ], + "trust_marks": ["..."], "federation_entity_metadata": { "organization_name": "Example RP", "homepage_uri": "https://developers.italia.it", "policy_uri": "https://developers.italia.it/privacy-policy", "tos_uri": "https://developers.italia.it/privacy-policy", - "logo_uri": "https://developers.italia.it/assets/img/io-it-logo-white.svg" - } - } + "logo_uri": "https://developers.italia.it/assets/img/io-it-logo-white.svg", + }, + }, }, }, "metadata_jwks": [ @@ -247,7 +242,7 @@ "kid": "dDwPWXz5sCtczj7CJbqgPGJ2qQ83gZ9Sfs-tJyULi6s", "kty": "EC", "x": "TSO-KOqdnUj5SUuasdlRB2VVFSqtJOxuR5GftUTuBdk", - "y": "ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7" + "y": "ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7", }, { "kty": "RSA", @@ -263,8 +258,8 @@ "p": "2zmGXIMCEHPphw778YjVTar1eycih6fFSJ4I4bl1iq167GqO0PjlOx6CZ1-OdBTVU7HfrYRiUK_BnGRdPDn-DQghwwkB79ZdHWL14wXnpB5y-boHz_LxvjsEqXtuQYcIkidOGaMG68XNT" "1nM4F9a8UKFr5hHYT5_UIQSwsxlRQ0", "q": "2jMFt2iFrdaYabdXuB4QMboVjPvbLA-IVb6_0hSG_-EueGBvgcBxdFGIZaG6kqHqlB7qMsSzdptU0vn6IgmCZnX-Hlt6c5X7JB_q91PZMLTO01pbZ2Bk58GloalCHnw_mjPh0YPviH5jG" - "oWM5RHyl_HDDMI-UeLkzP7ImxGizrM" - } + "oWM5RHyl_HDDMI-UeLkzP7ImxGizrM", + }, ], "storage": { "mongo_db": { @@ -273,11 +268,9 @@ "class": "MongoCache", "init_params": { "url": f"mongodb://{os.getenv('PYEUDIW_MONGO_TEST_AUTH_INLINE', '')}localhost:27017/?timeoutMS=2000", - "conf": { - "db_name": "eudiw" - }, - "connection_params": {} - } + "conf": {"db_name": "eudiw"}, + "connection_params": {}, + }, }, "storage": { "module": "pyeudiw.storage.mongo_storage", @@ -289,78 +282,14 @@ "db_sessions_collection": "sessions", "db_trust_attestations_collection": "trust_attestations", "db_trust_anchors_collection": "trust_anchors", - "db_trust_sources_collection": "trust_sources" + "db_trust_sources_collection": "trust_sources", }, - "connection_params": {} - } - } + "connection_params": {}, + }, + }, } }, - "metadata": { - "application_type": "web", - "authorization_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], - "authorization_encrypted_response_enc": [ - "A128CBC-HS256", - "A192CBC-HS384", - "A256CBC-HS512", - "A128GCM", - "A192GCM", - "A256GCM" - ], - "authorization_signed_response_alg": [ - "RS256", - "ES256" - ], - "client_id": f"{BASE_URL}/OpenID4VP", - "client_name": "Name of an example organization", - "contacts": [ - "ops@verifier.example.org" - ], - "default_acr_values": [ - "https://www.spid.gov.it/SpidL2", - "https://www.spid.gov.it/SpidL3" - ], - "default_max_age": 1111, - "id_token_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], - "id_token_encrypted_response_enc": [ - "A128CBC-HS256", - "A192CBC-HS384", - "A256CBC-HS512", - "A128GCM", - "A192GCM", - "A256GCM" - ], - "id_token_signed_response_alg": [ - "RS256", - "ES256" - ], - "response_uris_supported": [ - f"{BASE_URL}/OpenID4VP/response-uri" - ], - "request_uris": [ - f"{BASE_URL}/OpenID4VP/request-uri" - ], - "require_auth_time": True, - "subject_type": "pairwise", - "vp_formats": { - "vc+sd-jwt": { - "sd-jwt_alg_values": [ - "ES256", - "ES384" - ], - "kb-jwt_alg_values": [ - "ES256", - "ES384" - ] - } - } - } + "metadata": _METADATA, } CREDENTIAL_ISSUER_ENTITY_ID = "https://issuer.example.com" @@ -370,26 +299,18 @@ "class": "DirectTrustSdJwtVc", "config": { "jwk_endpoint": "/.well-known/jwt-vc-issuer", - "httpc_params": { - "connection": { - "ssl": True - }, - "session": { - "timeout": 6 - } - } - } + "httpc_params": {"connection": {"ssl": True}, "session": {"timeout": 6}}, + }, } CONFIG_DIRECT_TRUST = { "base_url": BASE_URL, - "ui": { "static_storage_url": BASE_URL, "template_folder": f"{pathlib.Path().absolute().__str__()}/pyeudiw/tests/satosa/templates", "qrcode_template": "qrcode.html", "error_template": "error.html", - "error_url": "https://localhost:9999/error_page.html" + "error_url": "https://localhost:9999/error_page.html", }, "endpoints": { "entity_configuration": "/.well-known/openid-federation", @@ -397,7 +318,7 @@ "response": "/response-uri", "request": "/request-uri", "status": "/status-uri", - "get_response": "/get-response" + "get_response": "/get-response", }, "response_code": { "sym_key": "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" @@ -406,12 +327,9 @@ "size": 100, "color": "#2B4375", "expiration_time": 120, - "logo_path": "pyeudiw/tests/satosa/static/logo.png" - }, - "jwt": { - "default_sig_alg": "ES256", - "default_exp": 6 + "logo_path": "pyeudiw/tests/satosa/static/logo.png", }, + "jwt": {"default_sig_alg": "ES256", "default_exp": 6}, "authorization": { "url_scheme": "haip", # haip:// "scopes": ["pid-sd-jwt:unique_id+given_name+family_name"], @@ -428,51 +346,34 @@ { "filter": { "const": "PersonIdentificationData", - "type": "string" + "type": "string", }, - "path": [ - "$.sd-jwt.type" - ] + "path": ["$.sd-jwt.type"], }, { - "filter": { - "type": "object" - }, - "path": [ - "$.sd-jwt.cnf" - ] + "filter": {"type": "object"}, + "path": ["$.sd-jwt.cnf"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.family_name" - ] + "path": ["$.sd-jwt.family_name"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.given_name" - ] + "path": ["$.sd-jwt.given_name"], }, { "intent_to_retain": "true", - "path": [ - "$.sd-jwt.unique_id" - ] - } + "path": ["$.sd-jwt.unique_id"], + }, ], - "limit_disclosure": "required" + "limit_disclosure": "required", }, - "jwt": { - "alg": [ - "EdDSA", - "ES256" - ] - } + "jwt": {"alg": ["EdDSA", "ES256"]}, }, - "id": "sd-jwt" + "id": "sd-jwt", } - ] + ], }, { "id": "mDL-sample-req", @@ -484,65 +385,46 @@ { "filter": { "const": "org.iso.18013.5.1.mDL", - "type": "string" + "type": "string", }, - "path": [ - "$.mdoc.doctype" - ] + "path": ["$.mdoc.doctype"], }, { "filter": { "const": "org.iso.18013.5.1", - "type": "string" + "type": "string", }, - "path": [ - "$.mdoc.namespace" - ] + "path": ["$.mdoc.namespace"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.family_name" - ] + "path": ["$.mdoc.family_name"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.portrait" - ] + "path": ["$.mdoc.portrait"], }, { "intent_to_retain": "false", - "path": [ - "$.mdoc.driving_privileges" - ] - } + "path": ["$.mdoc.driving_privileges"], + }, ], - "limit_disclosure": "required" + "limit_disclosure": "required", }, - "mso_mdoc": { - "alg": [ - "EdDSA", - "ES256" - ] - } + "mso_mdoc": {"alg": ["EdDSA", "ES256"]}, }, - "id": "mDL" + "id": "mDL", } - ] - } - ] + ], + }, + ], }, - 'user_attributes': { + "user_attributes": { "unique_identifiers": ["tax_id_code", "unique_id"], - "subject_id_random_value": "CHANGEME!" - }, - 'network': { - "httpc_params": httpc_params - }, - "trust": { - "direct_trust_sd_jwt_vc": MODULE_DIRECT_TRUST_CONFIG + "subject_id_random_value": "CHANGEME!", }, + "network": {"httpc_params": httpc_params}, + "trust": {"direct_trust_sd_jwt_vc": MODULE_DIRECT_TRUST_CONFIG}, "metadata_jwks": [ { "crv": "P-256", @@ -550,7 +432,7 @@ "kid": "dDwPWXz5sCtczj7CJbqgPGJ2qQ83gZ9Sfs-tJyULi6s", "kty": "EC", "x": "TSO-KOqdnUj5SUuasdlRB2VVFSqtJOxuR5GftUTuBdk", - "y": "ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7" + "y": "ByWgQt1wGBSnF56jQqLdoO1xKUynMY-BHIDB3eXlR7", }, { "kty": "RSA", @@ -566,8 +448,8 @@ "p": "2zmGXIMCEHPphw778YjVTar1eycih6fFSJ4I4bl1iq167GqO0PjlOx6CZ1-OdBTVU7HfrYRiUK_BnGRdPDn-DQghwwkB79ZdHWL14wXnpB5y-boHz_LxvjsEqXtuQYcIkidOGaMG68XNT" "1nM4F9a8UKFr5hHYT5_UIQSwsxlRQ0", "q": "2jMFt2iFrdaYabdXuB4QMboVjPvbLA-IVb6_0hSG_-EueGBvgcBxdFGIZaG6kqHqlB7qMsSzdptU0vn6IgmCZnX-Hlt6c5X7JB_q91PZMLTO01pbZ2Bk58GloalCHnw_mjPh0YPviH5jG" - "oWM5RHyl_HDDMI-UeLkzP7ImxGizrM" - } + "oWM5RHyl_HDDMI-UeLkzP7ImxGizrM", + }, ], "storage": { "mongo_db": { @@ -576,11 +458,9 @@ "class": "MongoCache", "init_params": { "url": "mongodb://localhost:27017/?timeoutMS=2000", - "conf": { - "db_name": "eudiw" - }, - "connection_params": {} - } + "conf": {"db_name": "eudiw"}, + "connection_params": {}, + }, }, "storage": { "module": "pyeudiw.storage.mongo_storage", @@ -591,78 +471,54 @@ "db_name": "test-eudiw", "db_sessions_collection": "sessions", "db_trust_attestations_collection": "trust_attestations", - "db_trust_anchors_collection": "trust_anchors" + "db_trust_anchors_collection": "trust_anchors", }, - "connection_params": {} - } - } + "connection_params": {}, + }, + }, } }, "metadata": { "application_type": "web", - "authorization_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], + "authorization_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], "authorization_encrypted_response_enc": [ "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" - ], - "authorization_signed_response_alg": [ - "RS256", - "ES256" + "A256GCM", ], + "authorization_signed_response_alg": ["RS256", "ES256"], "client_id": f"{BASE_URL}/OpenID4VP", "client_name": "Name of an example organization", - "contacts": [ - "ops@verifier.example.org" - ], + "contacts": ["ops@verifier.example.org"], "default_acr_values": [ "https://www.spid.gov.it/SpidL2", - "https://www.spid.gov.it/SpidL3" + "https://www.spid.gov.it/SpidL3", ], "default_max_age": 1111, - "id_token_encrypted_response_alg": [ - "RSA-OAEP", - "RSA-OAEP-256" - ], + "id_token_encrypted_response_alg": ["RSA-OAEP", "RSA-OAEP-256"], "id_token_encrypted_response_enc": [ "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", "A128GCM", "A192GCM", - "A256GCM" - ], - "id_token_signed_response_alg": [ - "RS256", - "ES256" - ], - "response_uris_supported": [ - f"{BASE_URL}/OpenID4VP/response-uri" - ], - "request_uris": [ - f"{BASE_URL}/OpenID4VP/request-uri" + "A256GCM", ], + "id_token_signed_response_alg": ["RS256", "ES256"], + "response_uris": [f"{BASE_URL}/OpenID4VP/response-uri"], + "request_uris": [f"{BASE_URL}/OpenID4VP/request-uri"], "require_auth_time": True, "subject_type": "pairwise", "vp_formats": { "vc+sd-jwt": { - "sd-jwt_alg_values": [ - "ES256", - "ES384" - ], - "kb-jwt_alg_values": [ - "ES256", - "ES384" - ] + "sd-jwt_alg_values": ["ES256", "ES384"], + "kb-jwt_alg_values": ["ES256", "ES384"], } - } - } + }, + }, } CREDENTIAL_ISSUER_CONF = { @@ -685,12 +541,10 @@ } -INTERNAL_ATTRIBUTES: dict = { - 'attributes': {} -} +INTERNAL_ATTRIBUTES: dict = {"attributes": {}} -PRIVATE_JWK = new_ec_key('P-256') +PRIVATE_JWK = new_ec_key("P-256") PUBLIC_JWK = PRIVATE_JWK.serialize(private=False) @@ -702,26 +556,15 @@ "tos_uri": "https://wallet-provider.example.org/info_policy", "logo_uri": "https://wallet-provider.example.org/logo.svg", "aal": "https://wallet-provider.example.org/LoA/basic", - "cnf": - { - "jwk": PUBLIC_JWK - }, + "cnf": {"jwk": PUBLIC_JWK}, "authorization_endpoint": "haip:", - "response_types_supported": [ - "vp_token" - ], + "response_types_supported": ["vp_token"], "vp_formats_supported": { - "jwt_vp_json": { - "alg_values_supported": ["ES256"] - }, - "jwt_vc_json": { - "alg_values_supported": ["ES256"] - } + "jwt_vp_json": {"alg_values_supported": ["ES256"]}, + "jwt_vc_json": {"alg_values_supported": ["ES256"]}, }, - "request_object_signing_alg_values_supported": [ - "ES256" - ], + "request_object_signing_alg_values_supported": ["ES256"], "presentation_definition_uri_supported": False, "iat": iat_now(), - "exp": exp_from_now() + "exp": exp_from_now(), } diff --git a/pyeudiw/tests/storage/test_db_engine.py b/pyeudiw/tests/storage/test_db_engine.py index 65c74ae7..76fbacea 100755 --- a/pyeudiw/tests/storage/test_db_engine.py +++ b/pyeudiw/tests/storage/test_db_engine.py @@ -1,9 +1,10 @@ import uuid +from datetime import datetime + import pytest -from datetime import datetime -from pyeudiw.storage.db_engine import DBEngine from pyeudiw.storage.base_storage import TrustType +from pyeudiw.storage.db_engine import DBEngine from pyeudiw.storage.exceptions import StorageWriteError from pyeudiw.tests.settings import CONFIG @@ -11,7 +12,7 @@ class TestMongoDBEngine: @pytest.fixture(autouse=True) def create_engine_instance(self): - self.engine = DBEngine(CONFIG['storage']) + self.engine = DBEngine(CONFIG["storage"]) @pytest.fixture(autouse=True) def test_init_session(self): @@ -19,7 +20,8 @@ def test_init_session(self): session_id = str(uuid.uuid4()) document_id = self.engine.init_session( - session_id=session_id, state=state, remote_flow_typ="") + session_id=session_id, state=state, remote_flow_typ="" + ) assert document_id @@ -29,11 +31,15 @@ def test_init_session(self): def test_update_request_object(self): self.nonce = str(uuid.uuid4()) self.state = str(uuid.uuid4()) - request_object = {"request_object": "request_object", - "nonce": self.nonce, "state": self.state} + request_object = { + "request_object": "request_object", + "nonce": self.nonce, + "state": self.state, + } replica_count = self.engine.update_request_object( - self.document_id, request_object) + self.document_id, request_object + ) assert replica_count == 1 @@ -44,21 +50,20 @@ def test_update_request_object_with_unexistent_id_object(self): request_object = {"request_object": "request_object"} try: - self.engine.update_request_object( - unx_document_id, request_object) + self.engine.update_request_object(unx_document_id, request_object) except Exception: return def test_update_response_object(self): response_object = {"response_object": "response_object"} - self.engine.update_response_object( - self.nonce, self.state, response_object) + self.engine.update_response_object(self.nonce, self.state, response_object) def test_update_response_object_unexistent_id_object(self): response_object = {"response_object": "response_object"} try: self.engine.update_response_object( - str(uuid.uuid4()), str(uuid.uuid4()), response_object) + str(uuid.uuid4()), str(uuid.uuid4()), response_object + ) except Exception: return @@ -68,7 +73,8 @@ def test_insert_trusted_attestation_federation(self): date = datetime.now() replica_count = self.engine.add_trust_attestation( - self.federation_entity_id, ["a", "b", "c"], date) + self.federation_entity_id, ["a", "b", "c"], date + ) assert replica_count > 0 @@ -83,7 +89,8 @@ def test_insert_trusted_attestation_x509(self): date = datetime.now() replica_count = self.engine.add_trust_attestation( - self.x509_entity_id, ["a", "b", "c"], date, TrustType.X509) + self.x509_entity_id, ["a", "b", "c"], date, TrustType.X509 + ) assert replica_count > 0 @@ -96,7 +103,8 @@ def test_update_trusted_attestation_federation(self): date = datetime.now() replica_count = self.engine.update_trust_attestation( - self.federation_entity_id, ["a", "b", "d"], date) + self.federation_entity_id, ["a", "b", "d"], date + ) assert replica_count > 0 @@ -109,7 +117,8 @@ def test_update_trusted_attestation_x509(self): date = datetime.now() replica_count = self.engine.update_trust_attestation( - self.x509_entity_id, ["a", "b", "d"], date, TrustType.X509) + self.x509_entity_id, ["a", "b", "d"], date, TrustType.X509 + ) assert replica_count > 0 @@ -122,8 +131,7 @@ def test_update_unexistent_trusted_attestation(self): try: date = datetime.now() - self.engine.update_trust_attestation( - "12345", ["a", "b", "d"], date) + self.engine.update_trust_attestation("12345", ["a", "b", "d"], date) assert False @@ -132,20 +140,23 @@ def test_update_unexistent_trusted_attestation(self): def test_update_trusted_attestation_metadata(self): replica_count = self.engine.add_trust_attestation_metadata( - self.federation_entity_id, "test_metadata", {"metadata": {"data_type": "test"}}) + self.federation_entity_id, + "test_metadata", + {"metadata": {"data_type": "test"}}, + ) assert replica_count > 0 ta = self.engine.get_trust_attestation(self.federation_entity_id) assert ta.get("metadata", None) is not None - assert ta["metadata"]["test_metadata"] == { - "metadata": {"data_type": "test"}} + assert ta["metadata"]["test_metadata"] == {"metadata": {"data_type": "test"}} def test_update_unexistent_trusted_attestation_metadata(self): try: self.engine.add_trust_attestation_metadata( - "test", "test_metadata", {"metadata": {"data_type": "test"}}) + "test", "test_metadata", {"metadata": {"data_type": "test"}} + ) assert False except StorageWriteError: return @@ -156,7 +167,8 @@ def test_insert_trusted_anchor_federation(self): date = datetime.now() replica_count = self.engine.add_trust_anchor( - self.federation_entity_anchor_id, "test123", date) + self.federation_entity_anchor_id, "test123", date + ) assert replica_count > 0 @@ -171,7 +183,8 @@ def test_insert_trusted_anchor_x509(self): date = datetime.now() replica_count = self.engine.add_trust_anchor( - self.x509_entity_anchor_id, "test123", date, TrustType.X509) + self.x509_entity_anchor_id, "test123", date, TrustType.X509 + ) assert replica_count > 0 @@ -184,7 +197,8 @@ def test_update_trusted_anchor_federation(self): date = datetime.now() replica_count = self.engine.update_trust_anchor( - self.federation_entity_anchor_id, "test124", date) + self.federation_entity_anchor_id, "test124", date + ) assert replica_count > 0 @@ -197,7 +211,8 @@ def test_update_trusted_anchor_x509(self): date = datetime.now() replica_count = self.engine.update_trust_anchor( - self.x509_entity_anchor_id, "test124", date, TrustType.X509) + self.x509_entity_anchor_id, "test124", date, TrustType.X509 + ) assert replica_count > 0 @@ -210,8 +225,7 @@ def test_update_unexistent_trusted_anchor(self): try: date = datetime.now() - self.engine.update_trust_anchor( - "12345", "test124", date, TrustType.X509) + self.engine.update_trust_anchor("12345", "test124", date, TrustType.X509) assert False diff --git a/pyeudiw/tests/storage/test_mongo_cache.py b/pyeudiw/tests/storage/test_mongo_cache.py index 8a88f720..42830501 100644 --- a/pyeudiw/tests/storage/test_mongo_cache.py +++ b/pyeudiw/tests/storage/test_mongo_cache.py @@ -12,7 +12,7 @@ def create_storage_instance(self): self.cache = MongoCache( {"db_name": "eudiw"}, f"mongodb://{os.getenv('PYEUDIW_MONGO_TEST_AUTH_INLINE', '')}localhost:27017/?timeoutMS=2000", - {} + {}, ) def test_try_retrieve(self): diff --git a/pyeudiw/tests/storage/test_mongo_storage.py b/pyeudiw/tests/storage/test_mongo_storage.py index 565f17c3..4edd4d98 100644 --- a/pyeudiw/tests/storage/test_mongo_storage.py +++ b/pyeudiw/tests/storage/test_mongo_storage.py @@ -1,5 +1,6 @@ import os import uuid + import pytest from pyeudiw.storage.mongo_storage import MongoStorage @@ -14,10 +15,10 @@ def create_storage_instance(self): "db_sessions_collection": "sessions", "db_trust_attestations_collection": "trust_attestations", "db_trust_anchors_collection": "trust_anchors", - "db_trust_sources_collection": "trust_source" + "db_trust_sources_collection": "trust_source", }, f"mongodb://{os.getenv('PYEUDIW_MONGO_TEST_AUTH_INLINE', '')}localhost:27017/?timeoutMS=2000", - {} + {}, ) def test_mongo_connection(self): @@ -34,16 +35,16 @@ def test_entity_initialization(self): session_id = str(uuid.uuid4()) document_id = self.storage.init_session( - str(uuid.uuid4()), - session_id=session_id, state=state, - remote_flow_typ="") + str(uuid.uuid4()), session_id=session_id, state=state, remote_flow_typ="" + ) assert document_id dpop_proof = {"dpop": "test"} attestation = {"attestation": "test"} self.storage.add_dpop_proof_and_attestation( - document_id, dpop_proof=dpop_proof, attestation=attestation) + document_id, dpop_proof=dpop_proof, attestation=attestation + ) document = self.storage.get_by_id(document_id) @@ -58,8 +59,8 @@ def test_add_request_object(self): session_id = str(uuid.uuid4()) document_id = self.storage.init_session( - str(uuid.uuid4()), - session_id=session_id, state=state, remote_flow_typ="") + str(uuid.uuid4()), session_id=session_id, state=state, remote_flow_typ="" + ) assert document_id @@ -82,8 +83,8 @@ def test_update_response_object(self): session_id = str(uuid.uuid4()) document_id = self.storage.init_session( - str(uuid.uuid4()), - session_id=session_id, state=state, remote_flow_typ="") + str(uuid.uuid4()), session_id=session_id, state=state, remote_flow_typ="" + ) assert document_id @@ -91,12 +92,15 @@ def test_update_response_object(self): state = str(uuid.uuid4()) request_object = {"nonce": nonce, "state": state} - self.storage.update_request_object( - document_id, request_object) + self.storage.update_request_object(document_id, request_object) documentStatus = self.storage.update_response_object( - nonce, state, {"response": "test"}) + nonce, state, {"response": "test"} + ) self.storage.add_dpop_proof_and_attestation( - document_id, dpop_proof={"dpop": "test"}, attestation={"attestation": "test"}) + document_id, + dpop_proof={"dpop": "test"}, + attestation={"attestation": "test"}, + ) assert documentStatus document = self.storage.get_by_id(document_id) @@ -114,22 +118,22 @@ def test_update_response_object(self): assert document["internal_response"] == {"response": "test"} # def test_retention_ttl(self): - # """ - # MongoDB does not garantee that the document will be deleted at the exact time - # https://www.mongodb.com/docs/v7.0/core/index-ttl/#timing-of-the-delete-operation - # """ - # self.storage.set_session_retention_ttl(5) - # assert self.storage.has_session_retention_ttl() + # """ + # MongoDB does not garantee that the document will be deleted at the exact time + # https://www.mongodb.com/docs/v7.0/core/index-ttl/#timing-of-the-delete-operation + # """ + # self.storage.set_session_retention_ttl(5) + # assert self.storage.has_session_retention_ttl() - # state = str(uuid.uuid4()) - # session_id = str(uuid.uuid4()) + # state = str(uuid.uuid4()) + # session_id = str(uuid.uuid4()) - # document_id = self.storage.init_session( - # str(uuid.uuid4()), - # session_id=session_id, state=state) + # document_id = self.storage.init_session( + # str(uuid.uuid4()), + # session_id=session_id, state=state) - # assert document_id + # assert document_id - # document = self.storage.get_by_id(document_id) - # time.sleep(6) - # assert not document + # document = self.storage.get_by_id(document_id) + # time.sleep(6) + # assert not document diff --git a/pyeudiw/tests/test_jwk.py b/pyeudiw/tests/test_jwk.py index 40bb24e9..4e9fa1e6 100644 --- a/pyeudiw/tests/test_jwk.py +++ b/pyeudiw/tests/test_jwk.py @@ -7,14 +7,10 @@ @pytest.mark.parametrize( "key, key_type, hash_func", - [ - (None, None, None), - (None, "EC", None), - (None, "RSA", None) - ] + [(None, None, None), (None, "EC", None), (None, "RSA", None)], ) def test_jwk(key, key_type, hash_func): - jwk = JWK(key, key_type, hash_func if hash_func else 'SHA-256') + jwk = JWK(key, key_type, hash_func if hash_func else "SHA-256") assert jwk.key assert jwk.thumbprint assert jwk.jwk @@ -22,7 +18,7 @@ def test_jwk(key, key_type, hash_func): def test_export_public__pem(): - jwk = JWK(key_type='RSA') + jwk = JWK(key_type="RSA") assert jwk.public_key assert jwk.public_key["e"] assert jwk.public_key["n"] diff --git a/pyeudiw/tests/test_jwt.py b/pyeudiw/tests/test_jwt.py index eb42f815..55dbb4bc 100644 --- a/pyeudiw/tests/test_jwt.py +++ b/pyeudiw/tests/test_jwt.py @@ -1,16 +1,15 @@ import pytest - -from cryptojwt.jwk.rsa import new_rsa_key from cryptojwt.jwk.ec import new_ec_key +from cryptojwt.jwk.rsa import new_rsa_key from pyeudiw.jwt.jwe_helper import JWEHelper from pyeudiw.jwt.jws_helper import DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP, JWSHelper from pyeudiw.jwt.utils import decode_jwt_header, is_jwe_format JWKs_EC = [ - (new_ec_key('P-256'), {"key": "value"}), - (new_ec_key('P-256'), "simple string"), - (new_ec_key('P-256'), None), + (new_ec_key("P-256"), {"key": "value"}), + (new_ec_key("P-256"), "simple string"), + (new_ec_key("P-256"), None), ] JWKs_RSA = [ diff --git a/pyeudiw/tests/tools/test_qr_code.py b/pyeudiw/tests/tools/test_qr_code.py index 6670b7ea..cb253359 100644 --- a/pyeudiw/tests/tools/test_qr_code.py +++ b/pyeudiw/tests/tools/test_qr_code.py @@ -13,7 +13,7 @@ def test_to_base64(): b64 = qr.to_base64() assert isinstance(b64, str) assert len(b64) > 0 - assert base64.b64decode(b64.encode()).decode('utf-8') == qr.to_svg() + assert base64.b64decode(b64.encode()).decode("utf-8") == qr.to_svg() def test_to_html(): @@ -29,7 +29,7 @@ def test_to_html(): assert html.endswith(">") assert "data:image/svg+xml;base64," in html b64 = html.split("data:image/svg+xml;base64,")[1].split('"')[0] - assert base64.b64decode(b64.encode()).decode('utf-8') == qr.to_svg() + assert base64.b64decode(b64.encode()).decode("utf-8") == qr.to_svg() def test_to_svg(): @@ -58,7 +58,9 @@ def _test_to_html_file(): qr = QRCode(data, size, color) html = qr.to_html() - with tempfile.NamedTemporaryFile("w", suffix=".html", dir=".", delete=DELETE_FILES) as tmp: + with tempfile.NamedTemporaryFile( + "w", suffix=".html", dir=".", delete=DELETE_FILES + ) as tmp: tmp.writelines(html) @@ -69,5 +71,7 @@ def _test_to_svg_file(): qr = QRCode(data, size, color) svg = qr.to_svg() - with tempfile.NamedTemporaryFile("w", suffix=".svg", dir=".", delete=DELETE_FILES) as tmp: + with tempfile.NamedTemporaryFile( + "w", suffix=".svg", dir=".", delete=DELETE_FILES + ) as tmp: tmp.writelines(svg) diff --git a/pyeudiw/tests/tools/test_utils.py b/pyeudiw/tests/tools/test_utils.py index 81bbb22e..842c5d69 100644 --- a/pyeudiw/tests/tools/test_utils.py +++ b/pyeudiw/tests/tools/test_utils.py @@ -1,14 +1,19 @@ - import datetime -import requests import sys import unittest.mock import freezegun import pytest +import requests -from pyeudiw.tools.utils import exp_from_now, iat_now, make_timezone_aware, random_token -from pyeudiw.tools.utils import cacheable_get_http_url, _lru_cached_get_http_url +from pyeudiw.tools.utils import ( + _lru_cached_get_http_url, + cacheable_get_http_url, + exp_from_now, + iat_now, + make_timezone_aware, + random_token, +) def test_make_timezone_aware(): @@ -20,8 +25,7 @@ def test_make_timezone_aware(): print(aware) with pytest.raises(ValueError): make_timezone_aware(aware) - aware = make_timezone_aware( - now, tz=datetime.datetime.now().astimezone().tzinfo) + aware = make_timezone_aware(now, tz=datetime.datetime.now().astimezone().tzinfo) assert aware.tzinfo is not None @@ -30,31 +34,50 @@ def frozen_time(fake_now, function, *args): return function(*args) -@pytest.mark.parametrize("fake_now, timestamp", [ - ("2020-12-31 12:00:00", 1609416000), - ("2000-10-02 12:23:14", 970489394), - ("1992-09-03 22:00:00", 715557600), -]) +@pytest.mark.parametrize( + "fake_now, timestamp", + [ + ("2020-12-31 12:00:00", 1609416000), + ("2000-10-02 12:23:14", 970489394), + ("1992-09-03 22:00:00", 715557600), + ], +) def test_iat_now(fake_now, timestamp): iat = frozen_time(fake_now=fake_now, function=iat_now) assert iat == timestamp -@pytest.mark.parametrize("fake_now, delta_mins, timestamp", [ - ("2020-12-31 12:00:00", 0, 1609416000), - ("2000-10-02 12:23:14", 1, 970489454), - ("1992-09-03 22:00:00", 2, 715557720), -]) +@pytest.mark.parametrize( + "fake_now, delta_mins, timestamp", + [ + ("2020-12-31 12:00:00", 0, 1609416000), + ("2000-10-02 12:23:14", 1, 970489454), + ("1992-09-03 22:00:00", 2, 715557720), + ], +) def test_exp_from_now(fake_now, delta_mins, timestamp): exp = frozen_time(fake_now, exp_from_now, delta_mins) assert exp == timestamp -@pytest.mark.parametrize("n", [ - -1, 0, 1, 2, 3, 10, 999, 10**1000, 2., - sys.maxsize, sys.maxsize - 1, - # sys.maxsize // 2 -1, - "1"]) +@pytest.mark.parametrize( + "n", + [ + -1, + 0, + 1, + 2, + 3, + 10, + 999, + 10**1000, + 2.0, + sys.maxsize, + sys.maxsize - 1, + # sys.maxsize // 2 -1, + "1", + ], +) def test_random_token(n): if type(n) is not int: with pytest.raises(TypeError): @@ -73,8 +96,8 @@ def test_random_token(n): rand = random_token(n) - if (n == 0): - assert rand == '' + if n == 0: + assert rand == "" return assert rand @@ -89,20 +112,17 @@ def test_cacheable_get_http_url(): ok_response = requests.Response() ok_response.status_code = 200 ok_response.headers.update({"Content-Type": "text/plain"}) - ok_response._content = b'Hello automated test' + ok_response._content = b"Hello automated test" mocked_endpoint = unittest.mock.patch( - "pyeudiw.tools.utils.get_http_url", - return_value=[ok_response] + "pyeudiw.tools.utils.get_http_url", return_value=[ok_response] ) - cache_ttl: int = 60*60*24*365 # 1 year + cache_ttl: int = 60 * 60 * 24 * 365 # 1 year httpc_p = { "connection": { "ssl": False, }, - "session": { - "timeout": 1 - } + "session": {"timeout": 1}, } # clear cache so that it is not polluted from prev tests @@ -110,14 +130,19 @@ def test_cacheable_get_http_url(): mocked_endpoint.start() for _ in range(tries): resp = cacheable_get_http_url( - cache_ttl, "http://location.example", httpc_p, http_async=False) + cache_ttl, "http://location.example", httpc_p, http_async=False + ) assert resp.status_code == 200 - assert resp._content == b'Hello automated test' + assert resp._content == b"Hello automated test" mocked_endpoint.stop() cache_misses = _lru_cached_get_http_url.cache_info().misses exp_cache_misses = 1 cache_hits = _lru_cached_get_http_url.cache_info().hits exp_cache_hits = tries - 1 - assert cache_misses == exp_cache_misses, f"cache missed more that {exp_cache_misses} time: {cache_misses}; {_lru_cached_get_http_url.cache_info()}" - assert cache_hits == exp_cache_hits, f"cache hit less than {exp_cache_hits} times: {cache_hits}" + assert ( + cache_misses == exp_cache_misses + ), f"cache missed more that {exp_cache_misses} time: {cache_misses}; {_lru_cached_get_http_url.cache_info()}" + assert ( + cache_hits == exp_cache_hits + ), f"cache hit less than {exp_cache_hits} times: {cache_hits}" diff --git a/pyeudiw/tests/trust/__init__.py b/pyeudiw/tests/trust/__init__.py index adcf836a..d337f7bb 100644 --- a/pyeudiw/tests/trust/__init__.py +++ b/pyeudiw/tests/trust/__init__.py @@ -2,22 +2,15 @@ "mock": { "module": "pyeudiw.tests.trust.mock_trust_handler", "class": "MockTrustHandler", - "config": {} + "config": {}, }, "direct_trust_sd_jwt_vc": { "module": "pyeudiw.trust.handler.direct_trust_sd_jwt_vc", "class": "DirectTrustSdJwtVc", "config": { "jwk_endpoint": "/.well-known/jwt-vc-issuer", - "httpc_params": { - "connection": { - "ssl": True - }, - "session": { - "timeout": 6 - } - } - } + "httpc_params": {"connection": {"ssl": True}, "session": {"timeout": 6}}, + }, }, } @@ -25,6 +18,6 @@ "not_conformant": { "module": "pyeudiw.tests.trust.mock_trust_handler", "class": "MockTrustEvaluator", - "config": {} + "config": {}, } } diff --git a/pyeudiw/tests/trust/handler/__init__.py b/pyeudiw/tests/trust/handler/__init__.py index d093d949..684c868d 100644 --- a/pyeudiw/tests/trust/handler/__init__.py +++ b/pyeudiw/tests/trust/handler/__init__.py @@ -1,23 +1,15 @@ import json + import requests def _generate_response(issuer: str, issuer_jwk: dict) -> requests.Response: - issuer_vct_md = { - "issuer": issuer, - "jwks": { - "keys": [ - issuer_jwk - ] - } - } + issuer_vct_md = {"issuer": issuer, "jwks": {"keys": [issuer_jwk]}} jwt_vc_issuer_endpoint_response = requests.Response() jwt_vc_issuer_endpoint_response.status_code = 200 - jwt_vc_issuer_endpoint_response.headers.update( - {"Content-Type": "application/json"}) - jwt_vc_issuer_endpoint_response._content = json.dumps( - issuer_vct_md).encode('utf-8') + jwt_vc_issuer_endpoint_response.headers.update({"Content-Type": "application/json"}) + jwt_vc_issuer_endpoint_response._content = json.dumps(issuer_vct_md).encode("utf-8") return jwt_vc_issuer_endpoint_response @@ -26,7 +18,7 @@ def _generate_empty_json_ok_response() -> requests.Response: resp = requests.Response() resp.status_code = 200 resp.headers.update({"Content-Type": "application/json"}) - resp._content = json.dumps({}).encode('utf-8') + resp._content = json.dumps({}).encode("utf-8") return resp @@ -36,19 +28,10 @@ def _generate_empty_json_ok_response() -> requests.Response: "kid": "MGaAh57cQghnevfWusalp0lNFXTzz2kHnkzO9wOjHq4", "crv": "P-256", "x": "S57KP4yGauTJJuNvO-wgWr2h_BYsatYUA1xW8Nae8i4", - "y": "66DmArglfyJODHAzZsIiPTY24gK70eeXPbpT4Nk0768" -} -issuer_vct_md = { - "issuer": issuer, - "jwks": { - "keys": [ - issuer_jwk - ] - } + "y": "66DmArglfyJODHAzZsIiPTY24gK70eeXPbpT4Nk0768", } +issuer_vct_md = {"issuer": issuer, "jwks": {"keys": [issuer_jwk]}} jwt_vc_issuer_endpoint_response = requests.Response() jwt_vc_issuer_endpoint_response.status_code = 200 -jwt_vc_issuer_endpoint_response.headers.update( - {"Content-Type": "application/json"}) -jwt_vc_issuer_endpoint_response._content = json.dumps( - issuer_vct_md).encode('utf-8') +jwt_vc_issuer_endpoint_response.headers.update({"Content-Type": "application/json"}) +jwt_vc_issuer_endpoint_response._content = json.dumps(issuer_vct_md).encode("utf-8") diff --git a/pyeudiw/tests/trust/handler/test_direct_trust.py b/pyeudiw/tests/trust/handler/test_direct_trust.py index 5a3aba4d..71f606eb 100644 --- a/pyeudiw/tests/trust/handler/test_direct_trust.py +++ b/pyeudiw/tests/trust/handler/test_direct_trust.py @@ -1,16 +1,23 @@ +import json import unittest.mock - -from pyeudiw.trust.handler._direct_trust_jwk import build_jwk_issuer_endpoint -from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import DirectTrustSdJwtVc, build_metadata_issuer_endpoint -from pyeudiw.tests.trust.handler import _generate_empty_json_ok_response, issuer -from pyeudiw.trust.model.trust_source import TrustSourceData -from pyeudiw.tests.trust.handler import issuer_jwk as expected_jwk +import uuid from dataclasses import dataclass + import requests -import json + +from pyeudiw.tests.trust.handler import ( + _generate_empty_json_ok_response, + _generate_response, + issuer, +) +from pyeudiw.tests.trust.handler import issuer_jwk as expected_jwk +from pyeudiw.trust.handler._direct_trust_jwk import build_jwk_issuer_endpoint +from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import ( + DirectTrustSdJwtVc, + build_metadata_issuer_endpoint, +) from pyeudiw.trust.handler.exception import InvalidJwkMetadataException -from pyeudiw.tests.trust.handler import _generate_response -import uuid +from pyeudiw.trust.model.trust_source import TrustSourceData def test_direct_trust_build_issuer_jwk_endpoint(): @@ -32,60 +39,41 @@ class TestCase: TestCase( "https://entity-id.example/path", "https://entity-id.example/path/.well-known/openid-credential-issuer", - explanation="the entity id does NOT have a trailing path separator" + explanation="the entity id does NOT have a trailing path separator", ), TestCase( "https://entity-id.example/path/", "https://entity-id.example/path/.well-known/openid-credential-issuer", - explanation="the entity id DOES have a trailing path separator" - ) + explanation="the entity id DOES have a trailing path separator", + ), ] metadata_endpoint = "/.well-known/openid-credential-issuer" for i, case in enumerate(test_cases): - obtained = build_metadata_issuer_endpoint( - case.entity_id, metadata_endpoint) + obtained = build_metadata_issuer_endpoint(case.entity_id, metadata_endpoint) assert case.expected == obtained, f"failed case {i}: {case.explanation}" def test_direct_trust_extract_jwks_from_jwk_metadata_by_value(): trust_source = DirectTrustSdJwtVc() - jwk_metadata = { - "issuer": issuer, - "jwks": { - "keys": [ - expected_jwk - ] - } - } + jwk_metadata = {"issuer": issuer, "jwks": {"keys": [expected_jwk]}} obt_jwks = trust_source._extract_jwks_from_jwk_metadata(jwk_metadata) - exp_jwks = { - "keys": [ - expected_jwk - ] - } + exp_jwks = {"keys": [expected_jwk]} assert obt_jwks == exp_jwks def test_direct_trust_extract_jwks_from_jwk_metadata_by_reference(): trust_source = DirectTrustSdJwtVc() - jwk_metadata = { - "issuer": issuer, - "jwks_uri": issuer + "jwks" - } - expected_jwks = { - "keys": [ - expected_jwk - ] - } + jwk_metadata = {"issuer": issuer, "jwks_uri": issuer + "jwks"} + expected_jwks = {"keys": [expected_jwk]} jwks_uri_response = requests.Response() jwks_uri_response.status_code = 200 jwks_uri_response.headers.update({"Content-Type": "application/json"}) - jwks_uri_response._content = json.dumps(expected_jwks).encode('utf-8') + jwks_uri_response._content = json.dumps(expected_jwks).encode("utf-8") mocked_jwks_document_endpoint = unittest.mock.patch( "pyeudiw.trust.handler._direct_trust_jwk.get_http_url", - return_value=[jwks_uri_response] + return_value=[jwks_uri_response], ) mocked_jwks_document_endpoint.start() obtained_jwks = trust_source._extract_jwks_from_jwk_metadata(jwk_metadata) @@ -96,12 +84,12 @@ def test_direct_trust_extract_jwks_from_jwk_metadata_by_reference(): def test_direct_trust_extract_jwks_from_jwk_metadata_invalid(): trust_source = DirectTrustSdJwtVc() - jwk_metadata = { - "issuer": issuer - } + jwk_metadata = {"issuer": issuer} try: trust_source._extract_jwks_from_jwk_metadata(jwk_metadata) - assert False, "parsed invalid metadata: should have raised InvalidJwkMetadataException instead" + assert ( + False + ), "parsed invalid metadata: should have raised InvalidJwkMetadataException instead" except InvalidJwkMetadataException: assert True @@ -113,12 +101,12 @@ def test_direct_trust_jwk(): mocked_issuer_jwt_vc_issuer_endpoint = unittest.mock.patch( "pyeudiw.trust.handler._direct_trust_jwk.get_http_url", - return_value=[_generate_response(random_issuer, expected_jwk)] + return_value=[_generate_response(random_issuer, expected_jwk)], ) mocked_metadata_endpoint = unittest.mock.patch( "pyeudiw.trust.handler.direct_trust_sd_jwt_vc.get_http_url", - return_value=[_generate_empty_json_ok_response()] + return_value=[_generate_empty_json_ok_response()], ) mocked_metadata_endpoint.start() @@ -126,13 +114,13 @@ def test_direct_trust_jwk(): trust_source = TrustSourceData.empty(random_issuer) trust_source = trust_handler.extract_and_update_trust_materials( - random_issuer, trust_source) + random_issuer, trust_source + ) obtained_jwks = trust_source.keys mocked_issuer_jwt_vc_issuer_endpoint.stop() mocked_metadata_endpoint.stop() - assert len( - obtained_jwks) == 1, f"expected 1 jwk, obtained {len(obtained_jwks)}" + assert len(obtained_jwks) == 1, f"expected 1 jwk, obtained {len(obtained_jwks)}" assert expected_jwk == obtained_jwks[0] diff --git a/pyeudiw/tests/trust/handler/test_direct_trust_jar.py b/pyeudiw/tests/trust/handler/test_direct_trust_jar.py index 5e90a1b3..2be35905 100644 --- a/pyeudiw/tests/trust/handler/test_direct_trust_jar.py +++ b/pyeudiw/tests/trust/handler/test_direct_trust_jar.py @@ -1,25 +1,27 @@ -from dataclasses import dataclass import json +from dataclasses import dataclass -from pyeudiw.jwk import JWK import pytest import satosa.context import satosa.response +from pyeudiw.jwk import JWK from pyeudiw.trust.handler.direct_trust_jar import DirectTrustJar @pytest.fixture def signing_private_key() -> list[dict]: - return [{ - "crv": "P-256", - "d": "r8UhwdbIvxKLvObVE-yixibCtu-0nzBZ3QGQ_-i1owc", - "kid": "MmjIDEhSnyIha4n462iIzrmdwMnJWlRZnsOJ3LWBEC4", - "kty": "EC", - "use": "sig", - "x": "xEmx9ruaf1qycPoYQ5lIfMSXAz2qLib6n0Ar_WDEiHM", - "y": "ZkSlQyYxuVTEKNRdrnONTisTepQ-3VcCza2O2yejawQ" - }] + return [ + { + "crv": "P-256", + "d": "r8UhwdbIvxKLvObVE-yixibCtu-0nzBZ3QGQ_-i1owc", + "kid": "MmjIDEhSnyIha4n462iIzrmdwMnJWlRZnsOJ3LWBEC4", + "kty": "EC", + "use": "sig", + "x": "xEmx9ruaf1qycPoYQ5lIfMSXAz2qLib6n0Ar_WDEiHM", + "y": "ZkSlQyYxuVTEKNRdrnONTisTepQ-3VcCza2O2yejawQ", + } + ] @pytest.fixture @@ -29,8 +31,11 @@ def all_private_keys(signing_private_key: list[dict]) -> list[dict]: "e": "AQAB", "kid": "KwJFzr11BxhSmAW8D2ZGBDQRdiQXZo1YWGaxHGW5Md8", "kty": "RSA", - "n": "wDBeA9a1xgOEb_zwm05cZnblBJANfWBA7oaZRYLp1sl0030pK6jHyEJ4wrXlMMQcxvwOx80uRFJG3o9BLTQ5lPnBu-VMAxF9LTkLZRD_gAJsrHz_myCgfcCMouX9AwDtUC01p5IIZ8YgfrbYPn694RxhCmH09oGs_OwOr7f3aW2qwf7uha7LRy8UPDYULnST7eqqWgrxjSIeHnmeO9BmEfcvqZJD2EfFHwFVXkjwMk1nWnQZYRV7Yncoz3qV0rhIQ86FQ2i4BoMW54OnRrgRHGqVUBHZP2y_Z3xo6foYOXJMgkHcEasiLbiATvHHN1cVsaM0PvQjO15qZu2IvVK224MgY6YbWU88pssG0ydTSOo0bY5gDhY6ml133MKXzfES0dzLNoOALrkyFxiHrPgQiFMKBuPXZ6qk1RomEWZYR54Brd7gDyK66MkdmpHvgBJf_V1YO42U1yxUTg63shdRp4O8FNZoTmhjMT9A_ZCD5mqGo00IewLHiQzVyWDqNrPv", "p": "0P7Vjj2Rc8tGdxqXvg77FsUaJIpoffgHSv59zrctOF7odKRYOSKW9FvlOi01NZE8dkcdYxlnriy3jyQVdGTxbRKJKKHrJbIJpqMDj6wUrk0k-67PbBuAJhzhmU_2wlyd04U7lt8gcn55kyV5XxQta6WTHz11MgO1GKePfCIlTRyS1T26_5wq5a_Q_VcdmiKuBHm0HtkBCkSYTWxWqfQjehs8eR5xOBAasgZHNit1KCMiONeUNyFtVFWgSSjDzhfL", "q": "62ngt1zghnj8pguWq1Xx6hRtE-eFS5K0rn6hSCgkLnUQeZWpO7cB4EEHFbN5FlWFIj9bjrRIoTQtHwprtM7dMqVaBH2HcKwSDiZy9ImmW2peKrP7Ko1t-Eg9Mhm8rycuzrwu3iQdd41JH-My5Fti-IuXyhZ3IVF_JvVNQKf4_RZQD4mbslEc6KFjLT-A3V6wfMhVFw7rnR6GcyQ0YUJTjzhRP3siG1A3GYGF1eN0pqT_3Lk2tvkd174BwcifiGft", "qi": "TFYkChfG3DRtgOiPRzl_yj_CDrYNsGWM-s0GmRHy_Zl1NvHK9u8Pc4hPoS9xx_qZiBnapX_Jmkaz39Q0GsjsJqjQQRxPMIofh1SZzH6O_tJ1-YQhJO4OfsQwi_FIAoDXHetkxnnhG1Axpvfqx5UyKM18uBz1vfWVrpfqaz9EBT04roVR_RFGzzV9jzDXFaZ17SWvovGtpHKqkVrCU0z6D8FV0lhuyBTmee6jXcxfzkwizGR6VexfaVwAHj7OdDGs", - "use": "enc" + "n": "wDBeA9a1xgOEb_zwm05cZnblBJANfWBA7oaZRYLp1sl0030pK6jHyEJ4wrXlMMQcxvwOx80uRFJG3o9BLTQ5lPnBu-VMAxF9LTkLZRD_gAJsrHz_myCgfcCMouX9AwDtUC01p5IIZ8YgfrbYPn694RxhCmH09oGs_OwOr7f3aW2qwf7uha7LRy8UPDYULnST7eqqWgrxjSIeHnmeO9BmEfcvqZJD2EfFHwFVXkjwMk1nWnQZYRV7Yncoz3qV0rhIQ86FQ2i4BoMW54OnRrgRHGqVUBHZP2y_Z3xo6foYOXJMgkHcEasiLbiATvHHN1cVsaM0PvQjO15qZu2IvVK224MgY6YbWU88pssG0ydTSOo0bY5gDhY6ml133MKXzfES0dzLNoOALrkyFxiHrPgQiFMKBuPXZ6qk1RomEWZYR54Brd7gDyK66MkdmpHvgBJf_V1YO42U1yxUTg63shdRp4O8FNZoTmhjMT9A_ZCD5mqGo00IewLHiQzVyWDqNrPv", + "p": "0P7Vjj2Rc8tGdxqXvg77FsUaJIpoffgHSv59zrctOF7odKRYOSKW9FvlOi01NZE8dkcdYxlnriy3jyQVdGTxbRKJKKHrJbIJpqMDj6wUrk0k-67PbBuAJhzhmU_2wlyd04U7lt8gcn55kyV5XxQta6WTHz11MgO1GKePfCIlTRyS1T26_5wq5a_Q_VcdmiKuBHm0HtkBCkSYTWxWqfQjehs8eR5xOBAasgZHNit1KCMiONeUNyFtVFWgSSjDzhfL", + "q": "62ngt1zghnj8pguWq1Xx6hRtE-eFS5K0rn6hSCgkLnUQeZWpO7cB4EEHFbN5FlWFIj9bjrRIoTQtHwprtM7dMqVaBH2HcKwSDiZy9ImmW2peKrP7Ko1t-Eg9Mhm8rycuzrwu3iQdd41JH-My5Fti-IuXyhZ3IVF_JvVNQKf4_RZQD4mbslEc6KFjLT-A3V6wfMhVFw7rnR6GcyQ0YUJTjzhRP3siG1A3GYGF1eN0pqT_3Lk2tvkd174BwcifiGft", + "qi": "TFYkChfG3DRtgOiPRzl_yj_CDrYNsGWM-s0GmRHy_Zl1NvHK9u8Pc4hPoS9xx_qZiBnapX_Jmkaz39Q0GsjsJqjQQRxPMIofh1SZzH6O_tJ1-YQhJO4OfsQwi_FIAoDXHetkxnnhG1Axpvfqx5UyKM18uBz1vfWVrpfqaz9EBT04roVR_RFGzzV9jzDXFaZ17SWvovGtpHKqkVrCU0z6D8FV0lhuyBTmee6jXcxfzkwizGR6VexfaVwAHj7OdDGs", + "use": "enc", } return signing_private_key + [rsa_pkey] @@ -51,24 +56,25 @@ class TestCase: TestCase( backend_name="", expected_path=".well-known/jar-issuer", - explanation="empty backend name" + explanation="empty backend name", ), TestCase( backend_name="openid4vp", expected_path="openid4vp/.well-known/jar-issuer", - explanation="regular backend name" + explanation="regular backend name", ), TestCase( backend_name="/openid4vp/", expected_path="openid4vp/.well-known/jar-issuer", - explanation="backend name with usual slashes" - ) + explanation="backend name with usual slashes", + ), ] for i, case in enumerate(test_cases): - path_component = direct_trust_jar._build_metadata_path( - case.backend_name) - assert path_component == case.expected_path, f"failed case {i+1}: test scenario: {case.explanation}" + path_component = direct_trust_jar._build_metadata_path(case.backend_name) + assert ( + path_component == case.expected_path + ), f"failed case {i+1}: test scenario: {case.explanation}" def test_direct_trust_jat_custom_path(all_private_keys): @@ -84,20 +90,23 @@ class TestCase: endpoint_component="custom", backend_name="openid4vp", expected_path="openid4vp/custom", - explanation="custom path" + explanation="custom path", ), TestCase( endpoint_component="/custom-with-slashes/", backend_name="openid4vp", expected_path="openid4vp/custom-with-slashes", - explanation="custom path with prepending and appending forward slash" - ) + explanation="custom path with prepending and appending forward slash", + ), ] for i, case in enumerate(test_cases): - dtj = DirectTrustJar(jwks=all_private_keys, - jwk_endpoint=case.endpoint_component) + dtj = DirectTrustJar( + jwks=all_private_keys, jwk_endpoint=case.endpoint_component + ) path_component = dtj._build_metadata_path(case.backend_name) - assert path_component == case.expected_path, f"failed case {i+1}: test scenario: {case.explanation}" + assert ( + path_component == case.expected_path + ), f"failed case {i+1}: test scenario: {case.explanation}" def test_direct_trust_jar_metadata(direct_trust_jar): @@ -114,8 +123,7 @@ def test_direct_trust_jar_metadata(direct_trust_jar): def test_direct_trust_metadata_handler(direct_trust_jar, signing_private_key): backend = "openid4vp" entity_id = f"https://rp.example/{backend}" - registered_methods = direct_trust_jar.build_metadata_endpoints( - backend, entity_id) + registered_methods = direct_trust_jar.build_metadata_endpoints(backend, entity_id) assert len(registered_methods) == 1 endpoint_regexp = registered_methods[0][0] @@ -129,7 +137,7 @@ def test_direct_trust_metadata_handler(direct_trust_jar, signing_private_key): try: response.headers.index(("Content-Type", "application/json")) except Exception as e: - assert True, f"unable to find application/json in repsonse content type: {e}" + assert True, f"unable to find application/json in response content type: {e}" response_data = json.loads(response.message) assert response_data["issuer"] == entity_id diff --git a/pyeudiw/tests/trust/handler/test_federation.py b/pyeudiw/tests/trust/handler/test_federation.py new file mode 100644 index 00000000..33b2d4c0 --- /dev/null +++ b/pyeudiw/tests/trust/handler/test_federation.py @@ -0,0 +1 @@ +# TODO: move legacy test about entity configurations and endpoints we still have in openid4vp backend tests diff --git a/pyeudiw/tests/trust/mock_trust_handler.py b/pyeudiw/tests/trust/mock_trust_handler.py index 7c320a0b..21de13db 100644 --- a/pyeudiw/tests/trust/mock_trust_handler.py +++ b/pyeudiw/tests/trust/mock_trust_handler.py @@ -7,7 +7,7 @@ "kty": "EC", "use": "sig", "x": "xu0FC3OQLgsea27rL0-d2CpVyKijjwl8tF6HB-3zLUg", - "y": "fUEsB8IrX2DgzqABfVsCody1RypAXX54fXQ1keoPP5Y" + "y": "fUEsB8IrX2DgzqABfVsCody1RypAXX54fXQ1keoPP5Y", } @@ -17,12 +17,12 @@ class MockTrustHandler(TrustHandlerInterface): """ def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> dict: - trust_source.metadata = { - "json_key": "json_value" - } + trust_source.metadata = {"json_key": "json_value"} return trust_source - def extract_and_update_trust_materials(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def extract_and_update_trust_materials( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: trust_source = self.get_metadata(issuer, trust_source) trust_source.keys.append(mock_jwk) return trust_source diff --git a/pyeudiw/tests/trust/test_dynamic.py b/pyeudiw/tests/trust/test_dynamic.py index 84ec3eef..c4df18e4 100644 --- a/pyeudiw/tests/trust/test_dynamic.py +++ b/pyeudiw/tests/trust/test_dynamic.py @@ -1,16 +1,20 @@ from uuid import uuid4 -from pyeudiw.trust.dynamic import CombinedTrustEvaluator -from pyeudiw.tests.trust import correct_config, not_conformant -from pyeudiw.tests.settings import CONFIG + from pyeudiw.storage.db_engine import DBEngine +from pyeudiw.tests.settings import CONFIG +from pyeudiw.tests.trust import correct_config, not_conformant from pyeudiw.tests.trust.mock_trust_handler import MockTrustHandler -from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import DirectTrustSdJwtVc +from pyeudiw.trust.dynamic import CombinedTrustEvaluator from pyeudiw.trust.exceptions import TrustConfigurationError +from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import DirectTrustSdJwtVc def test_trust_CombinedTrusstEvaluation_handler_loading(): trust_ev = CombinedTrustEvaluator.from_config( - correct_config, DBEngine(CONFIG["storage"])) + correct_config, + DBEngine(CONFIG["storage"]), + default_client_id="default-client-id", + ) assert trust_ev assert len(trust_ev.handlers) == 2 @@ -21,7 +25,10 @@ def test_trust_CombinedTrusstEvaluation_handler_loading(): def test_not_conformant_CombinedTrusstEvaluation_handler_loading(): try: CombinedTrustEvaluator.from_config( - not_conformant, DBEngine(CONFIG["storage"])) + not_conformant, + DBEngine(CONFIG["storage"]), + default_client_id="default-client-id", + ) assert False except TrustConfigurationError: assert True @@ -29,16 +36,19 @@ def test_not_conformant_CombinedTrusstEvaluation_handler_loading(): def test_if_no_conf_default_handler_instanciated(): trust_ev = CombinedTrustEvaluator.from_config( - {}, DBEngine(CONFIG["storage"])) - - assert len(trust_ev.handlers) == 1 + {}, DBEngine(CONFIG["storage"]), default_client_id="default-client-id" + ) + # both jar issuer and direct trust sd jwt vc are default if not trust handlers are configured + assert len(trust_ev.handlers) == 2 assert isinstance(trust_ev.handlers[0], DirectTrustSdJwtVc) def test_public_key_and_metadata_retrive(): db_engine = DBEngine(CONFIG["storage"]) - trust_ev = CombinedTrustEvaluator.from_config(correct_config, db_engine) + trust_ev = CombinedTrustEvaluator.from_config( + correct_config, db_engine, default_client_id="default-client-id" + ) uuid_url = f"http://{uuid4()}.issuer.it" @@ -46,7 +56,9 @@ def test_public_key_and_metadata_retrive(): trust_source = db_engine.get_trust_source(uuid_url) assert trust_source - assert trust_source["keys"][0]["kid"] == "qTo9RGpuU_CSolt6GZmndLyPXJJa48up5dH1YbxVDPs" + assert ( + trust_source["keys"][0]["kid"] == "qTo9RGpuU_CSolt6GZmndLyPXJJa48up5dH1YbxVDPs" + ) assert trust_source["metadata"] == {"json_key": "json_value"} assert pub_keys[0]["kid"] == "qTo9RGpuU_CSolt6GZmndLyPXJJa48up5dH1YbxVDPs" diff --git a/pyeudiw/tests/x509/test_x509.py b/pyeudiw/tests/x509/test_x509.py index ede98f1d..b1c0d2e4 100755 --- a/pyeudiw/tests/x509/test_x509.py +++ b/pyeudiw/tests/x509/test_x509.py @@ -1,11 +1,18 @@ +from datetime import datetime, timedelta +from ssl import DER_cert_to_PEM_cert + from cryptography import x509 -from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import Encoding -from datetime import datetime, timedelta -from ssl import DER_cert_to_PEM_cert -from pyeudiw.x509.verify import verify_x509_attestation_chain, verify_x509_anchor, get_issuer_from_x5c, is_der_format +from cryptography.x509.oid import NameOID + +from pyeudiw.x509.verify import ( + get_issuer_from_x5c, + is_der_format, + verify_x509_anchor, + verify_x509_attestation_chain, +) def gen_chain() -> list[bytes]: @@ -28,51 +35,76 @@ def gen_chain() -> list[bytes]: ) # Generate the CA's certificate - ca = x509.CertificateBuilder().subject_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"ca.example.com"), - ])).issuer_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"ca.example.com"), - ])).public_key( - ca_private_key.public_key() - ).serial_number( - x509.random_serial_number() - ).not_valid_before( - datetime.utcnow() - ).not_valid_after( - datetime.utcnow() + timedelta(days=365) - ).add_extension( - x509.BasicConstraints(ca=True, path_length=1), critical=True, - ).sign(ca_private_key, hashes.SHA256()) + ca = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "ca.example.com"), + ] + ) + ) + .issuer_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "ca.example.com"), + ] + ) + ) + .public_key(ca_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=1), + critical=True, + ) + .sign(ca_private_key, hashes.SHA256()) + ) # Generate the intermediate's certificate - intermediate = x509.CertificateBuilder().subject_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"intermediate.example.net"), - ])).issuer_name(ca.subject).public_key( - intermediate_private_key.public_key() - ).serial_number( - x509.random_serial_number() - ).not_valid_before( - datetime.utcnow() - ).not_valid_after( - datetime.utcnow() + timedelta(days=365) - ).add_extension( - x509.BasicConstraints(ca=True, path_length=0), critical=True, - ).sign(ca_private_key, hashes.SHA256()) + intermediate = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "intermediate.example.net"), + ] + ) + ) + .issuer_name(ca.subject) + .public_key(intermediate_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=0), + critical=True, + ) + .sign(ca_private_key, hashes.SHA256()) + ) # Generate the leaf's certificate - leaf = x509.CertificateBuilder().subject_name(x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"leaf.example.org"), - ])).issuer_name(intermediate.subject).public_key( - leaf_private_key.public_key() - ).serial_number( - x509.random_serial_number() - ).not_valid_before( - datetime.utcnow() - ).not_valid_after( - datetime.utcnow() + timedelta(days=365) - ).add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True, - ).sign(intermediate_private_key, hashes.SHA256()) + leaf = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "leaf.example.org"), + ] + ) + ) + .issuer_name(intermediate.subject) + .public_key(leaf_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(intermediate_private_key, hashes.SHA256()) + ) # Here the certificate chain in DER format, then encoded in base64 to use it according to RFC 9360: @@ -80,7 +112,7 @@ def gen_chain() -> list[bytes]: certificate_chain = [ ca.public_bytes(Encoding.DER), intermediate.public_bytes(Encoding.DER), - leaf.public_bytes(Encoding.DER) + leaf.public_bytes(Encoding.DER), ] return certificate_chain @@ -93,8 +125,7 @@ def chain_to_pem(chain: list[bytes]) -> str: def test_valid_chain(): chain = gen_chain() - assert verify_x509_attestation_chain( - chain, datetime.fromisoformat('2050-12-04')) + assert verify_x509_attestation_chain(chain, datetime.fromisoformat("2050-12-04")) def test_valid_chain_with_none_exp(): @@ -105,13 +136,15 @@ def test_valid_chain_with_none_exp(): def test_valid_chain_invalid_date(): chain = gen_chain() assert not verify_x509_attestation_chain( - chain, datetime.fromisoformat('2014-12-04')) + chain, datetime.fromisoformat("2014-12-04") + ) def test_invalid_intermediary_chain(): chain = gen_chain() chain[ - 1] = b'''0\x82\x02\xe00\x82\x01\xc8\xa0\x03\x02\x01\x02\x02\x14c\xef!\x17\xde\x88(\xbf\xb1\xdc\xad\x17\xc2`\xad\x15S\x95\n\xb60\r\x06\t*\x86H\x86\xf7\r + 1 + ] = b"""0\x82\x02\xe00\x82\x01\xc8\xa0\x03\x02\x01\x02\x02\x14c\xef!\x17\xde\x88(\xbf\xb1\xdc\xad\x17\xc2`\xad\x15S\x95\n\xb60\r\x06\t*\x86H\x86\xf7\r \x01\x01\x0b\x05\x000\x191\x170\x15\x06\x03U\x04\x03\x0c\x0eca.example.com0\x1e\x17\r231107165050Z\x17\r241106165050Z0#1!0\x1f\x06\x03U\x04\x03\x0c \x18intermediate.example.net0\x82\x01"0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x82\x01\x0f\x000\x82\x01\n\x02\x82\x01\x01\x00\x916\xde\x9b\x0b \xeb\xd4\x91\xder\x1c\x9b\x0b\x06s\xb3W\x08\xa1\x12\x19K\x05\xf9\x87\xf3Uk\x15\xfeQ\xf2#\x103\x9e6\x04]s\x87\x13cD\x9d\xed3\xd7\x1bg\xd6#Tau\x03[\xc8H\t @@ -124,9 +157,10 @@ def test_invalid_intermediary_chain(): \x92\xe6\x1d\x96NMDO6L:\xc4\xc5=%Q4\xd4\xca\xfct\xd1(6\xf1\xade~Or\xe0AM8\xbb0y=\xdc~D\x06g\x07p\x1c\x9eu)K~\xb0M\x81\xa5gfS\xfaG\xafW\x05N\xa0\x0f\x9a \xc9=\x06\xf7\xdb_\r\xc1\xf1\x1d\xea\xb0\x85\xf8p\x1e\xa5\xb0\xb6\xact\xb1\x86UmVNX\xb6\x8c\x07o\xc6\x0e\x88\xe7,\x9e\xbe\xb6w\xf9\x88\xca!\xb2k\xcdE \xaf%r\xfd\x1d+\xab\x1do/i\x84~\xad\xa1\x99\x80\x03\xf4\xf2s\x88\x90\xa3\x93\x83&\x1b\xa1a\xc9\xe6\\\xfe\xcar\x17\x83\x84\x8bB\x8e\x8d\xcb\xb2\x1bD\x08 - \xb5\x11y\xad\xa6~\x9ae5\xa4\x88\xac\xae\x03\xe9\xb2&\x05\x149\xa0\x86I\x84\xc1`!F\xb8''' + \xb5\x11y\xad\xa6~\x9ae5\xa4\x88\xac\xae\x03\xe9\xb2&\x05\x149\xa0\x86I\x84\xc1`!F\xb8""" assert not verify_x509_attestation_chain( - chain, datetime.fromisoformat('2050-12-04')) + chain, datetime.fromisoformat("2050-12-04") + ) def test_chain_issuer(): @@ -141,22 +175,23 @@ def test_invalid_len(): del chain[0] del chain[1] assert not verify_x509_attestation_chain( - chain, datetime.fromisoformat('2050-12-04')) + chain, datetime.fromisoformat("2050-12-04") + ) def test_invalid_chain_order(): chain = gen_chain() chain.reverse() assert not verify_x509_attestation_chain( - chain, datetime.fromisoformat('2050-12-04')) + chain, datetime.fromisoformat("2050-12-04") + ) def test_valid_anchor(): chain = gen_chain() pem = chain_to_pem(chain) - assert verify_x509_anchor( - pem, datetime.fromisoformat('2050-12-04')) + assert verify_x509_anchor(pem, datetime.fromisoformat("2050-12-04")) def test_valid_anchor_nodate(): @@ -170,14 +205,14 @@ def test_anchor_valid_chain_invalid_date(): chain = gen_chain() pem = chain_to_pem(chain) - assert not verify_x509_anchor( - pem, datetime.fromisoformat('2014-12-04')) + assert not verify_x509_anchor(pem, datetime.fromisoformat("2014-12-04")) def test_anchor_invalid_intermediary_chain(): chain = gen_chain() chain[ - 1] = b'''0\x82\x02\xe00\x82\x01\xc8\xa0\x03\x02\x01\x02\x02\x14c\xef!\x17\xde\x88(\xbf\xb1\xdc\xad\x17\xc2`\xad\x15S\x95\n\xb60\r\x06\t*\x86H\x86\xf7\r + 1 + ] = b"""0\x82\x02\xe00\x82\x01\xc8\xa0\x03\x02\x01\x02\x02\x14c\xef!\x17\xde\x88(\xbf\xb1\xdc\xad\x17\xc2`\xad\x15S\x95\n\xb60\r\x06\t*\x86H\x86\xf7\r \x01\x01\x0b\x05\x000\x191\x170\x15\x06\x03U\x04\x03\x0c\x0eca.example.com0\x1e\x17\r231107165050Z\x17\r241106165050Z0#1!0\x1f\x06\x03U\x04\x03\x0c \x18intermediate.example.net0\x82\x01"0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x82\x01\x0f\x000\x82\x01\n\x02\x82\x01\x01\x00\x916\xde\x9b\x0b \xeb\xd4\x91\xder\x1c\x9b\x0b\x06s\xb3W\x08\xa1\x12\x19K\x05\xf9\x87\xf3Uk\x15\xfeQ\xf2#\x103\x9e6\x04]s\x87\x13cD\x9d\xed3\xd7\x1bg\xd6#Tau\x03[\xc8H\t @@ -190,11 +225,10 @@ def test_anchor_invalid_intermediary_chain(): \x92\xe6\x1d\x96NMDO6L:\xc4\xc5=%Q4\xd4\xca\xfct\xd1(6\xf1\xade~Or\xe0AM8\xbb0y=\xdc~D\x06g\x07p\x1c\x9eu)K~\xb0M\x81\xa5gfS\xfaG\xafW\x05N\xa0\x0f\x9a \xc9=\x06\xf7\xdb_\r\xc1\xf1\x1d\xea\xb0\x85\xf8p\x1e\xa5\xb0\xb6\xact\xb1\x86UmVNX\xb6\x8c\x07o\xc6\x0e\x88\xe7,\x9e\xbe\xb6w\xf9\x88\xca!\xb2k\xcdE \xaf%r\xfd\x1d+\xab\x1do/i\x84~\xad\xa1\x99\x80\x03\xf4\xf2s\x88\x90\xa3\x93\x83&\x1b\xa1a\xc9\xe6\\\xfe\xcar\x17\x83\x84\x8bB\x8e\x8d\xcb\xb2\x1bD\x08 - \xb5\x11y\xad\xa6~\x9ae5\xa4\x88\xac\xae\x03\xe9\xb2&\x05\x149\xa0\x86I\x84\xc1`!F\xb8''' + \xb5\x11y\xad\xa6~\x9ae5\xa4\x88\xac\xae\x03\xe9\xb2&\x05\x149\xa0\x86I\x84\xc1`!F\xb8""" pem = chain_to_pem(chain) - assert not verify_x509_anchor( - pem, datetime.fromisoformat('2050-12-04')) + assert not verify_x509_anchor(pem, datetime.fromisoformat("2050-12-04")) def test_anchor_invalid_len(): @@ -203,8 +237,7 @@ def test_anchor_invalid_len(): del chain[1] pem = chain_to_pem(chain) - assert not verify_x509_anchor( - pem, datetime.fromisoformat('2050-12-04')) + assert not verify_x509_anchor(pem, datetime.fromisoformat("2050-12-04")) def test_anchor_invalid_chain_order(): @@ -212,8 +245,7 @@ def test_anchor_invalid_chain_order(): chain.reverse() pem = chain_to_pem(chain) - assert not verify_x509_anchor( - pem, datetime.fromisoformat('2050-12-04')) + assert not verify_x509_anchor(pem, datetime.fromisoformat("2050-12-04")) def test_valid_der(): diff --git a/pyeudiw/tools/base_logger.py b/pyeudiw/tools/base_logger.py index c3735a6b..77617cf6 100644 --- a/pyeudiw/tools/base_logger.py +++ b/pyeudiw/tools/base_logger.py @@ -1,4 +1,5 @@ import logging + import satosa.logging_util as lu from satosa.context import Context @@ -21,12 +22,7 @@ def _log(self, context: str | Context, level: str, message: str) -> None: context = context if isinstance(context, str) else context.state log_level = getattr(logger, level) - log_level( - lu.LOG_FMT.format( - id=lu.get_session_id(context), - message=message - ) - ) + log_level(lu.LOG_FMT.format(id=lu.get_session_id(context), message=message)) def _log_debug(self, context: str | Context, message: str) -> None: """ @@ -40,7 +36,9 @@ def _log_debug(self, context: str | Context, message: str) -> None: self._log(context, "debug", message) - def _log_function_debug(self, fn_name: str, context: Context, args_name: str | None = None, args=None) -> None: + def _log_function_debug( + self, fn_name: str, context: Context, args_name: str | None = None, args=None + ) -> None: """ Logs a message at the start of a backend function. diff --git a/pyeudiw/tools/mobile.py b/pyeudiw/tools/mobile.py index 5096da6f..f6d982f2 100644 --- a/pyeudiw/tools/mobile.py +++ b/pyeudiw/tools/mobile.py @@ -11,6 +11,6 @@ def is_smartphone(useragent: str) -> bool: """ device = DeviceDetector(useragent).parse() - if device.device_type() == 'smartphone': + if device.device_type() == "smartphone": return True return False diff --git a/pyeudiw/tools/qr_code.py b/pyeudiw/tools/qr_code.py index 01e4f63a..d0b1fcf1 100644 --- a/pyeudiw/tools/qr_code.py +++ b/pyeudiw/tools/qr_code.py @@ -37,7 +37,7 @@ def to_base64(self) -> str: :return: The svg data for html, base64 encoded :rtype: str """ - return base64.b64encode(self.svg.encode()).decode('utf-8') + return base64.b64encode(self.svg.encode()).decode("utf-8") def to_html(self) -> str: """ diff --git a/pyeudiw/tools/schema_utils.py b/pyeudiw/tools/schema_utils.py index a8ba8757..6222bcb2 100644 --- a/pyeudiw/tools/schema_utils.py +++ b/pyeudiw/tools/schema_utils.py @@ -1,4 +1,4 @@ -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo _default_supported_algorithms = [ "RS256", @@ -13,14 +13,14 @@ ] -def check_algorithm(alg: str, info: FieldValidationInfo) -> None: +def check_algorithm(alg: str, info: ValidationInfo) -> None: """ Check if the algorithm is supported by the relaying party. :param alg: The algorithm to check :type alg: str :param info: The field validation info - :type info: FieldValidationInfo + :type info: ValidationInfo :raises ValueError: If the algorithm is not supported """ @@ -29,7 +29,8 @@ def check_algorithm(alg: str, info: FieldValidationInfo) -> None: supported_algorithms = _default_supported_algorithms else: supported_algorithms = info.context.get( - "supported_algorithms", _default_supported_algorithms) + "supported_algorithms", _default_supported_algorithms + ) if not isinstance(supported_algorithms, list): supported_algorithms = [] if alg not in supported_algorithms: diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index a3bc976d..96290769 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -1,20 +1,24 @@ +import asyncio import datetime -from functools import lru_cache +import importlib import logging -import asyncio import os import time +from functools import lru_cache +from secrets import token_hex from typing import NamedTuple + import requests -import importlib -from secrets import token_hex -from pyeudiw.federation.http_client import http_get_sync, http_get_async +from pyeudiw.federation.http_client import http_get_async, http_get_sync logger = logging.getLogger(__name__) -def make_timezone_aware(dt: datetime.datetime, tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc) -> datetime.datetime: +def make_timezone_aware( + dt: datetime.datetime, + tz: datetime.timezone | datetime.tzinfo = datetime.timezone.utc, +) -> datetime.datetime: """ Make a datetime timezone aware. @@ -70,7 +74,9 @@ def datetime_from_timestamp(timestamp: int | float) -> datetime.datetime: return make_timezone_aware(datetime.datetime.fromtimestamp(timestamp)) -def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = True) -> list[requests.Response]: +def get_http_url( + urls: list[str] | str, httpc_params: dict, http_async: bool = True +) -> list[requests.Response]: """ Perform an HTTP Request returning the payload of the call. @@ -87,51 +93,12 @@ def get_http_url(urls: list[str] | str, httpc_params: dict, http_async: bool = T urls = urls if isinstance(urls, list) else [urls] if http_async: - responses = asyncio.run( - http_get_async(urls, httpc_params)) # pragma: no cover + responses = asyncio.run(http_get_async(urls, httpc_params)) # pragma: no cover else: responses = http_get_sync(urls, httpc_params) return responses -def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list[dict] = []) -> dict: - """ - Get jwks or jwks_uri or signed_jwks_uri - - :param httpc_params: parameters to perform http requests. - :type httpc_params: dict - :param metadata: metadata of the entity - :type metadata: dict - :param federation_jwks: jwks of the federation - :type federation_jwks: list - - :returns: A list of responses. - :rtype: list[dict] - """ - jwks_list = [] - if metadata.get('jwks'): - jwks_list = metadata["jwks"]["keys"] - elif metadata.get('jwks_uri'): - try: - jwks_uri = metadata["jwks_uri"] - jwks_list = get_http_url( - [jwks_uri], httpc_params=httpc_params - ) - jwks_list = jwks_list[0].json() - except Exception as e: - logger.error(f"Failed to download jwks from {jwks_uri}: {e}") - elif metadata.get('signed_jwks_uri'): - try: - signed_jwks_uri = metadata["signed_jwks_uri"] - jwks_list = get_http_url( - [signed_jwks_uri], httpc_params=httpc_params - )[0].json() - except Exception as e: - logger.error( - f"Failed to download jwks from {signed_jwks_uri}: {e}") - return jwks_list - - def random_token(n=254) -> str: """ Generate a random token. @@ -163,7 +130,9 @@ def get_dynamic_class(module_name: str, class_name: str) -> object: return instance_class -def dynamic_class_loader(module_name: str, class_name: str, init_params: dict = {}) -> object: +def dynamic_class_loader( + module_name: str, class_name: str, init_params: dict = {} +) -> object: """ Load a class dynamically. @@ -178,37 +147,16 @@ def dynamic_class_loader(module_name: str, class_name: str, init_params: dict = :rtype: object """ - storage_instance = get_dynamic_class( - module_name, class_name)(**init_params) + storage_instance = get_dynamic_class(module_name, class_name)(**init_params) return storage_instance -def satisfy_interface(o: object, interface: type) -> bool: - """ - Returns true if and only if an object satisfy an interface. - - :param o: an object (instance of a class) - :type o: object - :param interface: an interface type - :type interface: type - - :returns: True if the object satisfy the interface, otherwise False - """ - for cls_attr in dir(interface): - if cls_attr.startswith('_'): - continue - if not hasattr(o, cls_attr): - return False - if callable(getattr(interface, cls_attr)) and not callable(getattr(o, cls_attr)): - return False - return True - - -_HttpcParams_T = NamedTuple( - '_HttpcParams_T', [('ssl', bool), ('timeout', int)]) +_HttpcParams_T = NamedTuple("_HttpcParams_T", [("ssl", bool), ("timeout", int)]) -def cacheable_get_http_url(cache_ttl: int, url: str, httpc_params: dict, http_async: bool = True) -> requests.Response: +def cacheable_get_http_url( + cache_ttl: int, url: str, httpc_params: dict, http_async: bool = True +) -> requests.Response: """ Make a cached http GET request. The cache duration is UP TO cache_ttl. The actual duration is always @@ -237,7 +185,8 @@ def cacheable_get_http_url(cache_ttl: int, url: str, httpc_params: dict, http_as timeout: int | None = httpc_params.get("session", {}).get("timeout", None) if (ssl is None) or (timeout is None): raise ValueError( - f"invalid parameter {httpc_params=}: ['connection']['ssl'] and ['session']['timeout'] MUST be defined") + f"invalid parameter {httpc_params=}: ['connection']['ssl'] and ['session']['timeout'] MUST be defined" + ) curr_time_s = time.time_ns() // 1_000_000_000 if cache_ttl != 0: ttl_timestamp = curr_time_s // cache_ttl @@ -245,7 +194,8 @@ def cacheable_get_http_url(cache_ttl: int, url: str, httpc_params: dict, http_as ttl_timestamp = curr_time_s httpc_p_tuple = _HttpcParams_T(ssl, timeout) resp = _lru_cached_get_http_url( - ttl_timestamp, url, httpc_p_tuple, http_async=http_async) + ttl_timestamp, url, httpc_p_tuple, http_async=http_async + ) if resp.status_code != 200: _lru_cached_get_http_url.cache_clear() @@ -253,7 +203,12 @@ def cacheable_get_http_url(cache_ttl: int, url: str, httpc_params: dict, http_as @lru_cache(os.getenv("PYEUDIW_LRU_CACHE_MAXSIZE", 2048)) -def _lru_cached_get_http_url(timestamp: int, url: str, httpc_params_tuple: _HttpcParams_T, http_async: bool = True) -> requests.Response: +def _lru_cached_get_http_url( + timestamp: int, + url: str, + httpc_params_tuple: _HttpcParams_T, + http_async: bool = True, +) -> requests.Response: """ Wraps method 'get_http_url' around a ttl cache. This is done by including a timestamp in the function argument. For more, @@ -271,10 +226,7 @@ def _lru_cached_get_http_url(timestamp: int, url: str, httpc_params_tuple: _Http "connection": { "ssl": httpc_params_tuple.ssl, }, - "session": { - "timeout": httpc_params_tuple.timeout - } + "session": {"timeout": httpc_params_tuple.timeout}, } - resp: list[requests.Response] = get_http_url( - [url], httpc_params, http_async) + resp: list[requests.Response] = get_http_url([url], httpc_params, http_async) return resp[0] diff --git a/pyeudiw/trust/__init__.py b/pyeudiw/trust/__init__.py index 46cc25c7..e69de29b 100644 --- a/pyeudiw/trust/__init__.py +++ b/pyeudiw/trust/__init__.py @@ -1,291 +0,0 @@ -import logging -from datetime import datetime - -from pyeudiw.federation.trust_chain_builder import TrustChainBuilder -from pyeudiw.federation.trust_chain_validator import StaticTrustChainValidator -from pyeudiw.federation.exceptions import ProtocolMetadataNotFound -from pyeudiw.satosa.exceptions import DiscoveryFailedError -from pyeudiw.storage.db_engine import DBEngine -from pyeudiw.jwt.utils import decode_jwt_payload, is_jwt_format -from pyeudiw.x509.verify import verify_x509_anchor, get_issuer_from_x5c, is_der_format - -from pyeudiw.storage.exceptions import EntryNotFound -from pyeudiw.trust.exceptions import ( - MissingProtocolSpecificJwks, - UnknownTrustAnchor, - InvalidTrustType, - MissingTrustType, - InvalidAnchor -) - -from pyeudiw.federation.statements import EntityStatement -from pyeudiw.federation.exceptions import TimeValidationError -from pyeudiw.federation.policy import TrustChainPolicy, combine - -logger = logging.getLogger(__name__) - - -class TrustEvaluationHelper: - def __init__(self, storage: DBEngine, httpc_params, trust_anchor: str = None, **kwargs): - self.exp: int = 0 - self.trust_chain: list[str] = [] - self.trust_anchor = trust_anchor - self.storage = storage - self.entity_id: str = "" - self.httpc_params = httpc_params - self.is_trusted = False - - for k, v in kwargs.items(): - setattr(self, k, v) - - def _get_evaluation_method(self): - # The trust chain can be either federation or x509 - # If the trust_chain is empty, and we don't have a trust anchor - if not self.trust_chain and not self.trust_anchor: - raise MissingTrustType( - "Static trust chain is not available" - ) - - try: - if is_jwt_format(self.trust_chain[0]): - return self.federation - except TypeError: - pass - - if is_der_format(self.trust_chain[0]): - return self.x509 - - raise InvalidTrustType( - "Invalid Trust Type: trust type not supported" - ) - - def evaluation_method(self) -> bool: - ev_method = self._get_evaluation_method() - return ev_method() - - def _update_chain(self, entity_id: str | None = None, exp: datetime | None = None, trust_chain: list | None = None): - if entity_id is not None: - self.entity_id = entity_id - - if exp is not None: - self.exp = exp - - if trust_chain is not None: - self.trust_chain = trust_chain - - def _handle_federation_chain(self): - _first_statement = decode_jwt_payload(self.trust_chain[-1]) - trust_anchor_eid = self.trust_anchor or _first_statement.get( - 'iss', None) - - if not trust_anchor_eid: - raise UnknownTrustAnchor( - "Unknown Trust Anchor: can't find 'iss' in the " - f"first entity statement: {_first_statement} " - ) - - try: - trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) - except EntryNotFound: - raise UnknownTrustAnchor( - f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " - "a recognizable Trust Anchor." - ) - - decoded_ec = decode_jwt_payload( - trust_anchor['federation']['entity_configuration'] - ) - jwks = decoded_ec.get('jwks', {}).get('keys', []) - - if not jwks: - raise MissingProtocolSpecificJwks( - f"Cannot find any jwks in {decoded_ec}" - ) - - tc = StaticTrustChainValidator( - self.trust_chain, jwks, self.httpc_params - ) - self._update_chain( - entity_id=tc.entity_id, - exp=tc.exp - ) - - _is_valid = False - - try: - _is_valid = tc.validate() - except TimeValidationError: - logger.warn(f"Trust Chain {tc.entity_id} is expired") - except Exception as e: - logger.warn( - f"Cannot validate Trust Chain {tc.entity_id} for the following reason: {e}") - - db_chain = None - - if not _is_valid: - try: - db_chain = self.storage.get_trust_attestation( - self.entity_id - )["federation"]["chain"] - if StaticTrustChainValidator(db_chain, jwks, self.httpc_params).is_valid: - self.is_trusted = True - return self.is_trusted - - except (EntryNotFound, Exception): - pass - - _is_valid = tc.update() - - self._update_chain( - trust_chain=tc.trust_chain, - exp=tc.exp - ) - - # the good trust chain is then stored - self.storage.add_or_update_trust_attestation( - entity_id=self.entity_id, - attestation=tc.trust_chain, - exp=datetime.fromtimestamp(tc.exp) - ) - - self.is_trusted = _is_valid - return _is_valid - - def _handle_x509_pem(self): - trust_anchor_eid = self.trust_anchor or get_issuer_from_x5c( - self.trust_chain) - _is_valid = False - - if not trust_anchor_eid: - raise UnknownTrustAnchor( - "Unknown Trust Anchor: can't find 'iss' in the " - "first entity statement" - ) - - try: - trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) - except EntryNotFound: - raise UnknownTrustAnchor( - f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " - "a recognizable Trust Anchor." - ) - - pem = trust_anchor['x509'].get('pem') - - if pem is None: - raise MissingTrustType( - f"Trust Anchor: '{trust_anchor_eid}' has no x509 trust entity" - ) - - try: - _is_valid = verify_x509_anchor(pem) - except Exception as e: - raise InvalidAnchor( - f"Anchor verification raised the following exception: {e}" - ) - - if not self.is_trusted and trust_anchor['federation'].get("chain", None) is not None: - self._handle_federation_chain() - - self.is_trusted = _is_valid - return _is_valid - - def federation(self) -> bool: - if len(self.trust_chain) == 0: - self.discovery(self.entity_id) - - if self.trust_chain: - self.is_valid = self._handle_federation_chain() - return self.is_valid - - return False - - def x509(self) -> bool: - self.is_valid = self._handle_x509_pem() - return self.is_valid - - def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: - policy_acc = {"metadata": {}, "metadata_policy": {}} - - for policy in policies: - policy_acc = combine(policy, policy_acc) - - self.final_metadata = decode_jwt_payload(self.trust_chain[0]) - - try: - # TODO: there are some cases where the jwks are taken from a uri ... - selected_metadata = { - "metadata": self.final_metadata['metadata'], - "metadata_policy": {} - } - - self.final_metadata = TrustChainPolicy().apply_policy( - selected_metadata, - policy_acc - ) - - return self.final_metadata["metadata"][metadata_type] - except KeyError: - raise ProtocolMetadataNotFound( - f"{metadata_type} not found in the final metadata:" - f" {self.final_metadata['metadata']}" - ) - - def get_trusted_jwks(self, metadata_type: str, policies: list[dict] = []) -> list[dict]: - return self.get_final_metadata( - metadata_type=metadata_type, - policies=policies - ).get('jwks', {}).get('keys', []) - - def discovery(self, entity_id: str, entity_configuration: EntityStatement | None = None): - """ - Updates fields ``trust_chain`` and ``exp`` based on the discovery process. - - :raises: DiscoveryFailedError: raises an error if the discovery fails. - """ - trust_anchor_eid = self.trust_anchor - _ta_ec = self.storage.get_trust_anchor(entity_id=trust_anchor_eid) - ta_ec = _ta_ec['federation']['entity_configuration'] - - tcbuilder = TrustChainBuilder( - subject=entity_id, - trust_anchor=trust_anchor_eid, - trust_anchor_configuration=ta_ec, - subject_configuration=entity_configuration, - httpc_params=self.httpc_params - ) - - self._update_chain( - trust_chain=tcbuilder.get_trust_chain(), - exp=tcbuilder.exp - ) - is_good = tcbuilder.is_valid - if not is_good: - raise DiscoveryFailedError( - f"Discovery failed for entity {entity_id} with configuration {entity_configuration}" - ) - - @staticmethod - def build_trust_chain_for_entity_id(storage: DBEngine, entity_id, entity_configuration, httpc_params): - """ - Builds a ``TrustEvaluationHelper`` and returns it if the trust chain is valid. - In case the trust chain is invalid, tries to validate it in discovery before returning it. - - :return: The svg data for html, base64 encoded - :rtype: str - """ - db_chain = storage.get_trust_attestation(entity_id) - - trust_evaluation_helper = TrustEvaluationHelper( - storage=storage, - httpc_params=httpc_params, - trust_chain=db_chain - ) - - is_good = trust_evaluation_helper.evaluation_method() - if is_good: - return trust_evaluation_helper - - trust_evaluation_helper.discovery( - entity_id=entity_id, entity_configuration=entity_configuration) - return trust_evaluation_helper diff --git a/pyeudiw/trust/default/direct_trust_sd_jwt_vc.py b/pyeudiw/trust/default/direct_trust_sd_jwt_vc.py index 7e2acd3e..2cfec648 100644 --- a/pyeudiw/trust/default/direct_trust_sd_jwt_vc.py +++ b/pyeudiw/trust/default/direct_trust_sd_jwt_vc.py @@ -5,19 +5,14 @@ from pyeudiw.tools.utils import cacheable_get_http_url, get_http_url from pyeudiw.trust.interface import TrustEvaluator -from .. exceptions import InvalidJwkMetadataException - +from ..exceptions import InvalidJwkMetadataException DEFAULT_ISSUER_JWK_ENDPOINT = "/.well-known/jwt-vc-issuer" DEFAULT_METADATA_ENDPOINT = "/.well-known/openid-credential-issuer" DEFAULT_DIRECT_TRUST_SD_JWC_VC_PARAMS = { "httpc_params": { - "connection": { - "ssl": os.getenv("PYEUDIW_HTTPC_SSL", True) - }, - "session": { - "timeout": os.getenv("PYEUDIW_HTTPC_TIMEOUT", 6) - } + "connection": {"ssl": os.getenv("PYEUDIW_HTTPC_SSL", True)}, + "session": {"timeout": os.getenv("PYEUDIW_HTTPC_TIMEOUT", 6)}, } } @@ -35,8 +30,13 @@ class DirectTrustSdJwtVc(DirectTrust): available. """ - def __init__(self, httpc_params: Optional[dict] = None, cache_ttl: int = 0, jwk_endpoint: str = DEFAULT_ISSUER_JWK_ENDPOINT, - metadata_endpoint: str = DEFAULT_METADATA_ENDPOINT): + def __init__( + self, + httpc_params: Optional[dict] = None, + cache_ttl: int = 0, + jwk_endpoint: str = DEFAULT_ISSUER_JWK_ENDPOINT, + metadata_endpoint: str = DEFAULT_METADATA_ENDPOINT, + ): if httpc_params is None: self.httpc_params = DEFAULT_DIRECT_TRUST_SD_JWC_VC_PARAMS["httpc_params"] self.httpc_params = httpc_params @@ -56,12 +56,14 @@ def get_public_keys(self, issuer: str) -> list[dict]: md = self._get_jwk_metadata(issuer) if not issuer == (obt_issuer := md.get("issuer", None)): raise InvalidJwkMetadataException( - f"invalid jwk metadata: obtained issuer :{obt_issuer}, expected issuer: {issuer}") + f"invalid jwk metadata: obtained issuer :{obt_issuer}, expected issuer: {issuer}" + ) jwks = self._extract_jwks_from_jwk_metadata(md) jwk_l: list[dict] = jwks.get("keys", []) if not jwk_l: raise InvalidJwkMetadataException( - "unable to find jwks in issuer jwk metadata") + "unable to find jwks in issuer jwk metadata" + ) return jwk_l def _get_jwk_metadata(self, issuer: str) -> dict: @@ -69,16 +71,23 @@ def _get_jwk_metadata(self, issuer: str) -> dict: call the jwk metadata endpoint and return the whole document """ jwk_endpoint = DirectTrustSdJwtVc.build_issuer_jwk_endpoint( - issuer, self.jwk_endpoint) + issuer, self.jwk_endpoint + ) if self.cache_ttl: resp = cacheable_get_http_url( - self.cache_ttl, jwk_endpoint, self.httpc_params, http_async=self.http_async_calls) + self.cache_ttl, + jwk_endpoint, + self.httpc_params, + http_async=self.http_async_calls, + ) else: - resp = get_http_url([jwk_endpoint], self.httpc_params, - http_async=self.http_async_calls)[0] + resp = get_http_url( + [jwk_endpoint], self.httpc_params, http_async=self.http_async_calls + )[0] if (not resp) or (resp.status_code != 200): raise InvalidJwkMetadataException( - f"failed to fetch valid jwk metadata: obtained {resp}") + f"failed to fetch valid jwk metadata: obtained {resp}" + ) return resp.json() def _get_jwks_by_reference(self, jwks_reference_uri: str) -> dict: @@ -87,10 +96,17 @@ def _get_jwks_by_reference(self, jwks_reference_uri: str) -> dict: """ if self.cache_ttl: resp = cacheable_get_http_url( - self.cache_ttl, jwks_reference_uri, self.httpc_params, http_async=self.http_async_calls) + self.cache_ttl, + jwks_reference_uri, + self.httpc_params, + http_async=self.http_async_calls, + ) else: resp = get_http_url( - [jwks_reference_uri], self.httpc_params, http_async=self.http_async_calls)[0] + [jwks_reference_uri], + self.httpc_params, + http_async=self.http_async_calls, + )[0] return resp.json() def _extract_jwks_from_jwk_metadata(self, metadata: dict) -> dict: @@ -98,12 +114,12 @@ def _extract_jwks_from_jwk_metadata(self, metadata: dict) -> dict: parse the jwk metadata document and return the jwks NOTE: jwks might be in the document by value or by reference """ - jwks: dict[Literal["keys"], list[dict] - ] | None = metadata.get("jwks", None) + jwks: dict[Literal["keys"], list[dict]] | None = metadata.get("jwks", None) jwks_uri: str | None = metadata.get("jwks_uri", None) if (not jwks) and (not jwks_uri): raise InvalidJwkMetadataException( - "invalid issuing key metadata: missing both claims [jwks] and [jwks_uri]") + "invalid issuing key metadata: missing both claims [jwks] and [jwks_uri]" + ) if jwks: # get jwks by value return jwks @@ -120,26 +136,41 @@ def get_metadata(self, issuer: str) -> dict: if not issuer: raise ValueError("invalid issuer: cannot be empty value") url = DirectTrustSdJwtVc.build_issuer_metadata_endpoint( - issuer, self.metadata_endpoint) + issuer, self.metadata_endpoint + ) if self.cache_ttl == 0: return get_http_url(url, self.httpc_params, self.http_async_calls)[0].json() - return cacheable_get_http_url(self.cache_ttl, url, self.httpc_params, self.http_async_calls).json() + return cacheable_get_http_url( + self.cache_ttl, url, self.httpc_params, self.http_async_calls + ).json() - def build_issuer_jwk_endpoint(issuer_id: str, well_known_path_component: str) -> str: + def build_issuer_jwk_endpoint( + issuer_id: str, well_known_path_component: str + ) -> str: baseurl = urlparse(issuer_id) well_known_path = well_known_path_component + baseurl.path well_known_url: str = ParseResult( - baseurl.scheme, baseurl.netloc, well_known_path, baseurl.params, baseurl.query, baseurl.fragment).geturl() + baseurl.scheme, + baseurl.netloc, + well_known_path, + baseurl.params, + baseurl.query, + baseurl.fragment, + ).geturl() return well_known_url - def build_issuer_metadata_endpoint(issuer: str, metadata_path_component: str) -> str: - issuer_normalized = issuer if issuer[-1] != '/' else issuer[:-1] + def build_issuer_metadata_endpoint( + issuer: str, metadata_path_component: str + ) -> str: + issuer_normalized = issuer if issuer[-1] != "/" else issuer[:-1] return issuer_normalized + metadata_path_component def __str__(self) -> str: - return f"DirectTrustSdJwtVc(" \ - f"httpc_params={self.httpc_params}, " \ - f"cache_ttl={self.cache_ttl}, " \ - f"jwk_endpoint={self.jwk_endpoint}, " \ - f"metadata_endpoint={self.metadata_endpoint}" \ + return ( + f"DirectTrustSdJwtVc(" + f"httpc_params={self.httpc_params}, " + f"cache_ttl={self.cache_ttl}, " + f"jwk_endpoint={self.jwk_endpoint}, " + f"metadata_endpoint={self.metadata_endpoint}" ")" + ) diff --git a/pyeudiw/trust/default/federation.py b/pyeudiw/trust/default/federation.py deleted file mode 100644 index 3d4e7fca..00000000 --- a/pyeudiw/trust/default/federation.py +++ /dev/null @@ -1,259 +0,0 @@ -import logging -from cryptojwt.jwk.jwk import key_from_jwk_dict - -import json - -from satosa.context import Context -from satosa.response import Response - -from pyeudiw.jwk import JWK -from pyeudiw.jwt.utils import decode_jwt_header -from pyeudiw.satosa.exceptions import (DiscoveryFailedError, - NotTrustedFederationError) -from pyeudiw.storage.exceptions import EntryNotFound -from pyeudiw.trust import TrustEvaluationHelper -from pyeudiw.trust.trust_anchors import update_trust_anchors_ecs - - -from pyeudiw.federation.policy import TrustChainPolicy -from pyeudiw.jwt.utils import decode_jwt_payload -from pyeudiw.trust.interface import TrustEvaluator - -from cryptojwt.jwk.ec import ECKey -from cryptojwt.jwk.rsa import RSAKey - -logger = logging.getLogger(__name__) - - -class FederationTrustModel(TrustEvaluator): - _ISSUER_METADATA_TYPE = "openid_credential_issuer" - - def __init__(self, **kwargs): - self.metadata_policy_resolver = TrustChainPolicy() - self.federation_jwks = kwargs.get("federation_jwks", []) - - def get_public_keys(self, issuer): - public_keys = [JWK(i).as_public_dict() for i in self.federation_jwks] - - return public_keys - - def _verify_trust_chain(self, trust_chain: list[str]): - # TODO: qui c'è tutta la ciccia, ma si può fare copia incolla da terze parti (specialmente di pyeudiw.trust.__init__) - raise NotImplementedError - - def get_verified_key(self, issuer: str, token_header: dict) -> ECKey | RSAKey | dict: - # (1) verifica trust chain - kid: str = token_header.get("kid", None) - if not kid: - raise ValueError("missing claim [kid] in token header") - trust_chain: list[str] = token_header.get("trust_chain", None) - if not trust_chain: - raise ValueError("missing trust chain in federation token") - if not isinstance(trust_chain, list): - raise ValueError*("invalid format of header claim [trust_claim]") - # TODO: check whick exceptions this might raise - self._verify_trust_chain(trust_chain) - - # (2) metadata parsing ed estrazione Jwk set - # TODO: wrap in something that implements VciJwksSource - # apply policy of traust anchor only? - issuer_entity_configuration = trust_chain[0] - anchor_entity_configuration = trust_chain[-1] - issuer_payload: dict = decode_jwt_payload(issuer_entity_configuration) - anchor_payload = decode_jwt_payload(anchor_entity_configuration) - trust_anchor_policy = anchor_payload.get("metadata_policy", {}) - final_issuer_metadata = self.metadata_policy_resolver.apply_policy( - issuer_payload, trust_anchor_policy) - metadata: dict = final_issuer_metadata.get("metadata", None) - if not metadata: - raise ValueError( - "missing or invalid claim [metadata] in entity configuration") - issuer_metadata: dict = metadata.get( - FederationTrustModel._ISSUER_METADATA_TYPE, None) - if not issuer_metadata: - raise ValueError( - f"missing or invalid claim [metadata.{FederationTrustModel._ISSUER_METADATA_TYPE}] in entity configuration") - issuer_keys: list[dict] = issuer_metadata.get( - "jwks", {}).get("keys", []) - if not issuer_keys: - raise ValueError( - f"missing or invalid claim [metadata.{FederationTrustModel._ISSUER_METADATA_TYPE}.jwks.keys] in entity configuration") - # check issuer = entity_id - if issuer != (obt_iss := final_issuer_metadata.get("iss", "")): - raise ValueError( - f"invalid issuer metadata: expected '{issuer}', obtained '{obt_iss}'") - - # (3) dato il set completo, fa il match per kid tra l'header e il jwk set - found_jwks: list[dict] = [] - for key in issuer_keys: - obt_kid: str = key.get("kid", "") - if kid == obt_kid: - found_jwks.append(key) - if len(found_jwks) != 1: - raise ValueError( - f"unable to uniquely identify a key with kid {kid} in appropriate section of issuer entity configuration") - try: - return key_from_jwk_dict(**found_jwks[0]) - except Exception as e: - raise ValueError(f"unable to parse issuer jwk: {e}") - - # --------------------------- - # TODO: sistema da qui in giù - # --------------------------- - - # def __getattribute__(self, name: str) -> Any: - # if hasattr(self, name): - # return getattr(self, name) - # logger.critical("se vedi questo messaggio: sei perduto") - # return None - - def init_trust_resources(self) -> None: - """ - Initializes the trust resources. - """ - - # private keys by kid - self.federations_jwks_by_kids = { - i['kid']: i for i in self.config['trust']['federation']['config']['federation_jwks'] - } - # dumps public jwks - self.federation_public_jwks = [ - key_from_jwk_dict(i).serialize() for i in self.config['trust']['federation']['config']['federation_jwks'] - ] - # we close the connection in this constructor since it must be fork safe and - # get reinitialized later on, within each fork - self.update_trust_anchors() - - try: - self.get_backend_trust_chain() - except Exception as e: - self._log_critical( - "Backend Trust", - f"Cannot fetch the trust anchor configuration: {e}" - ) - - self.db_engine.close() - self._db_engine = None - - def entity_configuration_endpoint(self, context: Context) -> Response: - """ - Entity Configuration endpoint. - - :param context: The current context - :type context: Context - - :return: The entity configuration - :rtype: Response - """ - - if context.qs_params.get('format', '') == 'json': - return Response( - json.dumps(self.entity_configuration_as_dict), - status="200", - content="application/json" - ) - - return Response( - self.entity_configuration, - status="200", - content="application/entity-statement+jwt" - ) - - def update_trust_anchors(self): - """ - Updates the trust anchors of current instance. - """ - - tas = self.config['trust']['federation']['config']['trust_anchors'] - self._log_info("Trust Anchors updates", f"Trying to update: {tas}") - - for ta in tas: - try: - update_trust_anchors_ecs( - db=self.db_engine, - trust_anchors=[ta], - httpc_params=self.config['network']['httpc_params'] - ) - except Exception as e: - self._log_warning("Trust Anchor updates", - f"{ta} update failed: {e}") - - self._log_info("Trust Anchor updates", f"{ta} updated") - - def get_backend_trust_chain(self) -> list[str]: - """ - Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message - and returns an empty list. - - :return: The trust chain - :rtype: list - """ - try: - trust_evaluation_helper = TrustEvaluationHelper.build_trust_chain_for_entity_id( - storage=self.db_engine, - entity_id=self.client_id, - entity_configuration=self.entity_configuration, - httpc_params=self.config['network']['httpc_params'] - ) - self.db_engine.add_or_update_trust_attestation( - entity_id=self.client_id, - attestation=trust_evaluation_helper.trust_chain, - exp=trust_evaluation_helper.exp - ) - return trust_evaluation_helper.trust_chain - - except (DiscoveryFailedError, EntryNotFound, Exception) as e: - message = ( - f"Error while building trust chain for client with id: {self.client_id}. " - f"{e.__class__.__name__}: {e}" - ) - self._log_warning("Trust Chain", message) - - return [] - - def _validate_trust(self, context: Context, jws: str) -> TrustEvaluationHelper: - """ - Validates the trust of the given jws. - - :param context: the request context - :type context: satosa.context.Context - :param jws: the jws to validate - :type jws: str - - :raises: NotTrustedFederationError: raises an error if the trust evaluation fails. - - :return: the trust evaluation helper - :rtype: TrustEvaluationHelper - """ - - self._log_debug(context, "[TRUST EVALUATION] evaluating trust.") - - headers = decode_jwt_header(jws) - trust_eval = TrustEvaluationHelper( - self.db_engine, - httpc_params=self.config['network']['httpc_params'], - **headers - ) - - try: - trust_eval.evaluation_method() - except EntryNotFound: - message = ( - "[TRUST EVALUATION] not found for " - f"{trust_eval.entity_id}" - ) - self._log_error(context, message) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} not found for Trust evaluation." - ) - except Exception as e: - message = ( - "[TRUST EVALUATION] failed for " - f"{trust_eval.entity_id}: {e}" - ) - self._log_error(context, message) - raise NotTrustedFederationError( - f"{trust_eval.entity_id} is not trusted." - ) - - return trust_eval diff --git a/pyeudiw/trust/dynamic.py b/pyeudiw/trust/dynamic.py index 54bb3e9e..2f2102bf 100644 --- a/pyeudiw/trust/dynamic.py +++ b/pyeudiw/trust/dynamic.py @@ -1,17 +1,18 @@ import logging from typing import Any, Callable, Optional + import satosa.context import satosa.response from pyeudiw.storage.db_engine import DBEngine +from pyeudiw.storage.exceptions import EntryNotFound from pyeudiw.tools.base_logger import BaseLogger -from pyeudiw.trust.exceptions import TrustConfigurationError from pyeudiw.tools.utils import dynamic_class_loader +from pyeudiw.trust.exceptions import NoCriptographicMaterial, TrustConfigurationError +from pyeudiw.trust.handler.direct_trust_jar import DirectTrustJar +from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import DirectTrustSdJwtVc from pyeudiw.trust.handler.interface import TrustHandlerInterface from pyeudiw.trust.model.trust_source import TrustSourceData -from pyeudiw.trust.handler.direct_trust_sd_jwt_vc import DirectTrustSdJwtVc -from pyeudiw.storage.exceptions import EntryNotFound -from pyeudiw.trust.exceptions import NoCriptographicMaterial logger = logging.getLogger(__name__) @@ -21,7 +22,9 @@ class CombinedTrustEvaluator(BaseLogger): A trust evaluator that combines multiple trust models. """ - def __init__(self, handlers: list[TrustHandlerInterface], db_engine: DBEngine) -> None: + def __init__( + self, handlers: list[TrustHandlerInterface], db_engine: DBEngine + ) -> None: """ Initialize the CombinedTrustEvaluator. @@ -50,7 +53,9 @@ def _retrieve_trust_source(self, issuer: str) -> Optional[TrustSourceData]: except EntryNotFound: return None - def _upsert_source_trust_materials(self, issuer: str, trust_source: Optional[TrustSourceData]) -> TrustSourceData: + def _upsert_source_trust_materials( + self, issuer: str, trust_source: Optional[TrustSourceData] + ) -> TrustSourceData: """ Extract the trust material of a certain issuer from all the trust handlers. If the trust material is not found for a certain issuer the structure remain unchanged. @@ -67,7 +72,8 @@ def _upsert_source_trust_materials(self, issuer: str, trust_source: Optional[Tru for handler in self.handlers: trust_source = handler.extract_and_update_trust_materials( - issuer, trust_source) + issuer, trust_source + ) self.db_engine.add_trust_source(trust_source.serialize()) @@ -86,8 +92,7 @@ def _get_trust_source(self, issuer: str) -> TrustSourceData: trust_source = self._retrieve_trust_source(issuer) if not trust_source: - trust_source = self._upsert_source_trust_materials( - issuer, trust_source) + trust_source = self._upsert_source_trust_materials(issuer, trust_source) return trust_source @@ -105,7 +110,8 @@ def get_public_keys(self, issuer: str) -> list[dict]: if not trust_source.keys: raise NoCriptographicMaterial( - f"no trust evaluator can provide cyptographic material for {issuer}: searched among: {self.handlers_names}" + f"no trust evaluator can provide cyptographic material " + f"for {issuer}: searched among: {self.handlers_names}" ) return trust_source.public_keys @@ -118,7 +124,9 @@ def get_metadata(self, issuer: str) -> dict: if not trust_source.metadata: raise Exception( - f"no trust evaluator can provide metadata for {issuer}: searched among: {self.handlers_names}") + f"no trust evaluator can provide metadata for {issuer}: " + f"searched among: {self.handlers_names}" + ) return trust_source.metadata @@ -150,7 +158,9 @@ def get_policies(self, issuer: str) -> dict[str, any]: if not trust_source.policies: raise Exception( - f"no trust evaluator can provide policies for {issuer}: searched among: {self.handlers_names}") + f"no trust evaluator can provide policies for {issuer}: " + f"searched among: {self.handlers_names}" + ) return trust_source.policies @@ -166,27 +176,42 @@ def get_selfissued_jwt_header_trust_parameters(self, issuer: str) -> list[dict]: """ trust_source = self._get_trust_source(issuer) - if not trust_source.trust_params: - raise Exception( - f"no trust evaluator can provide trust parameters for {issuer}: searched among: {self.handlers_names}") - - return {type: param.trust_params for type, param in trust_source.trust_params.items()} - - def build_metadata_endpoints(self, backend_name: str, entity_uri: str) -> list[tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]]: + # why should we issue an exception if a configuration might work without + # any trust evaluation handler? + + # if not trust_source.trust_params: + # raise Exception( + # f"no trust evaluator can provide trust parameters for {issuer}: " + # f"searched among: {self.handlers_names}" + # ) + + return { + _typ: param.trust_params + for _typ, param in trust_source.trust_params.items() + } + + def build_metadata_endpoints( + self, backend_name: str, entity_uri: str + ) -> list[ + tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]] + ]: endpoints = [] for handler in self.handlers: - endpoints += handler.build_metadata_endpoints( - backend_name, entity_uri) + endpoints += handler.build_metadata_endpoints(backend_name, entity_uri) # Partially check for collissions in managed paths: this might happen if multiple configured # trust frameworks want to handle the same endpoints (check is not 100% exhaustive as paths are actually regexps) all_paths = [path for path, *_ in endpoints] if len(all_paths) > len(set(all_paths)): - self._log_warning("build_metadata_endpoints", - f"found collision in metadata endpoint: {all_paths}") + self._log_warning( + "build_metadata_endpoints", + f"found collision in metadata endpoint: {all_paths}", + ) return endpoints @staticmethod - def from_config(config: dict, db_engine: DBEngine) -> 'CombinedTrustEvaluator': + def from_config( + config: dict, db_engine: DBEngine, default_client_id: str + ) -> "CombinedTrustEvaluator": """ Create a CombinedTrustEvaluator from a configuration. @@ -202,24 +227,34 @@ def from_config(config: dict, db_engine: DBEngine) -> 'CombinedTrustEvaluator': for handler_name, handler_config in config.items(): try: + # every trust evaluation method might use their own client id + # but a default one always therefore required + if not handler_config["config"].get("client_id"): + handler_config["config"]["client_id"] = default_client_id + trust_handler = dynamic_class_loader( handler_config["module"], handler_config["class"], - handler_config["config"] + handler_config["config"], ) except Exception as e: raise TrustConfigurationError( - f"invalid configuration for {handler_name}: {e}", e) + f"invalid configuration for {handler_name}: {e}", e + ) if not isinstance(trust_handler, TrustHandlerInterface): raise TrustConfigurationError( - f"class {trust_handler.__class__} does not satisfy the interface TrustEvaluator") + f"class {trust_handler.__class__} does not satisfy the interface TrustEvaluator" + ) handlers.append(trust_handler) + logger.debug( + f"TrustHandlers loaded: [{', '.join([str(i.__class__) for i in handlers])}]." + ) if not handlers: - logger.warning( - "No configured trust model, using direct trust model") + logger.warning("No configured trust model, using direct trust model") handlers.append(DirectTrustSdJwtVc()) + handlers.append(DirectTrustJar()) return CombinedTrustEvaluator(handlers, db_engine) diff --git a/pyeudiw/trust/handler/_direct_trust_jwk.py b/pyeudiw/trust/handler/_direct_trust_jwk.py index e8f04191..3290a077 100644 --- a/pyeudiw/trust/handler/_direct_trust_jwk.py +++ b/pyeudiw/trust/handler/_direct_trust_jwk.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Literal from urllib.parse import urlparse + import satosa.context import satosa.response @@ -7,7 +8,6 @@ from pyeudiw.satosa.utils.response import JsonResponse from pyeudiw.tools.base_logger import BaseLogger from pyeudiw.tools.utils import cacheable_get_http_url, get_http_url - from pyeudiw.trust.handler.exception import InvalidJwkMetadataException from pyeudiw.trust.handler.interface import TrustHandlerInterface from pyeudiw.trust.model.trust_source import TrustSourceData @@ -41,12 +41,14 @@ def __init__( httpc_params: dict, jwk_endpoint: str, cache_ttl: int, - jwks: list[dict] | None + jwks: list[dict] | None, + client_id: str = None, ): self.httpc_params = httpc_params self.jwk_endpoint = jwk_endpoint self.cache_ttl = cache_ttl self.http_async_calls = False + self.client_id = client_id # input validation self.jwks = jwks if jwks else [] try: @@ -55,11 +57,8 @@ def __init__( raise ValueError("invalid argument: dictionary is not a jwk", e) def _build_issuing_public_signing_jwks(self) -> list[dict]: - signing_keys = [ - key for key in self.jwks if key.get("use", "") != "enc"] - return [ - JWK(key).as_public_dict() for key in signing_keys - ] + signing_keys = [key for key in self.jwks if key.get("use", "") != "enc"] + return [JWK(key).as_public_dict() for key in signing_keys] def _build_metadata_with_issuer_jwk(self, entity_id: str) -> dict: # This funciton assumed that the issuer is equal to the entity_uri; this @@ -68,9 +67,7 @@ def _build_metadata_with_issuer_jwk(self, entity_id: str) -> dict: # context; but for not we will opt for the simple option. md_dictionary = { "issuer": entity_id, - "jwks": { - "keys": self._build_issuing_public_signing_jwks() - } + "jwks": {"keys": self._build_issuing_public_signing_jwks()}, } return md_dictionary @@ -85,20 +82,20 @@ def _build_metadata_path(self, backend_name: str) -> str: level as it breaks an assuptions of the internal satosa router and there is no way to solve that problem at the satosa backend level. """ - endpoint = backend_name.strip('/') + '/' + self.jwk_endpoint.strip("/") - return endpoint.strip('/') + endpoint = f"{backend_name.strip('/')}/{self.jwk_endpoint.strip('/')}" + return endpoint.strip("/") def _extract_jwks_from_jwk_metadata(self, metadata: dict) -> dict: """ parse the jwk metadata document and return the jwks NOTE: jwks might be in the document by value or by reference """ - jwks: dict[Literal["keys"], list[dict] - ] | None = metadata.get("jwks", None) + jwks: dict[Literal["keys"], list[dict]] | None = metadata.get("jwks", None) jwks_uri: str | None = metadata.get("jwks_uri", None) if (not jwks) and (not jwks_uri): raise InvalidJwkMetadataException( - "invalid issuing key metadata: missing both claims [jwks] and [jwks_uri]") + "invalid issuing key metadata: missing both claims [jwks] and [jwks_uri]" + ) if jwks: # get jwks by value return jwks @@ -110,13 +107,19 @@ def _get_jwk_metadata(self, issuer_id: str) -> dict: endpoint = build_jwk_issuer_endpoint(issuer_id, self.jwk_endpoint) if self.cache_ttl: resp = cacheable_get_http_url( - self.cache_ttl, endpoint, self.httpc_params, http_async=self.http_async_calls) + self.cache_ttl, + endpoint, + self.httpc_params, + http_async=self.http_async_calls, + ) else: - resp = get_http_url([endpoint], self.httpc_params, - http_async=self.http_async_calls)[0] + resp = get_http_url( + [endpoint], self.httpc_params, http_async=self.http_async_calls + )[0] if (not resp) or (resp.status_code != 200): raise InvalidJwkMetadataException( - f"failed to fetch valid jwk metadata: obtained {resp}") + f"failed to fetch valid jwk metadata: obtained {resp}" + ) return resp.json() def _get_jwks_by_reference(self, jwks_reference_uri: str) -> dict: @@ -125,24 +128,40 @@ def _get_jwks_by_reference(self, jwks_reference_uri: str) -> dict: """ if self.cache_ttl: resp = cacheable_get_http_url( - self.cache_ttl, jwks_reference_uri, self.httpc_params, http_async=self.http_async_calls) + self.cache_ttl, + jwks_reference_uri, + self.httpc_params, + http_async=self.http_async_calls, + ) else: resp = get_http_url( - [jwks_reference_uri], self.httpc_params, http_async=self.http_async_calls)[0] + [jwks_reference_uri], + self.httpc_params, + http_async=self.http_async_calls, + )[0] return resp.json() - def build_metadata_endpoints(self, backend_name: str, entity_uri: str) -> list[tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]]: + def build_metadata_endpoints( + self, backend_name: str, entity_uri: str + ) -> list[ + tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]] + ]: if not self.jwk_endpoint: return [] - metadata_path = '^' + self._build_metadata_path(backend_name) + '$' + metadata_path = "^" + self._build_metadata_path(backend_name) + "$" response_json = self._build_metadata_with_issuer_jwk(entity_uri) - def metadata_response_fn(ctx: satosa.context.Context, *args) -> satosa.response.Response: + def metadata_response_fn( + ctx: satosa.context.Context, *args + ) -> satosa.response.Response: return JsonResponse(message=response_json) + return [(metadata_path, metadata_response_fn)] - def extract_and_update_trust_materials(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def extract_and_update_trust_materials( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: """ Fetches the public key of the issuer by querying a given endpoint. Previous responses might or might not be cached based on the cache_ttl @@ -157,27 +176,34 @@ def extract_and_update_trust_materials(self, issuer: str, trust_source: TrustSou self.get_metadata(issuer, trust_source) except Exception as e: self._log_warning( - "updating metadata", f"Exception encountered when updating metadata with {self.__class__.__name__} for issuer {issuer}: {e}") + "updating metadata", + f"Exception encountered when updating metadata with {self.__class__.__name__} for issuer {issuer}: {e}", + ) try: md = self._get_jwk_metadata(issuer) if not issuer == (obt_issuer := md.get("issuer", None)): raise InvalidJwkMetadataException( - f"invalid jwk metadata: obtained issuer :{obt_issuer}, expected issuer: {issuer}") + f"invalid jwk metadata: obtained issuer :{obt_issuer}, expected issuer: {issuer}" + ) jwks = self._extract_jwks_from_jwk_metadata(md) jwk_l: list[dict] = jwks.get("keys", []) if not jwk_l: raise InvalidJwkMetadataException( - "unable to find jwks in issuer jwk metadata") + "unable to find jwks in issuer jwk metadata" + ) trust_source.add_keys(jwk_l) except Exception as e: self._log_warning( - "Extracting JWK", f"Failed to extract jwks from issuer {issuer}: {e}") + "Extracting JWK", f"Failed to extract jwks from issuer {issuer}: {e}" + ) return trust_source - def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def get_metadata( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: # this class does not handle generic metadata information: it fetches and exposes cryptographic material only return trust_source @@ -186,5 +212,5 @@ def build_jwk_issuer_endpoint(issuer_id: str, endpoint_component: str) -> str: if not endpoint_component: return issuer_id baseurl = urlparse(issuer_id) - full_endpoint_path = '/' + endpoint_component.strip('/') + baseurl.path + full_endpoint_path = f"/{endpoint_component.strip('/')}{baseurl.path}" return baseurl._replace(path=full_endpoint_path).geturl() diff --git a/pyeudiw/trust/handler/commons.py b/pyeudiw/trust/handler/commons.py new file mode 100644 index 00000000..062f9287 --- /dev/null +++ b/pyeudiw/trust/handler/commons.py @@ -0,0 +1,11 @@ +import os + +DEFAULT_HTTPC_PARAMS = { + "connection": {"ssl": os.getenv("PYEUDIW_HTTPC_SSL", True)}, + "session": {"timeout": os.getenv("PYEUDIW_HTTPC_TIMEOUT", 6)}, +} + +DEFAULT_OPENID4VCI_METADATA_ENDPOINT = "/.well-known/openid-credential-issuer" +"""Default endpoint where metadata issuer credential are exposed/ +For further reference, see https://openid.net/specs/openid-4-verifiable-credential-issuance-1_0.html#name-well-known-uri-registry +""" diff --git a/pyeudiw/trust/handler/direct_trust_jar.py b/pyeudiw/trust/handler/direct_trust_jar.py index 71c5ed6a..e944fe90 100644 --- a/pyeudiw/trust/handler/direct_trust_jar.py +++ b/pyeudiw/trust/handler/direct_trust_jar.py @@ -1,22 +1,12 @@ -import os - from pyeudiw.trust.handler._direct_trust_jwk import _DirectTrustJwkHandler +from .commons import DEFAULT_HTTPC_PARAMS DEFAULT_JARISSUER_METADATA_ENDPOINT = "/.well-known/jar-issuer" """Default endpoint adopted by potential interopebility document as of version 1.1. The endpopint should be positioned between the host component and the path component (if any) of the iss claim value in the JAR. """ -DEFAULT_DIRECT_TRUST_JAR_HTTPC_PARAMS = { - "connection": { - "ssl": os.getenv("PYEUDIW_HTTPC_SSL", True) - }, - "session": { - "timeout": os.getenv("PYEUDIW_HTTPC_TIMEOUT", 6) - } -} - class DirectTrustJar(_DirectTrustJwkHandler): """DirectTrustJar is specialization of _DirectTrustJwkHandler @@ -25,14 +15,16 @@ class DirectTrustJar(_DirectTrustJwkHandler): def __init__( self, - httpc_params: dict = DEFAULT_DIRECT_TRUST_JAR_HTTPC_PARAMS, + httpc_params: dict = DEFAULT_HTTPC_PARAMS, jwk_endpoint: str = DEFAULT_JARISSUER_METADATA_ENDPOINT, cache_ttl: int = 0, - jwks: list[dict] | None = None + jwks: list[dict] | None = None, + client_id: str = None, ): super().__init__( httpc_params=httpc_params, jwk_endpoint=jwk_endpoint, cache_ttl=cache_ttl, - jwks=jwks + jwks=jwks, + client_id=client_id, ) diff --git a/pyeudiw/trust/handler/direct_trust_sd_jwt_vc.py b/pyeudiw/trust/handler/direct_trust_sd_jwt_vc.py index 934c8ce4..e8e48f3f 100644 --- a/pyeudiw/trust/handler/direct_trust_sd_jwt_vc.py +++ b/pyeudiw/trust/handler/direct_trust_sd_jwt_vc.py @@ -1,29 +1,14 @@ -import os - +from pyeudiw.tools.utils import cacheable_get_http_url, get_http_url from pyeudiw.trust.handler._direct_trust_jwk import _DirectTrustJwkHandler from pyeudiw.trust.model.trust_source import TrustSourceData -from pyeudiw.tools.utils import cacheable_get_http_url, get_http_url +from .commons import DEFAULT_HTTPC_PARAMS, DEFAULT_OPENID4VCI_METADATA_ENDPOINT DEFAULT_SDJWTVC_METADATA_ENDPOINT = "/.well-known/jwt-vc-issuer" """Default endpoint where issuer keys used for sd-jwt vc are exposed. For further reference, see https://www.ietf.org/archive/id/draft-ietf-oauth-sd-jwt-vc-06.html#name-jwt-vc-issuer-metadata """ -DEFAULT_OPENID4VCI_METADATA_ENDPOINT = "/.well-known/openid-credential-issuer" -"""Default endpoint where metadata issuer credential are exposed/ -For further reference, see https://openid.net/specs/openid-4-verifiable-credential-issuance-1_0.html#name-well-known-uri-registry -""" - -DEFAULT_DIRECT_TRUST_SD_JWC_VC_PARAMS = { - "connection": { - "ssl": os.getenv("PYEUDIW_HTTPC_SSL", True) - }, - "session": { - "timeout": os.getenv("PYEUDIW_HTTPC_TIMEOUT", 6) - } -} - class DirectTrustSdJwtVc(_DirectTrustJwkHandler): """DirectTrustSdJwtVc is specialization of _DirectTrustJwkHandler @@ -32,21 +17,25 @@ class DirectTrustSdJwtVc(_DirectTrustJwkHandler): def __init__( self, - httpc_params: dict = DEFAULT_DIRECT_TRUST_SD_JWC_VC_PARAMS, + httpc_params: dict = DEFAULT_HTTPC_PARAMS, jwk_endpoint: str = DEFAULT_SDJWTVC_METADATA_ENDPOINT, metadata_endpoint: str = DEFAULT_OPENID4VCI_METADATA_ENDPOINT, cache_ttl: int = 0, - jwks: list[dict] | None = None + jwks: list[dict] | None = None, + client_id: str = None, ): super().__init__( httpc_params=httpc_params, jwk_endpoint=jwk_endpoint, cache_ttl=cache_ttl, - jwks=jwks + jwks=jwks, + client_id=client_id, ) self.metadata_endpoint = metadata_endpoint - def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def get_metadata( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: """ Fetches the public metadata of an issuer by interrogating a given endpoint. The endpoint must yield information in a format that @@ -57,13 +46,18 @@ def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> TrustSourc url = build_metadata_issuer_endpoint(issuer, self.metadata_endpoint) if self.cache_ttl == 0: trust_source.metadata = get_http_url( - url, self.httpc_params, self.http_async_calls)[0].json() + url, self.httpc_params, self.http_async_calls + )[0].json() else: trust_source.metadata = cacheable_get_http_url( - self.cache_ttl, url, self.httpc_params, self.http_async_calls).json() + self.cache_ttl, url, self.httpc_params, self.http_async_calls + ).json() return trust_source +# TODO: do you really think that this should be stay here? + + def build_metadata_issuer_endpoint(issuer_id: str, endpoint_component: str) -> str: - return issuer_id.rstrip('/') + '/' + endpoint_component.lstrip('/') + return f"{issuer_id.rstrip('/')}/{endpoint_component.lstrip('/')}" diff --git a/pyeudiw/trust/handler/federation.py b/pyeudiw/trust/handler/federation.py index 266629f2..5509de6e 100644 --- a/pyeudiw/trust/handler/federation.py +++ b/pyeudiw/trust/handler/federation.py @@ -1,13 +1,523 @@ -from pyeudiw.trust.handler.interface import TrustHandlerInterface +import json +import logging +from datetime import datetime +from typing import Any, Callable, List, Union + +import satosa +from copy import deepcopy +from cryptojwt.jwk.ec import ECKey +from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jwk.rsa import RSAKey +from satosa.response import Response + +from pyeudiw.federation.exceptions import ProtocolMetadataNotFound, TimeValidationError +from pyeudiw.federation.policy import TrustChainPolicy, combine +from pyeudiw.federation.statements import EntityStatement, get_entity_configurations +from pyeudiw.federation.trust_chain_builder import TrustChainBuilder +from pyeudiw.federation.trust_chain_validator import StaticTrustChainValidator +from pyeudiw.jwk import JWK +from pyeudiw.jwt.jws_helper import JWSHelper +from pyeudiw.jwt.utils import decode_jwt_payload +from pyeudiw.satosa.exceptions import DiscoveryFailedError +from pyeudiw.satosa.utils.response import JsonResponse +from pyeudiw.storage.db_engine import DBEngine +from pyeudiw.storage.exceptions import EntryNotFound from pyeudiw.tools.base_logger import BaseLogger +from pyeudiw.tools.utils import exp_from_now, iat_now +from pyeudiw.trust.exceptions import MissingProtocolSpecificJwks, UnknownTrustAnchor +from pyeudiw.trust.handler.interface import TrustHandlerInterface + +from .commons import DEFAULT_HTTPC_PARAMS + +logger = logging.getLogger(__name__) + + +_ISSUER_METADATA_TYPE = "openid_credential_issuer" class FederationHandler(TrustHandlerInterface, BaseLogger): - def __init__(self, **kargs): - pass + def __init__( + self, + metadata: List[dict], + authority_hints: List[str], + trust_anchors: List[str], + default_sig_alg: str, + federation_jwks: List[dict[str, Union[str, List[str]]]], + trust_marks: List[dict], + federation_entity_metadata: dict[str, str], + client_id: str, + entity_configuration_exp: int = 800, + httpc_params: dict = DEFAULT_HTTPC_PARAMS, + cache_ttl: int = 0, + metadata_type: str = _ISSUER_METADATA_TYPE, + **kwargs, + ): + + self.httpc_params = httpc_params + self.cache_ttl = cache_ttl + # TODO - this MUST be handled in httpc_params ... + self.http_async_calls = False + self.client_id = client_id + + self.metadata_type = metadata_type + self.metadata: dict = metadata + self.authority_hints: List[str] = authority_hints + self.trust_anchors: List[str] = trust_anchors + self.default_sig_alg: str = default_sig_alg + self.federation_jwks: List[dict[str, Union[str, List[str]]]] = federation_jwks + self.trust_marks: List[dict] = trust_marks + self.federation_entity_metadata: dict[str, str] = federation_entity_metadata + self.client_id: str = federation_entity_metadata + self.entity_configuration_exp = entity_configuration_exp + + self.federation_public_jwks = [ + JWK(i).as_public_dict() for i in self.federation_jwks + ] + + if isinstance(self.metadata["jwks"], dict) and self.metadata["jwks"].get('keys'): + self.metadata["jwks"] = self.metadata["jwks"].pop("keys") + + self.metadata_jwks = [JWK(i) for i in self.metadata["jwks"]] + self.metadata["jwks"] = {"keys": [ + i.as_public_dict() for i in self.metadata_jwks + ]} + + self.metadata_policy_resolver = TrustChainPolicy() + + for k, v in kwargs.items(): + if not hasattr(self, k): + logger.warning( + f"Trust - FederationHandler. {k} was provided in the init but not handled." + ) def extract_and_update_trust_materials(self, issuer, trust_source): return trust_source def get_metadata(self, issuer, trust_source): return trust_source + + @property + def entity_configuration(self) -> dict: + """Returns the entity configuration as a JWT.""" + data = self.entity_configuration_as_dict + _jwk = self.federation_jwks[0] + jwshelper = JWSHelper(_jwk) + return jwshelper.sign( + protected={ + "alg": self.default_sig_alg, + "kid": _jwk["kid"], + "typ": "entity-statement+jwt", + }, + plain_dict=data, + ) + + @property + def entity_configuration_as_dict(self) -> dict: + """Returns the entity configuration as a dictionary.""" + ec_payload = { + "exp": exp_from_now(minutes=self.entity_configuration_exp), + "iat": iat_now(), + "iss": self.client_id, + "sub": self.client_id, + "jwks": {"keys": self.federation_public_jwks}, + "metadata": { + self.metadata_type: self.metadata, + "federation_entity": self.federation_entity_metadata, + }, + "authority_hints": self.authority_hints, + } + return ec_payload + + def entity_configuration_endpoint( + self, context: satosa.context.Context + ) -> satosa.response.Response: + """ + Entity Configuration endpoint. + + :param context: The current context + :type context: Context + + :return: The entity configuration + :rtype: Response + """ + + if context.qs_params.get("format", "") == "json": + return Response( + json.dumps(self.entity_configuration_as_dict), + status="200", + content="application/json", + ) + else: + return satosa.response.Response( + self.entity_configuration, + status="200", + content="application/entity-statement+jwt", + ) + + def build_metadata_endpoints( + self, backend_name: str, entity_uri: str + ) -> list[ + tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]] + ]: + + metadata_path = f'^{backend_name.strip("/")}/.well-known/openid-federation$' + response = self.entity_configuration + + def metadata_response_fn( + ctx: satosa.context.Context, *args + ) -> satosa.response.Response: + return JsonResponse(message=response) + + return [(metadata_path, metadata_response_fn)] + + def get_backend_trust_chain(self) -> list[str]: + """ + Get the backend trust chain. In case something raises an Exception (e.g. faulty storage), logs a warning message + and returns an empty list. + + :return: The trust chain + :rtype: list + """ + + try: + breakpoint() + trust_evaluation_helper = self.build_trust_chain_for_entity_id( + storage=self.db_engine, + entity_id=self.client_id, + entity_configuration=self.entity_configuration, + httpc_params=self.httpc_params, + ) + + self.db_engine.add_or_update_trust_attestation( + entity_id=self.client_id, + attestation=trust_evaluation_helper.trust_chain, + exp=trust_evaluation_helper.exp, + ) + return trust_evaluation_helper.trust_chain + + except (DiscoveryFailedError, EntryNotFound, Exception) as e: + message = ( + f"Error while building trust chain for client with id: {self.client_id}. " + f"{e.__class__.__name__}: {e}" + ) + self._log_warning("Trust Chain", message) + + return [] + + @property + def default_federation_private_jwk(self) -> dict: + """Returns the default federation private jwk.""" + return tuple(self.federations_jwks_by_kids.values())[0] + + # era class FederationTrustModel(TrustEvaluator): + + def get_public_keys(self, issuer): + public_keys = [JWK(i).as_public_dict() for i in self.federation_jwks] + + return public_keys + + def get_verified_key( + self, issuer: str, token_header: dict + ) -> ECKey | RSAKey | dict: + # (1) verifica trust chain + kid: str = token_header.get("kid", None) + if not kid: + raise ValueError("missing claim [kid] in token header") + trust_chain: list[str] = token_header.get("trust_chain", None) + if not trust_chain: + raise ValueError("missing trust chain in federation token") + if not isinstance(trust_chain, list): + raise ValueError * ("invalid format of header claim [trust_claim]") + # TODO: check whick exceptions this might raise + self._verify_trust_chain(trust_chain) + + # (2) metadata parsing ed estrazione Jwk set + # TODO: wrap in something that implements VciJwksSource + # apply policy of traust anchor only? + issuer_entity_configuration = trust_chain[0] + anchor_entity_configuration = trust_chain[-1] + issuer_payload: dict = decode_jwt_payload(issuer_entity_configuration) + anchor_payload = decode_jwt_payload(anchor_entity_configuration) + trust_anchor_policy = anchor_payload.get("metadata_policy", {}) + final_issuer_metadata = self.metadata_policy_resolver.apply_policy( + issuer_payload, trust_anchor_policy + ) + metadata: dict = final_issuer_metadata.get("metadata", None) + if not metadata: + raise ValueError( + "missing or invalid claim [metadata] in entity configuration" + ) + issuer_metadata: dict = metadata.get(_ISSUER_METADATA_TYPE, None) + if not issuer_metadata: + raise ValueError( + f"missing or invalid claim [metadata.{_ISSUER_METADATA_TYPE}] in entity configuration" + ) + issuer_keys: list[dict] = issuer_metadata.get("jwks", {}).get("keys", []) + if not issuer_keys: + raise ValueError( + f"missing or invalid claim [metadata.{_ISSUER_METADATA_TYPE}.jwks.keys] in entity configuration" + ) + # check issuer = entity_id + if issuer != (obt_iss := final_issuer_metadata.get("iss", "")): + raise ValueError( + f"invalid issuer metadata: expected '{issuer}', obtained '{obt_iss}'" + ) + + # (3) dato il set completo, fa il match per kid tra l'header e il jwk set + found_jwks: list[dict] = [] + for key in issuer_keys: + obt_kid: str = key.get("kid", "") + if kid == obt_kid: + found_jwks.append(key) + if len(found_jwks) != 1: + raise ValueError( + f"unable to uniquely identify a key with kid {kid} in appropriate section of issuer entity configuration" + ) + try: + return key_from_jwk_dict(**found_jwks[0]) + except Exception as e: + raise ValueError(f"unable to parse issuer jwk: {e}") + + def init_trust_resources(self) -> None: + """ + Initializes the trust resources. + """ + + # private keys by kid + self.federations_jwks_by_kids = { + i["kid"]: i + for i in self.config["federation_jwks"] + } + # dumps public jwks + self.federation_public_jwks = [ + key_from_jwk_dict(i).serialize() + for i in self.config["federation_jwks"] + ] + # we close the connection in this constructor since it must be fork safe and + # get reinitialized later on, within each fork + self.update_trust_anchors() + + try: + self.get_backend_trust_chain() + except Exception as e: + self._log_critical( + "Backend Trust", f"Cannot fetch the trust anchor configuration: {e}" + ) + + self.db_engine.close() + self._db_engine = None + + def update_trust_anchors(self): + """ + Updates the trust anchors of current instance. + """ + + tas = self.config["trust_anchors"] + self._log_info("Trust Anchors updates", f"Trying to update: {tas}") + + for ta in tas: + try: + self.update_trust_anchors_ecs( + db=self.db_engine, + trust_anchors=[ta], + httpc_params=self.config["httpc_params"], + ) + except Exception as e: + self._log_warning("Trust Anchor updates", f"{ta} update failed: {e}") + + self._log_info("Trust Anchor updates", f"{ta} updated") + + def _update_chain( + self, + entity_id: str | None = None, + exp: datetime | None = None, + trust_chain: list | None = None, + ): + if entity_id is not None: + self.entity_id = entity_id + + if exp is not None: + self.exp = exp + + if trust_chain is not None: + self.trust_chain = trust_chain + + def _handle_federation_chain(self, trust_chain): + _first_statement = decode_jwt_payload(trust_chain[-1]) + trust_anchor_eid = self.trust_anchor or _first_statement.get("iss", None) + + if not trust_anchor_eid: + raise UnknownTrustAnchor( + "Unknown Trust Anchor: can't find 'iss' in the " + f"first entity statement: {_first_statement} " + ) + + try: + trust_anchor = self.storage.get_trust_anchor(trust_anchor_eid) + except EntryNotFound: + raise UnknownTrustAnchor( + f"Unknown Trust Anchor: '{trust_anchor_eid}' is not " + "a recognizable Trust Anchor." + ) + + decoded_ec = decode_jwt_payload( + trust_anchor["federation"]["entity_configuration"] + ) + jwks = decoded_ec.get("jwks", {}).get("keys", []) + + if not jwks: + raise MissingProtocolSpecificJwks(f"Cannot find any jwks in {decoded_ec}") + + tc = StaticTrustChainValidator(self.trust_chain, jwks, self.httpc_params) + self._update_chain(entity_id=tc.entity_id, exp=tc.exp) + + _is_valid = False + + try: + _is_valid = tc.validate() + except TimeValidationError: + logger.warn(f"Trust Chain {tc.entity_id} is expired") + except Exception as e: + logger.warn( + f"Cannot validate Trust Chain {tc.entity_id} for the following reason: {e}" + ) + + db_chain = None + + if not _is_valid: + try: + db_chain = self.storage.get_trust_attestation(self.entity_id)[ + "federation" + ]["chain"] + if StaticTrustChainValidator( + db_chain, jwks, self.httpc_params + ).is_valid: + self.is_trusted = True + return self.is_trusted + + except (EntryNotFound, Exception): + pass + + _is_valid = tc.update() + + self._update_chain(trust_chain=tc.trust_chain, exp=tc.exp) + + # the good trust chain is then stored + self.storage.add_or_update_trust_attestation( + entity_id=self.entity_id, + attestation=tc.trust_chain, + exp=datetime.fromtimestamp(tc.exp), + ) + + self.is_trusted = _is_valid + return _is_valid + + def get_final_metadata(self, metadata_type: str, policies: list[dict]) -> dict: + policy_acc = {"metadata": {}, "metadata_policy": {}} + + for policy in policies: + policy_acc = combine(policy, policy_acc) + + self.final_metadata = decode_jwt_payload(self.trust_chain[0]) + + try: + # TODO: there are some cases where the jwks are taken from a uri ... + selected_metadata = { + "metadata": self.final_metadata["metadata"], + "metadata_policy": {}, + } + + self.final_metadata = TrustChainPolicy().apply_policy( + selected_metadata, policy_acc + ) + + return self.final_metadata["metadata"][metadata_type] + except KeyError: + raise ProtocolMetadataNotFound( + f"{metadata_type} not found in the final metadata:" + f" {self.final_metadata['metadata']}" + ) + + def get_trusted_jwks( + self, metadata_type: str, policies: list[dict] = [] + ) -> list[dict]: + return ( + self.get_final_metadata(metadata_type=metadata_type, policies=policies) + .get("jwks", {}) + .get("keys", []) + ) + + def discovery( + self, entity_id: str, entity_configuration: EntityStatement | None = None + ): + """ + Updates fields ``trust_chain`` and ``exp`` based on the discovery process. + + :raises: DiscoveryFailedError: raises an error if the discovery fails. + """ + trust_anchor_eid = self.trust_anchor + _ta_ec = self.storage.get_trust_anchor(entity_id=trust_anchor_eid) + ta_ec = _ta_ec["federation"]["entity_configuration"] + + tcbuilder = TrustChainBuilder( + subject=entity_id, + trust_anchor=trust_anchor_eid, + trust_anchor_configuration=ta_ec, + subject_configuration=entity_configuration, + httpc_params=self.httpc_params, + ) + + self._update_chain(trust_chain=tcbuilder.get_trust_chain(), exp=tcbuilder.exp) + is_good = tcbuilder.is_valid + if not is_good: + raise DiscoveryFailedError( + f"Discovery failed for entity {entity_id} with configuration {entity_configuration}" + ) + + def build_trust_chain_for_entity_id(self, entity_id: str): + """ + Builds a ``TrustEvaluationHelper`` and returns it if the trust chain is valid. + In case the trust chain is invalid, tries to validate it in discovery before returning it. + + :return: The svg data for html, base64 encoded + :rtype: str + """ + db_chain: list = self.storage.get_trust_attestation(entity_id) + + if len(db_chain) == 0: + db_chain = self.discovery(self.entity_id) + else: + self.is_valid = self._handle_federation_chain() + return self.is_valid + + return False + + def update_trust_anchors_ecs(self, trust_anchors: list[str], db: DBEngine) -> None: + """ + Update the trust anchors entity configurations. + + :param trust_anchors: The trust anchors + :type trust_anchors: list + :param db: The database engine + :type db: DBEngine + :param httpc_params: The HTTP client parameters + :type httpc_params: dict + """ + + ta_ecs = get_entity_configurations( + trust_anchors, httpc_params=self.httpc_params + ) + + for jwt in ta_ecs: + if isinstance(jwt, bytes): + jwt = jwt.decode() + + ec = EntityStatement(jwt, httpc_params=self.httpc_params) + if not ec.validate_by_itself(): + logger.warning( + f"The trust anchor failed the validation of its EntityConfiguration {ec}" + ) + + db.add_trust_anchor( + entity_id=ec.sub, entity_configuration=ec.jwt, exp=ec.exp + ) diff --git a/pyeudiw/trust/handler/interface.py b/pyeudiw/trust/handler/interface.py index 1b81c602..bb504a6c 100644 --- a/pyeudiw/trust/handler/interface.py +++ b/pyeudiw/trust/handler/interface.py @@ -1,4 +1,5 @@ from typing import Any, Callable + import satosa.context import satosa.response @@ -6,7 +7,12 @@ class TrustHandlerInterface: - def extract_and_update_trust_materials(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def __init__(self, *args, **kwargs): + pass + + def extract_and_update_trust_materials( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: """ Extract the trust material of a certain issuer using a trust handler implementation. @@ -20,7 +26,9 @@ def extract_and_update_trust_materials(self, issuer: str, trust_source: TrustSou """ raise NotImplementedError - def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> TrustSourceData: + def get_metadata( + self, issuer: str, trust_source: TrustSourceData + ) -> TrustSourceData: """ Get the metadata of a certain issuer if is needed by the specifics. @@ -35,7 +43,11 @@ def get_metadata(self, issuer: str, trust_source: TrustSourceData) -> TrustSourc raise NotImplementedError - def build_metadata_endpoints(self, backend_name: str, entity_uri: str) -> list[tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]]: + def build_metadata_endpoints( + self, backend_name: str, entity_uri: str + ) -> list[ + tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]] + ]: """ Expose one or more metadata endpoint required to publish metadata information about *myself* and that are associated to a trust diff --git a/pyeudiw/trust/interface.py b/pyeudiw/trust/interface.py index 696c1698..ffef42fa 100644 --- a/pyeudiw/trust/interface.py +++ b/pyeudiw/trust/interface.py @@ -1,4 +1,5 @@ from typing import Any, Callable + import satosa.context import satosa.response @@ -36,7 +37,11 @@ def get_metadata(self, issuer: str) -> dict: """ raise NotImplementedError - def build_metadata_endpoints(self, base_path: str) -> list[tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]]]: + def build_metadata_endpoints( + self, base_path: str + ) -> list[ + tuple[str, Callable[[satosa.context.Context, Any], satosa.response.Response]] + ]: """ Expose one or more metadata endpoint required to publish metadata information about *myself* and that are associated to a trust diff --git a/pyeudiw/trust/model/__init__.py b/pyeudiw/trust/model/__init__.py index be6e838c..72b54e5f 100644 --- a/pyeudiw/trust/model/__init__.py +++ b/pyeudiw/trust/model/__init__.py @@ -1,8 +1,10 @@ import sys + if float(f"{sys.version_info.major}.{sys.version_info.minor}") >= 3.12: from typing import TypedDict else: from typing_extensions import TypedDict -TrustModuleConfiguration_T = TypedDict("_DynamicTrustConfiguration", { - "module": str, "class": str, "config": dict}) +TrustModuleConfiguration_T = TypedDict( + "_DynamicTrustConfiguration", {"module": str, "class": str, "config": dict} +) diff --git a/pyeudiw/trust/model/trust_source.py b/pyeudiw/trust/model/trust_source.py index bdb744a4..f459dc04 100644 --- a/pyeudiw/trust/model/trust_source.py +++ b/pyeudiw/trust/model/trust_source.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from datetime import datetime from typing import Optional + from cryptojwt.jwk.jwk import key_from_jwk_dict @@ -50,7 +51,7 @@ def serialize(self) -> dict[str, any]: return { "type": self.type, "trust_params": self.trust_params, - "expiration_date": self.expiration_date + "expiration_date": self.expiration_date, } @property @@ -104,8 +105,9 @@ def __init__( self.additional_data = kwargs - self.trust_params = {type: TrustParameterData( - **tp) for type, tp in trust_params.items()} + self.trust_params = { + type: TrustParameterData(**tp) for type, tp in trust_params.items() + } def add_key(self, key: dict) -> None: """ @@ -173,11 +175,13 @@ def serialize(self) -> dict[str, any]: "metadata": self.metadata, "revoked": self.revoked, "keys": self.keys, - "trust_params": {type: param.serialize() for type, param in self.trust_params.items()} + "trust_params": { + type: param.serialize() for type, param in self.trust_params.items() + }, } @staticmethod - def empty(entity_id: str) -> 'TrustSourceData': + def empty(entity_id: str) -> "TrustSourceData": """ Return the empty trust source data. @@ -186,10 +190,12 @@ def empty(entity_id: str) -> 'TrustSourceData': :returns: The empty trust source data :rtype: TrustSourceData """ - return TrustSourceData(entity_id, policies={}, metadata={}, revoked=False, keys=[], trust_params={}) + return TrustSourceData( + entity_id, policies={}, metadata={}, revoked=False, keys=[], trust_params={} + ) @staticmethod - def from_dict(data: dict) -> 'TrustSourceData': + def from_dict(data: dict) -> "TrustSourceData": """ Return the trust source data from the given dictionary. diff --git a/pyeudiw/trust/trust_anchors.py b/pyeudiw/trust/trust_anchors.py deleted file mode 100644 index 33e0cb61..00000000 --- a/pyeudiw/trust/trust_anchors.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -from pyeudiw.federation.statements import ( - get_entity_configurations, - EntityStatement -) -from pyeudiw.storage.db_engine import DBEngine - -logger = logging.getLogger(__name__) - - -def update_trust_anchors_ecs(trust_anchors: list[str], db: DBEngine, httpc_params: dict) -> None: - """ - Update the trust anchors entity configurations. - - :param trust_anchors: The trust anchors - :type trust_anchors: list - :param db: The database engine - :type db: DBEngine - :param httpc_params: The HTTP client parameters - :type httpc_params: dict - """ - - ta_ecs = get_entity_configurations( - trust_anchors, httpc_params=httpc_params - ) - - for jwt in ta_ecs: - if isinstance(jwt, bytes): - jwt = jwt.decode() - - ec = EntityStatement(jwt, httpc_params=httpc_params) - if not ec.validate_by_itself(): - logger.warning( - f"The trust anchor failed the validation of its EntityConfiguration {ec}") - - db.add_trust_anchor( - entity_id=ec.sub, - entity_configuration=ec.jwt, - exp=ec.exp - ) diff --git a/pyeudiw/x509/verify.py b/pyeudiw/x509/verify.py index 15099244..5830289d 100644 --- a/pyeudiw/x509/verify.py +++ b/pyeudiw/x509/verify.py @@ -1,12 +1,12 @@ -import pem import logging -from OpenSSL import crypto from datetime import datetime from ssl import DER_cert_to_PEM_cert -from cryptography.x509 import load_der_x509_certificate +import pem +from cryptography.x509 import load_der_x509_certificate from cryptojwt.jwk.ec import ECKey from cryptojwt.jwk.rsa import RSAKey +from OpenSSL import crypto LOG_ERROR = "x509 verification failed: {}" @@ -26,8 +26,7 @@ def _verify_x509_certificate_chain(pems: list[str]): try: store = crypto.X509Store() x509_certs = [ - crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) - for pem in pems + crypto.load_certificate(crypto.FILETYPE_PEM, str(pem)) for pem in pems ] for cert in x509_certs[:-1]: @@ -87,7 +86,9 @@ def _check_datetime(exp: datetime | None): return True -def verify_x509_attestation_chain(x5c: list[bytes], exp: datetime | None = None) -> bool: +def verify_x509_attestation_chain( + x5c: list[bytes], exp: datetime | None = None +) -> bool: """ Verify the x509 attestation certificate chain. diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..ce175db4 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +# pytest.ini or .pytest.ini +[pytest] +minversion = 6.0 +addopts = -ra -q +testpaths = + pyeudiw diff --git a/requirements-dev.txt b/requirements-dev.txt index c97c5efe..91ab5065 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ isort autoflake bandit autopep8 +black beautifulsoup4 lxml freezegun