Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mypy to lint types #195

Merged
merged 7 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
test_deps:
pip install coverage flake8 wheel
pip install coverage flake8 wheel mypy types-certifi types-pyOpenSSL lxml-stubs

lint: test_deps
flake8 $$(python setup.py --name) test
mypy $$(python setup.py --name) --check-untyped-defs

test: test_deps lint
coverage run --source=$$(python setup.py --name) ./test/test.py
Expand Down
20 changes: 11 additions & 9 deletions signxml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from base64 import b64decode, b64encode
from collections import namedtuple
from enum import Enum
from typing import List, Tuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa, utils
Expand Down Expand Up @@ -121,7 +122,7 @@ class XMLSignatureProcessor(XMLProcessor):
"urn:oid:1.3.132.0.37": ec.SECT409R1,
"urn:oid:1.3.132.0.38": ec.SECT571K1,
}
known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()}
known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()} # type: ignore

known_c14n_algorithms = {
"http://www.w3.org/TR/2001/REC-xml-c14n-20010315",
Expand All @@ -133,7 +134,7 @@ class XMLSignatureProcessor(XMLProcessor):
}
default_c14n_algorithm = "http://www.w3.org/2006/12/xml-c14n11"

id_attributes = ("Id", "ID", "id", "xml:id")
id_attributes: Tuple[str, ...] = ("Id", "ID", "id", "xml:id")

def _get_digest(self, data, digest_algorithm):
hasher = Hash(algorithm=digest_algorithm, backend=default_backend())
Expand Down Expand Up @@ -586,7 +587,8 @@ def _verify_signature_with_pubkey(self, signed_info_c14n, raw_signature, key_val
x = bytes_to_long(key_data[:len(key_data) // 2])
y = bytes_to_long(key_data[len(key_data) // 2:])
curve_class = self.known_ecdsa_curves[named_curve.get("URI")]
key = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve_class()).public_key(backend=default_backend())
ecpn = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve_class()) # type: ignore
key = ecpn.public_key(backend=default_backend())
elif not isinstance(key, ec.EllipticCurvePublicKey):
raise InvalidInput("DER encoded key value does not match specified signature algorithm")
dss_signature = self._encode_dss_signature(raw_signature, key.key_size)
Expand All @@ -604,8 +606,8 @@ def _verify_signature_with_pubkey(self, signed_info_c14n, raw_signature, key_val
q = self._get_long(dsa_key_value, "Q")
g = self._get_long(dsa_key_value, "G", require=False)
y = self._get_long(dsa_key_value, "Y")
pn = dsa.DSAPublicNumbers(y=y, parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g))
key = pn.public_key(backend=default_backend())
dsapn = dsa.DSAPublicNumbers(y=y, parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g))
key = dsapn.public_key(backend=default_backend()) # type: ignore
elif not isinstance(key, dsa.DSAPublicKey):
raise InvalidInput("DER encoded key value does not match specified signature algorithm")
# TODO: supply meaningful key_size_bits for signature length assertion
Expand Down Expand Up @@ -898,7 +900,7 @@ def verify(self, data, require_x509=True, x509_cert=None, cert_subject_name=None
der_encoded_key_value=der_encoded_key_value,
signature_alg=signature_alg)

verify_results = []
verify_results: List[VerifyResult] = []
for reference in self._findall(signed_info, "Reference"):
copied_root = self.fromstring(self.tostring(root))
copied_signature_ref = self._get_signature(copied_root)
Expand Down Expand Up @@ -982,9 +984,9 @@ def check_der_key_value_matches_cert_public_key(self, der_encoded_key_value, pub
elif "dsa-" in signature_alg \
and isinstance(der_public_key, dsa.DSAPublicKey) \
and isinstance(public_key.to_cryptography_key(), dsa.DSAPublicKey):
p = der_public_key.public_numbers().parameter_numbers().p
q = der_public_key.public_numbers().parameter_numbers().q
g = der_public_key.public_numbers().parameter_numbers().g
p = der_public_key.public_numbers().parameter_numbers().p # type: ignore
q = der_public_key.public_numbers().parameter_numbers().q # type: ignore
g = der_public_key.public_numbers().parameter_numbers().g # type: ignore

pubk_p = public_key.to_cryptography_key().public_numbers().p
pubk_q = public_key.to_cryptography_key().public_numbers().q
Expand Down
2 changes: 1 addition & 1 deletion signxml/__pyinstaller/hook-signxml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Hook for pyinstaller to include the files are signxml/schemas/* into the final build."""

from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_data_files # type: ignore

datas = collect_data_files('signxml', excludes=['__pyinstaller'])
10 changes: 7 additions & 3 deletions signxml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import cryptography.exceptions


class InvalidSignature(cryptography.exceptions.InvalidSignature):
class SignXMLException(Exception):
pass


class InvalidSignature(cryptography.exceptions.InvalidSignature, SignXMLException):
"""
Raised when signature validation fails.
"""
Expand All @@ -23,9 +27,9 @@ class InvalidCertificate(InvalidSignature):
"""


class InvalidInput(ValueError):
class InvalidInput(ValueError, SignXMLException):
pass


class RedundantCert(Exception):
class RedundantCert(SignXMLException):
pass
16 changes: 5 additions & 11 deletions signxml/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
import os
import re
import struct
import sys
import textwrap
from base64 import b64decode, b64encode
from xml.etree import ElementTree as stdlibElementTree

from lxml import etree

from ..exceptions import InvalidCertificate, InvalidInput, RedundantCert

USING_PYTHON2 = True if sys.version_info < (3, 0) else False
from ..exceptions import InvalidCertificate, InvalidInput, RedundantCert, SignXMLException

PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"
Expand Down Expand Up @@ -49,8 +46,6 @@ def bytes_to_long(s):
# On Python 2, indexing into a bytearray returns a byte string; on Python 3, an int.
return s
acc = 0
if USING_PYTHON2:
acc = long(acc) # noqa
unpack = struct.unpack
length = len(s)
if length % 4:
Expand All @@ -72,8 +67,6 @@ def long_to_bytes(n, blocksize=0):
"""
# after much testing, this algorithm was deemed to be the fastest
s = b''
if USING_PYTHON2:
n = long(n) # noqa
pack = struct.pack
while n > 0:
s = pack(b'>I', n & 0xffffffff) + s
Expand Down Expand Up @@ -109,7 +102,7 @@ def bits_to_bytes_unit(num_of_bits):

def strip_pem_header(cert):
try:
return re.search(pem_regexp, ensure_str(cert)).group(1).replace("\r", "")
return re.search(pem_regexp, ensure_str(cert)).group(1).replace("\r", "") # type: ignore
except Exception:
return ensure_str(cert).replace("\r", "")

Expand All @@ -132,7 +125,7 @@ def __getattr__(self, a):


class XMLProcessor:
_schema, _default_parser = None, None
_schema, _default_parser, _parser, schema_file = None, None, None, ""

@classmethod
def schema(cls):
Expand Down Expand Up @@ -245,7 +238,8 @@ def verify_x509_cert_chain(cert_chain, ca_pem_file=None, ca_path=None):
context.load_verify_locations(ensure_bytes(ca_pem_file, none_ok=True), capath=ca_path)
store = context.get_cert_store()
certs = list(reversed(cert_chain))
end_of_chain, last_error = None, None
end_of_chain = None
last_error: Exception = SignXMLException("Invalid certificate chain")
while len(certs) > 0:
for cert in certs:
try:
Expand Down