about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-06-05 17:09:25 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-06-05 17:09:25 +0900
commit9508ef9fa8ae23610048bff9f03d592c3d1855fb (patch)
treed5c748b5ae08b355d21645cf5cb6b2d762707dc9
parent8021b9db75807fb89f5bc2497cbf309ac606f18d (diff)
downloadscadere-9508ef9fa8ae23610048bff9f03d592c3d1855fb.tar.gz
Strengthen typing
-rw-r--r--src/scadere/check.py9
-rw-r--r--src/scadere/listen.py79
-rw-r--r--tst/test_check.py28
-rw-r--r--tst/test_listen.py30
4 files changed, 69 insertions, 77 deletions
diff --git a/src/scadere/check.py b/src/scadere/check.py
index 288e599..0937e75 100644
--- a/src/scadere/check.py
+++ b/src/scadere/check.py
@@ -25,6 +25,7 @@ from socket import AF_INET, socket
 from ssl import create_default_context as tls_context
 from sys import argv, stderr, stdout
 from unicodedata import category as unicode_category
+from uuid import uuid4
 
 from . import __version__, GNUHelpFormatter, NetLoc
 
@@ -56,7 +57,7 @@ def check(netlocs, after, output, fake_ca=None):
         fake_ca.configure_trust(ctx)
 
     for hostname, port in netlocs:
-        now = datetime.now(tz=timezone.utc).isoformat(timespec='seconds')
+        now = datetime.now(timezone.utc).isoformat()
         netloc = f'{hostname}:{port}'
         stderr.write(f'TLS certificate for {netloc} ')
         try:
@@ -66,7 +67,7 @@ def check(netlocs, after, output, fake_ca=None):
                 cert = conn.getpeercert()
         except Exception as exception:
             stderr.write(f'cannot be retrieved: {exception}\n')
-            print(now, 'N/A', hostname, port, 'N/A',
+            print('N/A', now, hostname, port, uuid4().int,
                   base64_from_str(str(exception)), file=output)
             continue
 
@@ -79,7 +80,7 @@ def check(netlocs, after, output, fake_ca=None):
             serial = int(cert['serialNumber'], 16)
         except Exception as exception:
             stderr.write(f'cannot be parsed: {exception}\n')
-            print(now, 'N/A', hostname, port, 'N/A',
+            print('N/A', now, hostname, port, uuid4().int,
                   base64_from_str(str(exception)), file=output)
         else:
             if after < not_after:
@@ -111,7 +112,7 @@ def main(arguments=argv[1:]):
                         help='output file (default to stdout)')
     args = parser.parse_args(arguments)
     with args.output:  # pragma: no cover
-        after = datetime.now(tz=timezone.utc) + timedelta(days=args.days)
+        after = datetime.now(timezone.utc) + timedelta(days=args.days)
         check(args.netloc, after, args.output)
 
 
diff --git a/src/scadere/listen.py b/src/scadere/listen.py
index 59dbef6..add4b05 100644
--- a/src/scadere/listen.py
+++ b/src/scadere/listen.py
@@ -22,6 +22,7 @@ from base64 import urlsafe_b64decode as from_base64
 from datetime import datetime, timezone
 from functools import partial
 from http import HTTPStatus
+from operator import call
 from pathlib import Path
 from string import digits
 from typing import assert_never
@@ -32,20 +33,34 @@ from xml.etree.ElementTree import (Element as xml_element,
 from sys import argv
 
 from . import __version__, GNUHelpFormatter, NetLoc
+from .check import base64_from_str
 
 __all__ = ['main']
 
 
+def datetime_from_str(string, unavailable_ok=False):
+    """Parse datetime from string in ISO 8601 format."""
+    if string == 'N/A' and unavailable_ok:
+        return None
+    return datetime.fromisoformat(string)
+
+
+def str_from_base64(string64):
+    """Decode string in base64 format."""
+    return from_base64(string64.encode()).decode()
+
+
 def parse_summary(line):
     """Parse TLS certificate into a summary tuple."""
-    return tuple(line.rstrip('\r\n').split(' ', maxsplit=5))
+    return tuple(map(call,
+                     (partial(datetime_from_str, unavailable_ok=True),
+                      datetime_from_str, str, int, int, str_from_base64),
+                     line.rstrip('\r\n').split(' ', maxsplit=5)))
 
 
-def path(hostname, port, issuer, serial):
+def path(hostname, port, number, string):
     """Return the relative URL for the given certificate's details."""
-    if serial == 'N/A':
-        return f'{hostname}/{port}'
-    return f'{hostname}/{port}/{issuer}/{serial}'
+    return f'{hostname}/{port}/{base64_from_str(string)}/{number}'
 
 
 def supported_http_version(version):
@@ -60,13 +75,6 @@ def supported_http_version(version):
             return False
 
 
-def datetime_from_str(string, unavailable_ok=False):
-    """Parse datetime from string in ISO 8601 format."""
-    if string == 'N/A' and unavailable_ok:
-        return None
-    return datetime.fromisoformat(string)
-
-
 async def write_status(writer, http_version, status):
     """Write the given HTTP status line."""
     status = f'HTTP/{http_version} {status.value} {status.phrase}\r\n'
@@ -90,39 +98,33 @@ async def describe_status(writer, status, http_version='1.1'):
     await writer.drain()
 
 
-def str_from_base64(string):
-    """Decode string in base64 format."""
-    return from_base64(string.encode()).decode()
-
-
-def body(not_before, not_after, hostname, port, serial, string64):
+def body(not_before, not_after, hostname, port, number, string):
     """Describe the given certificate in XHTML."""
-    string = str_from_base64(string64)
-    if not_after is None:
+    if not_before is None:
         return (('h1', 'TLS certificate problem'),
                 ('dl',
                  ('dt', 'Domain'), ('dd', hostname),
                  ('dt', 'Port'), ('dd', port),
-                 ('dt', 'Time'), ('dd', not_before),
+                 ('dt', 'Time'), ('dd', not_after),
                  ('dt', 'Error'), ('dd', string)))
     return (('h1', 'TLS certificate information'),
             ('dl',
              ('dt', 'Domain'), ('dd', hostname),
              ('dt', 'Port'), ('dd', port),
              ('dt', 'Issuer'), ('dd', string),
-             ('dt', 'Serial number'), ('dd', serial),
+             ('dt', 'Serial number'), ('dd', number),
              ('dt', 'Valid from'), ('dd', not_before),
              ('dt', 'Valid until'), ('dd', not_after)))
 
 
 def entry(base_url, cert):
     """Construct Atom entry for the given TLS certificate."""
-    not_before, not_after, hostname, port, serial, issuer = cert
-    url = urljoin(base_url, path(hostname, port, issuer, serial))
+    not_before, not_after, hostname, port, number, string = cert
+    url = urljoin(base_url, path(hostname, port, number, string))
     title = (f'TLS cert for {hostname} cannot be retrieved'
-             if not_after is None
+             if not_before is None
              else f'TLS cert for {hostname} will expire at {not_after}')
-    author = 'Scadere' if not_after is None else str_from_base64(issuer)
+    author = 'Scadere' if not_before is None else string
     return ('entry',
             ('author', ('name', author)),
             ('content', {'type': 'xhtml'},
@@ -132,7 +134,7 @@ def entry(base_url, cert):
                       'type': 'application/xhtml+xml',
                       'href': url}),
             ('title', title),
-            ('updated', not_before))
+            ('updated', not_after))
 
 
 def split_domain(domain):
