about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-06-03 22:20:22 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-06-03 22:20:22 +0900
commit8b83c1f04c808558a8097022466b2d4327dd62af (patch)
tree9436b81ce05b06e0d1e789fda7a8b74a67e1f463
parentddaee1e438b06ced6ec621db0e37d4c9968fe835 (diff)
downloadscadere-8b83c1f04c808558a8097022466b2d4327dd62af.tar.gz
DRY up base64 handling
-rw-r--r--src/scadere/check.py10
-rw-r--r--src/scadere/listen.py11
-rw-r--r--tst/test_check.py10
-rw-r--r--tst/test_listen.py15
4 files changed, 30 insertions, 16 deletions
diff --git a/src/scadere/check.py b/src/scadere/check.py
index 5764382..23ba189 100644
--- a/src/scadere/check.py
+++ b/src/scadere/check.py
@@ -27,7 +27,12 @@ from sys import argv, stderr, stdout
 
 from . import __version__, GNUHelpFormatter, NetLoc
 
-__all__ = ['check', 'main']
+__all__ = ['main']
+
+
+def base64_from_str(string):
+    """Convert string to base64 format in bytes."""
+    return base64(string.encode()).decode()
 
 
 def check(netlocs, after, output, fake_ca=None):
@@ -62,8 +67,7 @@ def check(netlocs, after, output, fake_ca=None):
                 print(not_before.isoformat(), not_after.isoformat(),
                       # As unique identifier
                       hostname, port, cert['serialNumber'],
-                      base64(ca.encode()).decode(),
-                      file=output)
+                      base64_from_str(ca), file=output)
 
 
 def main(arguments=argv[1:]):
diff --git a/src/scadere/listen.py b/src/scadere/listen.py
index 6dc8f3a..bf179e6 100644
--- a/src/scadere/listen.py
+++ b/src/scadere/listen.py
@@ -31,7 +31,7 @@ from sys import argv
 
 from . import __version__, GNUHelpFormatter, NetLoc
 
-__all__ = ['listen', 'main', 'parse_summary']
+__all__ = ['main']
 
 
 def parse_summary(line):
@@ -66,13 +66,18 @@ async def describe_status(writer, status):
     await writer.drain()
 
 
+def str_from_base64(string):
+    """Decode string in base64 format."""
+    return from_base64(string.encode()).decode()
+
+
 def body(not_before, not_after, hostname, port, serial, issuer):
     """Describe the given certificate in XHTML."""
     return (('h1', 'TLS certificate information'),
             ('dl',
              ('dt', 'Domain'), ('dd', hostname),
              ('dt', 'Port'), ('dd', port),
-             ('dt', 'Issuer'), ('dd', from_base64(issuer.encode()).decode()),
+             ('dt', 'Issuer'), ('dd', str_from_base64(issuer)),
              ('dt', 'Serial number'), ('dd', serial),
              ('dt', 'Valid from'), ('dd', not_before),
              ('dt', 'Valid until'), ('dd', not_after)))
@@ -83,7 +88,7 @@ def entry(base_url, cert):
     not_before, not_after, hostname, port, serial, issuer = cert
     url = urljoin(base_url, path(hostname, port, issuer, serial))
     return ('entry',
-            ('author', ('name', from_base64(issuer.encode()).decode())),
+            ('author', ('name', str_from_base64(issuer))),
             ('content', {'type': 'xhtml'},
              ('div', {'xmlns': 'http://www.w3.org/1999/xhtml'}, *body(*cert))),
             ('id', url),
diff --git a/tst/test_check.py b/tst/test_check.py
index a87788e..397b9ca 100644
--- a/tst/test_check.py
+++ b/tst/test_check.py
@@ -22,17 +22,23 @@ from datetime import datetime, timedelta, timezone
 from io import StringIO
 from ssl import Purpose, create_default_context as tls_context
 
+from hypothesis import given
 from pytest import mark
 from trustme import CA
 
-from scadere.check import check
-from scadere.listen import parse_summary
+from scadere.check import base64_from_str, check
+from scadere.listen import parse_summary, str_from_base64
 
 SECONDS_AGO = datetime.now(tz=timezone.utc)
 NEXT_DAY = SECONDS_AGO + timedelta(days=1)
 NEXT_WEEK = SECONDS_AGO + timedelta(days=7)
 
 
+@given(...)
+def test_base64(string: str):
+    assert str_from_base64(base64_from_str(string)) == string
+
+
 async def noop(reader, writer):
     """Do nothing."""
     writer.close()
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 3737baa..e204d4f 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -17,8 +17,6 @@
 # along with scadere.  If not, see <https://www.gnu.org/licenses/>.
 
 from asyncio import TaskGroup, open_connection, start_server
-from base64 import (urlsafe_b64decode as from_base64,
-                    urlsafe_b64encode as base64)
 from contextlib import asynccontextmanager, contextmanager
 from copy import deepcopy
 from email.parser import BytesHeaderParser
@@ -36,8 +34,9 @@ from hypothesis.strategies import (builds, composite, data,
                                    datetimes, integers, lists, text)
 from hypothesis.provisional import domains, urls
 
-from scadere.listen import (body, entry, handle, is_subdomain,
-                            path, with_trailing_slash, xml)
+from scadere.check import base64_from_str
+from scadere.listen import (body, entry, handle, is_subdomain, path,
+                            str_from_base64, with_trailing_slash, xml)
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
 XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
@@ -55,7 +54,7 @@ def serials():
 
 def ca_names():
     """Return a Hypothesis strategy for CA names."""
-    return text().map(lambda name: base64(name.encode()).decode())
+    return text().map(lambda name: base64_from_str(name))
 
 
 @given(domains(), ports(), ca_names(), serials())
@@ -75,7 +74,7 @@ def test_body(hostname, port, issuer, serial, not_before, not_after):
                  (v for k, v in r[-1][1:] if k == 'dd')))
     assert d['Domain'] == hostname
     assert d['Port'] == port
-    assert d['Issuer'] == from_base64(issuer.encode()).decode()
+    assert d['Issuer'] == str_from_base64(issuer)
     assert d['Serial number'] == serial
     assert d['Valid from'] == not_before
     assert d['Valid until'] == not_after
@@ -88,7 +87,7 @@ def test_atom_entry(base_url, hostname, port,
     cert = not_before, not_after, hostname, port, serial, issuer
     r = str_from_xml(xml(entry(base_url, cert)),
                      'unicode', short_empty_elements=False)
-    issuer_str = from_base64(issuer.encode()).decode()
+    issuer_str = str_from_base64(issuer)
     url = urljoin(base_url, path(hostname, port, issuer, serial))
     assert r == f'''<entry>
   <author>
@@ -142,7 +141,7 @@ def certificates(draw):
     port = draw(ports())
     serial = draw(serials())
     # Free-formed UTF-8 could easily creates malformed XML.
-    issuer = base64(draw(text(ascii_letters)).encode()).decode()
+    issuer = base64_from_str(draw(text(ascii_letters)))
     return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'