diff options
-rw-r--r-- | src/scadere/check.py | 32 | ||||
-rw-r--r-- | tst/test_check.py | 15 | ||||
-rw-r--r-- | tst/test_listen.py | 16 |
3 files changed, 37 insertions, 26 deletions
diff --git a/src/scadere/check.py b/src/scadere/check.py index aaabe3f..288e599 100644 --- a/src/scadere/check.py +++ b/src/scadere/check.py @@ -31,18 +31,19 @@ from . import __version__, GNUHelpFormatter, NetLoc __all__ = ['main'] -class CtlChrTrans: - """Translator for printing Unicode control characters.""" +def is_control_character(character): + """Check if a Unicode character belongs to the control category.""" + return unicode_category(character) == 'Cc' - def __getitem__(self, ordinal): - if unicode_category(chr(ordinal)) == 'Cc': - return 0xfffd # replacement character '�' - raise KeyError + +def printable(string): + """Check if the given Unicode string is printable.""" + return not any(map(is_control_character, string)) def base64_from_str(string): """Convert string to base64 format in bytes.""" - return base64(string.translate(CtlChrTrans()).encode()).decode() + return base64(string.encode()).decode() def check(netlocs, after, output, fake_ca=None): @@ -55,6 +56,7 @@ def check(netlocs, after, output, fake_ca=None): fake_ca.configure_trust(ctx) for hostname, port in netlocs: + now = datetime.now(tz=timezone.utc).isoformat(timespec='seconds') netloc = f'{hostname}:{port}' stderr.write(f'TLS certificate for {netloc} ') try: @@ -64,19 +66,27 @@ def check(netlocs, after, output, fake_ca=None): cert = conn.getpeercert() except Exception as exception: stderr.write(f'cannot be retrieved: {exception}\n') - now = datetime.now(tz=timezone.utc).isoformat() print(now, 'N/A', hostname, port, 'N/A', base64_from_str(str(exception)), file=output) - else: - ca = dict(chain.from_iterable(cert['issuer']))['organizationName'] + continue + + try: not_before = parsedate(cert['notBefore']) not_after = parsedate(cert['notAfter']) + ca = dict(chain.from_iterable(cert['issuer']))['organizationName'] + if not printable(ca): + raise ValueError(f'CA name contains control character: {ca!r}') + serial = int(cert['serialNumber'], 16) + except Exception as exception: + stderr.write(f'cannot be parsed: {exception}\n') + print(now, 'N/A', hostname, port, 'N/A', + base64_from_str(str(exception)), file=output) + else: if after < not_after: after_seconds = after.isoformat(timespec='seconds') stderr.write(f'will not expire at {after_seconds}\n') else: stderr.write(f'will expire at {not_after.isoformat()}\n') - serial = cert['serialNumber'].translate(CtlChrTrans()) print(not_before.isoformat(), not_after.isoformat(), # As unique identifier hostname, port, serial, diff --git a/tst/test_check.py b/tst/test_check.py index 0ebc0ab..23be6f5 100644 --- a/tst/test_check.py +++ b/tst/test_check.py @@ -17,7 +17,6 @@ # along with scadere. If not, see <https://www.gnu.org/licenses/>. from asyncio import get_running_loop, start_server -from base64 import urlsafe_b64encode as base64 from datetime import datetime, timedelta, timezone from io import StringIO from ssl import Purpose, create_default_context as tls_context @@ -26,7 +25,7 @@ from hypothesis import given from pytest import mark from trustme import CA -from scadere.check import CtlChrTrans, base64_from_str, check +from scadere.check import base64_from_str, check, printable from scadere.listen import parse_summary, str_from_base64 SECONDS_AGO = datetime.now(tz=timezone.utc) @@ -36,8 +35,7 @@ NEXT_WEEK = SECONDS_AGO + timedelta(days=7) @given(...) def test_base64(string: str): - printable_string = string.translate(CtlChrTrans()) - assert str_from_base64(base64_from_str(string)) == printable_string + assert str_from_base64(base64_from_str(string)) == string async def noop(reader, writer): @@ -67,7 +65,7 @@ async def get_cert_summary(netloc, after, ca): @mark.parametrize('domain', ['localhost']) -@mark.parametrize('ca_name', ['trustme']) +@mark.parametrize('ca_name', ['trustme', '\x1f']) @mark.parametrize('not_after', [SECONDS_AGO, NEXT_DAY, NEXT_WEEK]) @mark.parametrize('after', [NEXT_DAY, NEXT_WEEK]) @mark.parametrize('trust_ca', [False, True]) @@ -88,11 +86,14 @@ async def test_check(domain, ca_name, not_after, after, trust_ca): elif not_after == SECONDS_AGO: assert failed_to_get_cert(summary) assert 'certificate has expired' in str_from_base64(summary[-1]) + elif not printable(ca_name): + assert failed_to_get_cert(summary) + assert 'control character' in str_from_base64(summary[-1]) elif not_after > after: assert summary is None else: assert summary[0] == SECONDS_AGO.isoformat(timespec='seconds') assert summary[1] == not_after.isoformat(timespec='seconds') assert summary[2] == domain - assert summary[3] == str(port) - assert summary[5] == base64(ca_name.encode()).decode() + assert int(summary[3]) == port + assert str_from_base64(summary[5]) == ca_name diff --git a/tst/test_listen.py b/tst/test_listen.py index 3862d9d..45289d5 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -30,13 +30,13 @@ from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) from hypothesis import HealthCheck, given, settings -from hypothesis.strategies import (booleans, builds, composite, - data, datetimes, from_type, - integers, lists, sampled_from, text) +from hypothesis.strategies import (booleans, composite, data, + datetimes, from_type, integers, + lists, sampled_from, text) from hypothesis.provisional import domains, urls from pytest import raises -from scadere.check import base64_from_str +from scadere.check import base64_from_str, printable from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} @@ -50,12 +50,12 @@ def ports(): def serials(): """Return a Hypothesis strategy for TLS serial number.""" - return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) + return integers(0, 256**20-1) def base64s(): - """Return a Hypothesis strategy for CA names.""" - return text().map(base64_from_str) + """Return a Hypothesis strategy for printable strings in base64.""" + return text().filter(printable).map(base64_from_str) @given(domains(), ports(), base64s(), serials()) @@ -64,7 +64,7 @@ def test_path_with_cert(hostname, port, issuer, serial): assert r[0] == hostname assert int(r[1]) == port assert r[2] == issuer - assert r[3] == serial + assert int(r[3]) == serial @given(domains(), ports(), base64s()) |