@@ -155,7 +157,7 @@ def feed(base_url, filename, certificates, domains):
             ('id', base_url),
             ('link', {'rel': 'self', 'href': base_url}),
             ('title', filename),
-            ('updated', datetime.now(tz=timezone.utc).isoformat()),
+            ('updated', datetime.now(timezone.utc).isoformat()),
             ('generator',
              {'uri': 'https://trong.loang.net/scadere/about',
               'version': __version__},
@@ -166,7 +168,7 @@ def feed(base_url, filename, certificates, domains):
 
 def page(certificate):
     """Construct an XHTML page for the given TLS certificate."""
-    not_before, not_after, hostname, port, serial, issuer = certificate
+    hostname, port = certificate[2:4]
     return ('html', {'xmlns': 'http://www.w3.org/1999/xhtml',
                      'lang': 'en'},
             ('head',
@@ -177,8 +179,7 @@ def page(certificate):
                                    'initial-scale=1.0')}),
              ('link', {'rel': 'icon', 'href': 'data:,'}),
              ('title', f'TLS certificate - {hostname}:{port}')),
-            ('body', *body(not_before, not_after,
-                           hostname, port, serial, issuer)))
+            ('body', *body(*certificate)))
 
 
 def xml(tree, parent=None):
@@ -196,6 +197,8 @@ def xml(tree, parent=None):
                 xml(child, elem)
             case str():
                 elem.text = child
+            case int():
+                elem.text = str(child)
             case datetime():
                 elem.text = child.isoformat()
             case _:
@@ -251,17 +254,13 @@ async def handle(certs, base_url, reader, writer):
             return
 
         try:
+            summaries = tuple(map(parse_summary,
+                                  certs.read_text().splitlines()))
+            paths = tuple(urlsplit(urljoin(base_url, path(*s[-4:]))).path
+                          for s in summaries)
+            lookup = dict(map(tuple, zip(paths, summaries)))
             url_parts = urlsplit(urljoin(base_url, url.strip().decode()))
             domains = tuple(parse_qs(url_parts.query).get('domain', []))
-            summaries = map(parse_summary, certs.read_text().splitlines())
-            lookup = {urlsplit(urljoin(base_url,
-                                       path(hostname, port,
-                                            issuer, serial))).path:
-                      (datetime_from_str(not_before),
-                       datetime_from_str(not_after, unavailable_ok=True),
-                       hostname, port, serial, issuer)
-                      for not_before, not_after, hostname, port, serial, issuer
-                      in summaries}
         except Exception:  # pragma: no cover
             await describe_status(writer, HTTPStatus.INTERNAL_SERVER_ERROR,
                                   http_version)
