diff options
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | src/scadere/check.py | 23 | ||||
-rw-r--r-- | src/scadere/listen.py | 49 | ||||
-rw-r--r-- | tst/test_check.py | 9 | ||||
-rw-r--r-- | tst/test_listen.py | 79 |
5 files changed, 76 insertions, 86 deletions
diff --git a/pyproject.toml b/pyproject.toml index 60291a8..4677381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = 'flit_core.buildapi' name = 'scadere' description = 'TLS certificate renewal reminder' readme = 'README.md' -requires-python = '>=3.10' +requires-python = '>=3.11' license = { file = 'COPYING' } authors = [ { name = 'Nguyễn Gia Phong', email = 'cnx@loang.net' } ] maintainers = [ { name = 'Nguyễn Gia Phong', email = 'chung@loa.loang.net' } ] diff --git a/src/scadere/check.py b/src/scadere/check.py index 23ba189..aaabe3f 100644 --- a/src/scadere/check.py +++ b/src/scadere/check.py @@ -24,15 +24,25 @@ from itertools import chain from socket import AF_INET, socket from ssl import create_default_context as tls_context from sys import argv, stderr, stdout +from unicodedata import category as unicode_category from . import __version__, GNUHelpFormatter, NetLoc __all__ = ['main'] +class CtlChrTrans: + """Translator for printing Unicode control characters.""" + + def __getitem__(self, ordinal): + if unicode_category(chr(ordinal)) == 'Cc': + return 0xfffd # replacement character '�' + raise KeyError + + def base64_from_str(string): """Convert string to base64 format in bytes.""" - return base64(string.encode()).decode() + return base64(string.translate(CtlChrTrans()).encode()).decode() def check(netlocs, after, output, fake_ca=None): @@ -52,9 +62,11 @@ def check(netlocs, after, output, fake_ca=None): server_hostname=hostname) as conn: conn.connect((hostname, port)) cert = conn.getpeercert() - except Exception as e: - stderr.write(f'cannot be retrieved: {e}\n') - print(f'N/A N/A {hostname} {port} N/A {e}', file=output) + except Exception as exception: + stderr.write(f'cannot be retrieved: {exception}\n') + now = datetime.now(tz=timezone.utc).isoformat() + print(now, 'N/A', hostname, port, 'N/A', + base64_from_str(str(exception)), file=output) else: ca = dict(chain.from_iterable(cert['issuer']))['organizationName'] not_before = parsedate(cert['notBefore']) @@ -64,9 +76,10 @@ def check(netlocs, after, output, fake_ca=None): stderr.write(f'will not expire at {after_seconds}\n') else: stderr.write(f'will expire at {not_after.isoformat()}\n') + serial = cert['serialNumber'].translate(CtlChrTrans()) print(not_before.isoformat(), not_after.isoformat(), # As unique identifier - hostname, port, cert['serialNumber'], + hostname, port, serial, base64_from_str(ca), file=output) diff --git a/src/scadere/listen.py b/src/scadere/listen.py index bf179e6..d8c1178 100644 --- a/src/scadere/listen.py +++ b/src/scadere/listen.py @@ -23,6 +23,7 @@ from datetime import datetime, timezone from functools import partial from http import HTTPStatus from pathlib import Path +from typing import assert_never from urllib.parse import parse_qs, urljoin, urlsplit from xml.etree.ElementTree import (Element as xml_element, SubElement as xml_subelement, @@ -41,9 +42,18 @@ def parse_summary(line): def path(hostname, port, issuer, serial): """Return the relative URL for the given certificate's details.""" + if serial == 'N/A': + return f'{hostname}/{port}' return f'{hostname}/{port}/{issuer}/{serial}' +def datetime_from_str(string, unavailable_ok=False): + """Parse datetime from string in ISO 8601 format.""" + if string == 'N/A' and unavailable_ok: + return None + return datetime.fromisoformat(string) + + async def write_status(writer, status): """Write the given HTTP/1.1 status line.""" writer.write(f'HTTP/1.1 {status.value} {status.phrase}\r\n'.encode()) @@ -71,13 +81,21 @@ def str_from_base64(string): return from_base64(string.encode()).decode() -def body(not_before, not_after, hostname, port, serial, issuer): +def body(not_before, not_after, hostname, port, serial, string64): """Describe the given certificate in XHTML.""" + string = str_from_base64(string64) + if not_after is None: + return (('h1', 'TLS certificate problem'), + ('dl', + ('dt', 'Domain'), ('dd', hostname), + ('dt', 'Port'), ('dd', port), + ('dt', 'Time'), ('dd', not_before), + ('dt', 'Error'), ('dd', string))) return (('h1', 'TLS certificate information'), ('dl', ('dt', 'Domain'), ('dd', hostname), ('dt', 'Port'), ('dd', port), - ('dt', 'Issuer'), ('dd', str_from_base64(issuer)), + ('dt', 'Issuer'), ('dd', string), ('dt', 'Serial number'), ('dd', serial), ('dt', 'Valid from'), ('dd', not_before), ('dt', 'Valid until'), ('dd', not_after))) @@ -87,15 +105,19 @@ def entry(base_url, cert): """Construct Atom entry for the given TLS certificate.""" not_before, not_after, hostname, port, serial, issuer = cert url = urljoin(base_url, path(hostname, port, issuer, serial)) + title = (f'TLS cert for {hostname} cannot be retrieved' + if not_after is None + else f'TLS cert for {hostname} will expire at {not_after}') + author = 'Scadere' if not_after is None else str_from_base64(issuer) return ('entry', - ('author', ('name', str_from_base64(issuer))), + ('author', ('name', author)), ('content', {'type': 'xhtml'}, ('div', {'xmlns': 'http://www.w3.org/1999/xhtml'}, *body(*cert))), ('id', url), ('link', {'rel': 'alternate', 'type': 'application/xhtml+xml', 'href': url}), - ('title', (f'TLS cert for {hostname} will expire at {not_after}')), + ('title', title), ('updated', not_before)) @@ -109,12 +131,15 @@ def xml(tree, parent=None): else: elem = xml_subelement(parent, tag, attrs) for child in children: - if isinstance(child, tuple): - xml(child, elem) - elif isinstance(child, datetime): - elem.text = child.isoformat() - else: - elem.text = str(child) + match child: + case tuple(): + xml(child, elem) + case str(): + elem.text = child + case datetime(): + elem.text = child.isoformat() + case _: # pragma: no cover + assert_never(child) if parent is None: indent(elem) return elem @@ -151,7 +176,9 @@ async def handle(certs, base_url, reader, writer): summaries = map(parse_summary, certs.read_text().splitlines()) lookup = {urlsplit(urljoin(base_url, path(hostname, port, issuer, serial))).path: - (not_before, not_after, hostname, port, serial, issuer) + (datetime_from_str(not_before), + datetime_from_str(not_after, unavailable_ok=True), + hostname, port, serial, issuer) for not_before, not_after, hostname, port, serial, issuer in summaries} request = await reader.readuntil(b'\r\n') diff --git a/tst/test_check.py b/tst/test_check.py index 397b9ca..f809a90 100644 --- a/tst/test_check.py +++ b/tst/test_check.py @@ -26,7 +26,7 @@ from hypothesis import given from pytest import mark from trustme import CA -from scadere.check import base64_from_str, check +from scadere.check import CtlChrTrans, base64_from_str, check from scadere.listen import parse_summary, str_from_base64 SECONDS_AGO = datetime.now(tz=timezone.utc) @@ -36,7 +36,8 @@ NEXT_WEEK = SECONDS_AGO + timedelta(days=7) @given(...) def test_base64(string: str): - assert str_from_base64(base64_from_str(string)) == string + printable_string = string.translate(CtlChrTrans()) + assert str_from_base64(base64_from_str(string)) == printable_string async def noop(reader, writer): @@ -84,10 +85,10 @@ async def test_check(domain, ca_name, not_after, after, trust_ca): ca if trust_ca else None) if not trust_ca: assert failed_to_get_cert(summary) - assert 'self-signed certificate' in summary[-1] + assert 'self-signed certificate' in str_from_base64(summary[-1]) elif not_after == SECONDS_AGO: assert failed_to_get_cert(summary) - assert 'certificate has expired' in summary[-1] + assert 'certificate has expired' in str_from_base64(summary[-1]) elif not_after > after: assert summary is None else: diff --git a/tst/test_listen.py b/tst/test_listen.py index e204d4f..cc6a9a1 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -22,21 +22,18 @@ from copy import deepcopy from email.parser import BytesHeaderParser from functools import partial from pathlib import Path -from string import ascii_letters from tempfile import mkstemp from urllib.parse import urljoin, urlsplit from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) -from xml.sax.saxutils import escape from hypothesis import HealthCheck, given, settings -from hypothesis.strategies import (builds, composite, data, +from hypothesis.strategies import (booleans, builds, composite, data, datetimes, integers, lists, text) from hypothesis.provisional import domains, urls from scadere.check import base64_from_str -from scadere.listen import (body, entry, handle, is_subdomain, path, - str_from_base64, with_trailing_slash, xml) +from scadere.listen import handle, is_subdomain, path, with_trailing_slash ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'} @@ -52,13 +49,13 @@ def serials(): return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) -def ca_names(): +def base64s(): """Return a Hypothesis strategy for CA names.""" - return text().map(lambda name: base64_from_str(name)) + return text().map(base64_from_str) -@given(domains(), ports(), ca_names(), serials()) -def test_path(hostname, port, issuer, serial): +@given(domains(), ports(), base64s(), serials()) +def test_path_with_cert(hostname, port, issuer, serial): r = path(hostname, port, issuer, serial).split('/') assert r[0] == hostname assert int(r[1]) == port @@ -66,57 +63,9 @@ def test_path(hostname, port, issuer, serial): assert r[3] == serial -@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) -def test_body(hostname, port, issuer, serial, not_before, not_after): - r = body(not_before, not_after, hostname, port, serial, issuer) - assert r[-1][0] == 'dl' - d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'), - (v for k, v in r[-1][1:] if k == 'dd'))) - assert d['Domain'] == hostname - assert d['Port'] == port - assert d['Issuer'] == str_from_base64(issuer) - assert d['Serial number'] == serial - assert d['Valid from'] == not_before - assert d['Valid until'] == not_after - - -@given(urls(), domains(), ports(), - ca_names(), serials(), datetimes(), datetimes()) -def test_atom_entry(base_url, hostname, port, - issuer, serial, not_before, not_after): - cert = not_before, not_after, hostname, port, serial, issuer - r = str_from_xml(xml(entry(base_url, cert)), - 'unicode', short_empty_elements=False) - issuer_str = str_from_base64(issuer) - url = urljoin(base_url, path(hostname, port, issuer, serial)) - assert r == f'''<entry> - <author> - <name>{escape(issuer_str)}</name> - </author> - <content type="xhtml"> - <div xmlns="http://www.w3.org/1999/xhtml"> - <h1>TLS certificate information</h1> - <dl> - <dt>Domain</dt> - <dd>{hostname}</dd> - <dt>Port</dt> - <dd>{port}</dd> - <dt>Issuer</dt> - <dd>{escape(issuer_str)}</dd> - <dt>Serial number</dt> - <dd>{serial}</dd> - <dt>Valid from</dt> - <dd>{not_before.isoformat()}</dd> - <dt>Valid until</dt> - <dd>{not_after.isoformat()}</dd> - </dl> - </div> - </content> - <id>{url}</id> - <link rel="alternate" type="application/xhtml+xml" href="{url}"></link> - <title>TLS cert for {hostname} will expire at {not_after}</title> - <updated>{not_before.isoformat()}</updated> -</entry>''' +@given(domains(), ports(), base64s()) +def test_path_without_cert(hostname, port, error): + assert path(hostname, port, error, 'N/A') == f'{hostname}/{port}' @given(domains(), lists(domains())) @@ -135,13 +84,13 @@ def test_is_subdomain(subject, objects): @composite def certificates(draw): """Return a Hypothesis strategy for certificate summaries.""" + valid = draw(booleans()) not_before = draw(datetimes()).isoformat() - not_after = draw(datetimes()).isoformat() + not_after = draw(datetimes()).isoformat() if valid else 'N/A' hostname = draw(domains()) port = draw(ports()) - serial = draw(serials()) - # Free-formed UTF-8 could easily creates malformed XML. - issuer = base64_from_str(draw(text(ascii_letters))) + serial = draw(serials()) if valid else 'N/A' + issuer = draw(base64s()) return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' @@ -245,7 +194,7 @@ async def check_server(sockets, func, *args): @given(urls().filter(is_base_url).filter(has_usual_path), lists(certificates(), min_size=1)) @settings(deadline=None) -async def test_http_200(base_url, certs): +async def test_content(base_url, certs): base_path = urlsplit(base_url).path with tmp_cert_file(certs) as cert_file: handler = partial(handle, cert_file, base_url) |