Skip to content

Commit

Permalink
Use mypy to lint types (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
kislyuk authored Aug 21, 2022
1 parent d39a410 commit 0c1465d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
test_deps:
pip install coverage flake8 wheel
pip install coverage flake8 wheel mypy types-certifi types-pyOpenSSL lxml-stubs

lint: test_deps
flake8 $$(python setup.py --name) test
mypy $$(python setup.py --name) --check-untyped-defs

test: test_deps lint
coverage run --source=$$(python setup.py --name) ./test/test.py
Expand Down
20 changes: 11 additions & 9 deletions signxml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from base64 import b64decode, b64encode
from collections import namedtuple
from enum import Enum
from typing import List, Tuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa, utils
Expand Down Expand Up @@ -121,7 +122,7 @@ class XMLSignatureProcessor(XMLProcessor):
"urn:oid:1.3.132.0.37": ec.SECT409R1,
"urn:oid:1.3.132.0.38": ec.SECT571K1,
}
known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()}
known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()} # type: ignore

known_c14n_algorithms = {
"http://www.w3.org/TR/2001/REC-xml-c14n-20010315",
Expand All @@ -133,7 +134,7 @@ class XMLSignatureProcessor(XMLProcessor):
}
default_c14n_algorithm = "http://www.w3.org/2006/12/xml-c14n11"

id_attributes = ("Id", "ID", "id", "xml:id")
id_attributes: Tuple[str, ...] = ("Id", "ID", "id", "xml:id")

def _get_digest(self, data, digest_algorithm):
hasher = Hash(algorithm=digest_algorithm, backend=default_backend())
Expand Down Expand Up @@ -586,7 +587,8 @@ def _verify_signature_with_pubkey(self, signed_info_c14n, raw_signature, key_val
x = bytes_to_long(key_data[:len(key_data) // 2])
y = bytes_to_long(key_data[len(key_data) // 2:])
curve_class = self.known_ecdsa_curves[named_curve.get("URI")]
key = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve_class()).public_key(backend=default_backend())
ecpn = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve_class()) # type: ignore
key = ecpn.public_key(backend=default_backend())
elif not isinstance(key, ec.EllipticCurvePublicKey):
raise InvalidInput("DER encoded key value does not match specified signature algorithm")
dss_signature = self._encode_dss_signature(raw_signature, key.key_size)
Expand All @@ -604,8 +606,8 @@ def _verify_signature_with_pubkey(self, signed_info_c14n, raw_signature, key_val
q = self._get_long(dsa_key_value, "Q")
g = self._get_long(dsa_key_value, "G", require=False)
y = self._get_long(dsa_key_value, "Y")
pn = dsa.DSAPublicNumbers(y=y, parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g))
key = pn.public_key(backend=default_backend())
dsapn = dsa.DSAPublicNumbers(y=y, parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g))
key = dsapn.public_key(backend=default_backend()) # type: ignore
elif not isinstance(key, dsa.DSAPublicKey):
raise InvalidInput("DER encoded key value does not match specified signature algorithm")
# TODO: supply meaningful key_size_bits for signature length assertion
Expand Down Expand Up @@ -898,7 +900,7 @@ def verify(self, data, require_x509=True, x509_cert=None, cert_subject_name=None
der_encoded_key_value=der_encoded_key_value,
signature_alg=signature_alg)

verify_results = []
verify_results: List[VerifyResult] = []
for reference in self._findall(signed_info, "Reference"):
copied_root = self.fromstring(self.tostring(root))
copied_signature_ref = self._get_signature(copied_root)
Expand Down Expand Up @@ -982,9 +984,9 @@ def check_der_key_value_matches_cert_public_key(self, der_encoded_key_value, pub
elif "dsa-" in signature_alg \
and isinstance(der_public_key, dsa.DSAPublicKey) \
and isinstance(public_key.to_cryptography_key(), dsa.DSAPublicKey):
p = der_public_key.public_numbers().parameter_numbers().p
q = der_public_key.public_numbers().parameter_numbers().q
g = der_public_key.public_numbers().parameter_numbers().g
p = der_public_key.public_numbers().parameter_numbers().p # type: ignore
q = der_public_key.public_numbers().parameter_numbers().q # type: ignore
g = der_public_key.public_numbers().parameter_numbers().g # type: ignore

pubk_p = public_key.to_cryptography_key().public_numbers().p
pubk_q = public_key.to_cryptography_key().public_numbers().q
Expand Down
2 changes: 1 addition & 1 deletion signxml/__pyinstaller/hook-signxml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Hook for pyinstaller to include the files are signxml/schemas/* into the final build."""

from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_data_files # type: ignore

datas = collect_data_files('signxml', excludes=['__pyinstaller'])
10 changes: 7 additions & 3 deletions signxml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import cryptography.exceptions


class InvalidSignature(cryptography.exceptions.InvalidSignature):
class SignXMLException(Exception):
pass


class InvalidSignature(cryptography.exceptions.InvalidSignature, SignXMLException):
"""
Raised when signature validation fails.
"""
Expand All @@ -23,9 +27,9 @@ class InvalidCertificate(InvalidSignature):
"""


class InvalidInput(ValueError):
class InvalidInput(ValueError, SignXMLException):
pass


class RedundantCert(Exception):
class RedundantCert(SignXMLException):
pass
16 changes: 5 additions & 11 deletions signxml/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
import os
import re
import struct
import sys
import textwrap
from base64 import b64decode, b64encode
from xml.etree import ElementTree as stdlibElementTree

from lxml import etree

from ..exceptions import InvalidCertificate, InvalidInput, RedundantCert

USING_PYTHON2 = True if sys.version_info < (3, 0) else False
from ..exceptions import InvalidCertificate, InvalidInput, RedundantCert, SignXMLException

PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"
Expand Down Expand Up @@ -49,8 +46,6 @@ def bytes_to_long(s):
# On Python 2, indexing into a bytearray returns a byte string; on Python 3, an int.
return s
acc = 0
if USING_PYTHON2:
acc = long(acc) # noqa
unpack = struct.unpack
length = len(s)
if length % 4:
Expand All @@ -72,8 +67,6 @@ def long_to_bytes(n, blocksize=0):
"""
# after much testing, this algorithm was deemed to be the fastest
s = b''
if USING_PYTHON2:
n = long(n) # noqa
pack = struct.pack
while n > 0:
s = pack(b'>I', n & 0xffffffff) + s
Expand Down Expand Up @@ -109,7 +102,7 @@ def bits_to_bytes_unit(num_of_bits):

def strip_pem_header(cert):
try:
return re.search(pem_regexp, ensure_str(cert)).group(1).replace("\r", "")
return re.search(pem_regexp, ensure_str(cert)).group(1).replace("\r", "") # type: ignore
except Exception:
return ensure_str(cert).replace("\r", "")

Expand All @@ -132,7 +125,7 @@ def __getattr__(self, a):


class XMLProcessor:
_schema, _default_parser = None, None
_schema, _default_parser, _parser, schema_file = None, None, None, ""

@classmethod
def schema(cls):
Expand Down Expand Up @@ -245,7 +238,8 @@ def verify_x509_cert_chain(cert_chain, ca_pem_file=None, ca_path=None):
context.load_verify_locations(ensure_bytes(ca_pem_file, none_ok=True), capath=ca_path)
store = context.get_cert_store()
certs = list(reversed(cert_chain))
end_of_chain, last_error = None, None
end_of_chain = None
last_error: Exception = SignXMLException("Invalid certificate chain")
while len(certs) > 0:
for cert in certs:
try:
Expand Down

0 comments on commit 0c1465d

Please sign in to comment.