diff --git a/tst/test_check.py b/tst/test_check.py
index c5516e2..4b2c955 100644
--- a/tst/test_check.py
+++ b/tst/test_check.py
@@ -28,7 +28,8 @@ from trustme import CA
 from scadere.check import base64_from_str, check, printable
 from scadere.listen import parse_summary, str_from_base64
 
-SECONDS_AGO = datetime.now(tz=timezone.utc)
+# Times in X.509 certificates are YYYYMMDDHHMMSSZ (RFC 5280)
+SECONDS_AGO = datetime.now(timezone.utc).replace(microsecond=0)
 NEXT_DAY = SECONDS_AGO + timedelta(days=1)
 NEXT_WEEK = SECONDS_AGO + timedelta(days=7)
 
@@ -44,11 +45,6 @@ async def noop(reader, writer):
     await writer.wait_closed()
 
 
-def failed_to_get_cert(summary):
-    """Return if any field is N/A."""
-    return any(field == 'N/A' for field in summary)
-
-
 async def get_cert_summary(netloc, after, ca):
     """Fetch TLS certificate expiration summary for netloc."""
     loop = get_running_loop()
@@ -77,19 +73,19 @@ async def test_check(domain, ca_name, not_after, after, trust_ca):
         summary = await get_cert_summary((domain, port), after,
                                          ca if trust_ca else None)
         if not trust_ca:
-            assert failed_to_get_cert(summary)
-            assert 'self-signed certificate' in str_from_base64(summary[-1])
+            assert summary[0] is None
+            assert 'self-signed certificate' in summary[5]
         elif not_after == SECONDS_AGO:
-            assert failed_to_get_cert(summary)
-            assert 'certificate has expired' in str_from_base64(summary[-1])
+            assert summary[0] is None
+            assert 'certificate has expired' in summary[5]
         elif not printable(ca_name):
-            assert failed_to_get_cert(summary)
-            assert 'control character' in str_from_base64(summary[-1])
+            assert summary[0] is None
+            assert 'control character' in summary[5]
         elif not_after > after:
             assert summary is None
         else:
-            assert summary[0] == SECONDS_AGO.isoformat(timespec='seconds')
-            assert summary[1] == not_after.isoformat(timespec='seconds')
+            assert summary[0] == SECONDS_AGO
+            assert summary[1] == not_after
             assert summary[2] == domain
-            assert int(summary[3]) == port
-            assert str_from_base64(summary[5]) == ca_name
+            assert summary[3] == port
+            assert summary[5] == ca_name
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 55a69e0..ee95c3c 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -37,7 +37,8 @@ from hypothesis.provisional import domains, urls
 from pytest import raises
 
 from scadere.check import base64_from_str, printable
-from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml
+from scadere.listen import (handle, is_subdomain, path,
+                            str_from_base64, with_trailing_slash, xml)
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
 XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
@@ -58,18 +59,13 @@ def base64s():
     return text().filter(printable).map(base64_from_str)
 
 
-@given(domains(), ports(), base64s(), serials())
-def test_path_with_cert(hostname, port, issuer, serial):
-    r = path(hostname, port, issuer, serial).split('/')
+@given(domains(), ports(), serials(), text())
+def test_path(hostname, port, number, string):
+    r = path(hostname, port, number, string).split('/')
     assert r[0] == hostname
     assert int(r[1]) == port
-    assert r[2] == issuer
-    assert int(r[3]) == serial
-
-
-@given(domains(), ports(), base64s())
-def test_path_without_cert(hostname, port, error):
-    assert path(hostname, port, error, 'N/A') == f'{hostname}/{port}'
+    assert int(r[3]) == number
+    assert str_from_base64(r[2]) == string
 
 
 @given(domains(), lists(domains()))
@@ -87,7 +83,7 @@ def test_is_subdomain(subject, objects):
 
 def xml_unsupported_type(child):
     """Check if child is of a type supported by the XML constructor."""
-    return not isinstance(child, (tuple, str, datetime))
+    return not isinstance(child, (tuple, str, int, datetime))
 
 
 @given(text(), from_type(type).flatmap(from_type).filter(xml_unsupported_type))
@@ -100,13 +96,13 @@ def test_xml_unsupported_type(tag, child):
 def certificates(draw):
     """Return a Hypothesis strategy for certificate summaries."""
     valid = draw(booleans())
-    not_before = draw(datetimes()).isoformat()
-    not_after = draw(datetimes()).isoformat() if valid else 'N/A'
+    not_before = draw(datetimes()).isoformat() if valid else 'N/A'
+    not_after = draw(datetimes()).isoformat()
     hostname = draw(domains())
     port = draw(ports())
-    serial = draw(serials()) if valid else 'N/A'
-    issuer = draw(base64s())
-    return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'
+    number = draw(serials())
+    string = draw(base64s())
+    return f'{not_before} {not_after} {hostname} {port} {number} {string}'
 
 
 @contextmanager