about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-06-04 16:36:05 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-06-04 16:37:00 +0900
commit20766657a7146d7a36feebbb297dccd016a1406f (patch)
tree78493273cbb8f84cdb6e9a7e751e7461a64e0ee6
parent11d05505cdf25b77cfbdf09f5f1d1be79eeaa0f3 (diff)
downloadscadere-20766657a7146d7a36feebbb297dccd016a1406f.tar.gz
Err on invalid cert fetch
-rw-r--r--src/scadere/check.py32
-rw-r--r--tst/test_check.py15
-rw-r--r--tst/test_listen.py16
3 files changed, 37 insertions, 26 deletions
diff --git a/src/scadere/check.py b/src/scadere/check.py
index aaabe3f..288e599 100644
--- a/src/scadere/check.py
+++ b/src/scadere/check.py
@@ -31,18 +31,19 @@ from . import __version__, GNUHelpFormatter, NetLoc
 __all__ = ['main']
 
 
-class CtlChrTrans:
-    """Translator for printing Unicode control characters."""
+def is_control_character(character):
+    """Check if a Unicode character belongs to the control category."""
+    return unicode_category(character) == 'Cc'
 
-    def __getitem__(self, ordinal):
-        if unicode_category(chr(ordinal)) == 'Cc':
-            return 0xfffd  # replacement character '�'
-        raise KeyError
+
+def printable(string):
+    """Check if the given Unicode string is printable."""
+    return not any(map(is_control_character, string))
 
 
 def base64_from_str(string):
     """Convert string to base64 format in bytes."""
-    return base64(string.translate(CtlChrTrans()).encode()).decode()
+    return base64(string.encode()).decode()
 
 
 def check(netlocs, after, output, fake_ca=None):
@@ -55,6 +56,7 @@ def check(netlocs, after, output, fake_ca=None):
         fake_ca.configure_trust(ctx)
 
     for hostname, port in netlocs:
+        now = datetime.now(tz=timezone.utc).isoformat(timespec='seconds')
         netloc = f'{hostname}:{port}'
         stderr.write(f'TLS certificate for {netloc} ')
         try:
@@ -64,19 +66,27 @@ def check(netlocs, after, output, fake_ca=None):
                 cert = conn.getpeercert()
         except Exception as exception:
             stderr.write(f'cannot be retrieved: {exception}\n')
-            now = datetime.now(tz=timezone.utc).isoformat()
             print(now, 'N/A', hostname, port, 'N/A',
                   base64_from_str(str(exception)), file=output)
-        else:
-            ca = dict(chain.from_iterable(cert['issuer']))['organizationName']
+            continue
+
+        try:
             not_before = parsedate(cert['notBefore'])
             not_after = parsedate(cert['notAfter'])
+            ca = dict(chain.from_iterable(cert['issuer']))['organizationName']
+            if not printable(ca):
+                raise ValueError(f'CA name contains control character: {ca!r}')
+            serial = int(cert['serialNumber'], 16)
+        except Exception as exception:
+            stderr.write(f'cannot be parsed: {exception}\n')
+            print(now, 'N/A', hostname, port, 'N/A',
+                  base64_from_str(str(exception)), file=output)
+        else:
             if after < not_after:
                 after_seconds = after.isoformat(timespec='seconds')
                 stderr.write(f'will not expire at {after_seconds}\n')
             else:
                 stderr.write(f'will expire at {not_after.isoformat()}\n')
-                serial = cert['serialNumber'].translate(CtlChrTrans())
                 print(not_before.isoformat(), not_after.isoformat(),
                       # As unique identifier
                       hostname, port, serial,
diff --git a/tst/test_check.py b/tst/test_check.py
index 0ebc0ab..23be6f5 100644
--- a/tst/test_check.py
+++ b/tst/test_check.py
@@ -17,7 +17,6 @@
 # along with scadere.  If not, see <https://www.gnu.org/licenses/>.
 
 from asyncio import get_running_loop, start_server
-from base64 import urlsafe_b64encode as base64
 from datetime import datetime, timedelta, timezone
 from io import StringIO
 from ssl import Purpose, create_default_context as tls_context
@@ -26,7 +25,7 @@ from hypothesis import given
 from pytest import mark
 from trustme import CA
 
-from scadere.check import CtlChrTrans, base64_from_str, check
+from scadere.check import base64_from_str, check, printable
 from scadere.listen import parse_summary, str_from_base64
 
 SECONDS_AGO = datetime.now(tz=timezone.utc)
@@ -36,8 +35,7 @@ NEXT_WEEK = SECONDS_AGO + timedelta(days=7)
 
 @given(...)
 def test_base64(string: str):
-    printable_string = string.translate(CtlChrTrans())
-    assert str_from_base64(base64_from_str(string)) == printable_string
+    assert str_from_base64(base64_from_str(string)) == string
 
 
 async def noop(reader, writer):
@@ -67,7 +65,7 @@ async def get_cert_summary(netloc, after, ca):
 
 
 @mark.parametrize('domain', ['localhost'])
-@mark.parametrize('ca_name', ['trustme'])
+@mark.parametrize('ca_name', ['trustme', '\x1f'])
 @mark.parametrize('not_after', [SECONDS_AGO, NEXT_DAY, NEXT_WEEK])
 @mark.parametrize('after', [NEXT_DAY, NEXT_WEEK])
 @mark.parametrize('trust_ca', [False, True])
@@ -88,11 +86,14 @@ async def test_check(domain, ca_name, not_after, after, trust_ca):
         elif not_after == SECONDS_AGO:
             assert failed_to_get_cert(summary)
             assert 'certificate has expired' in str_from_base64(summary[-1])
+        elif not printable(ca_name):
+            assert failed_to_get_cert(summary)
+            assert 'control character' in str_from_base64(summary[-1])
         elif not_after > after:
             assert summary is None
         else:
             assert summary[0] == SECONDS_AGO.isoformat(timespec='seconds')
             assert summary[1] == not_after.isoformat(timespec='seconds')
             assert summary[2] == domain
-            assert summary[3] == str(port)
-            assert summary[5] == base64(ca_name.encode()).decode()
+            assert int(summary[3]) == port
+            assert str_from_base64(summary[5]) == ca_name
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 3862d9d..45289d5 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -30,13 +30,13 @@ from xml.etree.ElementTree import (XML, XMLParser, indent,
                                    tostring as str_from_xml)
 
 from hypothesis import HealthCheck, given, settings
-from hypothesis.strategies import (booleans, builds, composite,
-                                   data, datetimes, from_type,
-                                   integers, lists, sampled_from, text)
+from hypothesis.strategies import (booleans, composite, data,
+                                   datetimes, from_type, integers,
+                                   lists, sampled_from, text)
 from hypothesis.provisional import domains, urls
 from pytest import raises
 
-from scadere.check import base64_from_str
+from scadere.check import base64_from_str, printable
 from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
@@ -50,12 +50,12 @@ def ports():
 
 def serials():
     """Return a Hypothesis strategy for TLS serial number."""
-    return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1))
+    return integers(0, 256**20-1)
 
 
 def base64s():
-    """Return a Hypothesis strategy for CA names."""
-    return text().map(base64_from_str)
+    """Return a Hypothesis strategy for printable strings in base64."""
+    return text().filter(printable).map(base64_from_str)
 
 
 @given(domains(), ports(), base64s(), serials())
@@ -64,7 +64,7 @@ def test_path_with_cert(hostname, port, issuer, serial):
     assert r[0] == hostname
     assert int(r[1]) == port
     assert r[2] == issuer
-    assert r[3] == serial
+    assert int(r[3]) == serial
 
 
 @given(domains(), ports(), base64s())