about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-26 17:45:48 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-26 17:45:48 +0900
commit56a032568443bdf85dd37df5f6716b3475626d6a (patch)
tree487de1b9bff07ede605d5321faa1f09706e3bd00
parentb37d71bca632c1e29a3402fbaf69a14843eab8f2 (diff)
downloadscadere-56a032568443bdf85dd37df5f6716b3475626d6a.tar.gz
Fix handling of base URL
-rw-r--r--pyproject.toml3
-rw-r--r--src/scadere/listen.py24
-rw-r--r--tst/test_listen.py182
3 files changed, 178 insertions, 31 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 4104f67..244eba7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,7 +25,10 @@ urls = { Source = 'https://trong.loang.net/scadere' }
 scripts = { scadere = 'scadere.__main__:main' }
 
 [tool.pytest.ini_options]
+asyncio_mode = 'auto'
+asyncio_default_fixture_loop_scope = 'function'
 testpaths = [ 'tst' ]
+verbosity_assertions = 2
 
 [tool.coverage.run]
 branch = true
diff --git a/src/scadere/listen.py b/src/scadere/listen.py
index aa80c32..0777f4a 100644
--- a/src/scadere/listen.py
+++ b/src/scadere/listen.py
@@ -21,7 +21,7 @@ from functools import partial
 from urllib.parse import parse_qs, urljoin, urlsplit
 from xml.etree.ElementTree import (Element as xml_element,
                                    SubElement as xml_subelement,
-                                   indent, tostring as xml_to_string)
+                                   indent, tostring as str_from_xml)
 
 from . import __version__
 
@@ -84,13 +84,14 @@ async def handle(certs, base_url, reader, writer):
     """Handle HTTP request."""
     summaries = tuple(cert.rstrip().split(maxsplit=5)
                       for cert in certs.read_text().splitlines())
-    lookup = {f'/{path(hostname, port, issuer, serial)}':
+    lookup = {urlsplit(urljoin(base_url,
+                               path(hostname, port, issuer, serial))).path:
               (not_before, not_after, hostname, port, serial, issuer)
               for not_before, not_after, hostname, port, serial, issuer
               in summaries}
     request = await reader.readuntil(b'\r\n')
     url = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1)[0]
-    url_parts = urlsplit(url.decode())
+    url_parts = urlsplit(urljoin(base_url, url.decode()))
     domains = tuple(parse_qs(url_parts.query).get('domain', ['']))
 
     if not request.startswith(b'GET '):
@@ -99,7 +100,9 @@ async def handle(certs, base_url, reader, writer):
         writer.close()
         await writer.wait_closed()
         return
-    elif url_parts.path == '/':  # Atom feed
+    elif url.startswith(b'//'):  # urljoin goes haywire
+        writer.write(b'HTTP/1.1 404 Not Found\r\n')
+    elif url_parts.path == urlsplit(base_url).path:  # Atom feed
         writer.write(b'HTTP/1.1 200 OK\r\n')
         writer.write(b'Content-Type: application/atom+xml\r\n')
         feed = xml(('feed', {'xmlns': 'http://www.w3.org/2005/Atom'},
@@ -112,9 +115,10 @@ async def handle(certs, base_url, reader, writer):
                       'version': __version__},
                      'Scadere'),
                     *(entry(base_url, cert)
-                      for cert in summaries if cert[2].endswith(domains))))
-        content = xml_to_string(feed, 'unicode', xml_declaration=True,
-                                default_namespace=None).encode()
+                      for cert in lookup.values()
+                      if cert[2].endswith(domains))))
+        content = str_from_xml(feed, 'unicode', xml_declaration=True,
+                               default_namespace=None).encode()
         writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
         writer.write(content)
     elif url_parts.path in lookup:  # accessible Atom entry's link/ID
@@ -134,8 +138,8 @@ async def handle(certs, base_url, reader, writer):
                      ('title', f'TLS certificate - {hostname}:{port}')),
                     ('body', *body(not_before, not_after,
                                    hostname, port, serial, issuer))))
-        content = xml_to_string(page, 'unicode', xml_declaration=True,
-                                default_namespace=None).encode()
+        content = str_from_xml(page, 'unicode', xml_declaration=True,
+                               default_namespace=None).encode()
         writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
         writer.write(content)
     else:
@@ -145,7 +149,7 @@ async def handle(certs, base_url, reader, writer):
     await writer.wait_closed()
 
 
-async def listen(certs, base_url, host, port):
+async def listen(certs, base_url, host, port):  # pragma: no cover
     """Serve HTTP server for TLS certificate expirations' Atom feed."""
     server = await start_server(partial(handle, certs, base_url), host, port)
     async with server:
diff --git a/tst/test_listen.py b/tst/test_listen.py
index d3a8052..f35ce82 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -14,35 +14,175 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-from base64 import urlsafe_b64decode as from_base64
+from asyncio import open_connection, start_server
+from base64 import (urlsafe_b64decode as from_base64,
+                    urlsafe_b64encode as base64)
+from contextlib import closing, contextmanager
+from functools import partial
+from pathlib import Path
+from tempfile import mkstemp
+from urllib.parse import urljoin, urlsplit
+from xml.etree.ElementTree import tostring as str_from_xml
+from xml.sax.saxutils import escape
 
 from hypothesis import given
-from hypothesis.strategies import integers, datetimes, text
-from hypothesis.provisional import domains
+from hypothesis.strategies import (builds, composite, datetimes,
+                                   integers, lists, text)
+from hypothesis.provisional import domains, urls
 
-from scadere.listen import body, path
+from scadere.listen import body, entry, handle, path, xml
 
 
-@given(domains(), integers(1, 65535), text(), integers(0, 256**20))
+def ports():
+    """Return a Hypothesis strategy for TCP ports."""
+    return integers(1, 65535)
+
+
+def serials():
+    """Return a Hypothesis strategy for TLS serial number."""
+    return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1))
+
+
+def ca_names():
+    """Return a Hypothesis strategy for CA names."""
+    return text().map(lambda name: base64(name.encode()).decode())
+
+
+@given(domains(), ports(), ca_names(), serials())
 def test_path(hostname, port, issuer, serial):
-    r = path(hostname, port, issuer, hex(serial).removeprefix('0x')).split('/')
-    assert(r[0] == hostname)
-    assert(int(r[1]) == port)
-    assert(from_base64(r[2]).decode() == issuer)
-    assert(int(r[3], 16) == serial)
+    r = path(hostname, port, issuer, serial).split('/')
+    assert r[0] == hostname
+    assert int(r[1]) == port
+    assert r[2] == issuer
+    assert r[3] == serial
 
 
-@given(domains(), integers(1, 65535), text(), integers(0, 256**20),
-       datetimes(), datetimes())
+@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes())
 def test_body(hostname, port, issuer, serial, not_before, not_after):
-    r = body(not_before, not_after, hostname, port,
-             hex(serial).removeprefix('0x'), issuer)
-    assert(r[-1][0] == 'dl')
+    r = body(not_before, not_after, hostname, port, serial, issuer)
+    assert r[-1][0] == 'dl'
     d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'),
                  (v for k, v in r[-1][1:] if k == 'dd')))
-    assert(d['Domain'] == hostname)
-    assert(d['Port'] == port)
-    assert(d['Issuer'] == issuer)
-    assert(int(d['Serial number'], 16) == serial)
-    assert(d['Valid from'] == not_before)
-    assert(d['Valid until'] == not_after)
+    assert d['Domain'] == hostname
+    assert d['Port'] == port
+    assert d['Issuer'] == from_base64(issuer.encode()).decode()
+    assert d['Serial number'] == serial
+    assert d['Valid from'] == not_before
+    assert d['Valid until'] == not_after
+
+
+@given(urls(), domains(), ports(),
+       ca_names(), serials(), datetimes(), datetimes())
+def test_atom_entry(base_url, hostname, port,
+                    issuer, serial, not_before, not_after):
+    cert = not_before, not_after, hostname, port, serial, issuer
+    r = str_from_xml(xml(entry(base_url, cert)),
+                     'unicode', short_empty_elements=False)
+    issuer_str = from_base64(issuer.encode()).decode()
+    url = urljoin(base_url, path(hostname, port, issuer, serial))
+    assert r == f'''<entry>
+  <author>
+    <name>{escape(issuer_str)}</name>
+  </author>
+  <content type="xhtml">
+    <div xmlns="http://www.w3.org/1999/xhtml">
+      <h1>TLS certificate information</h1>
+      <dl>
+        <dt>Domain</dt>
+        <dd>{hostname}</dd>
+        <dt>Port</dt>
+        <dd>{port}</dd>
+        <dt>Issuer</dt>
+        <dd>{escape(issuer_str)}</dd>
+        <dt>Serial number</dt>
+        <dd>{serial}</dd>
+        <dt>Valid from</dt>
+        <dd>{not_before.isoformat()}</dd>
+        <dt>Valid until</dt>
+        <dd>{not_after.isoformat()}</dd>
+      </dl>
+    </div>
+  </content>
+  <id>{url}</id>
+  <link rel="alternate" type="text/plain" href="{url}"></link>
+  <title>TLS cert for {hostname} will expire at {not_after}</title>
+  <updated>{not_before.isoformat()}</updated>
+</entry>'''
+
+
+@composite
+def certificates(draw):
+    """Return a Hypothesis strategy for certificate summaries."""
+    not_before = draw(datetimes()).isoformat()
+    not_after = draw(datetimes()).isoformat()
+    hostname = draw(domains())
+    port = draw(ports())
+    serial = draw(serials())
+    issuer = draw(ca_names()) or '\0'
+    return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'
+
+
+@contextmanager
+def tmp_cert_file(lines):
+    cert_file = Path(mkstemp(text=True)[1])
+    cert_file.write_text('\n'.join(lines))
+    try:
+        yield cert_file
+    finally:
+        cert_file.unlink()
+
+
+@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')),
+       lists(certificates()))
+async def test_http_200(base_url, certs):
+    base_path = urlsplit(base_url).path
+    with tmp_cert_file(certs) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'GET {base_path}\r\n'.encode())
+                await writer.drain()
+                response = await reader.readuntil(b'\r\n')
+                assert response == b'HTTP/1.1 200 OK\r\n'
+
+
+@composite
+def two_urls(draw, constraint):
+    """Return a Hypothesis strategy for 2 URLs."""
+    first = draw(urls())
+    second = draw(urls().filter(partial(constraint, first)))
+    return first, second
+
+
+@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path))
+async def test_http_404(url_and_url):
+    base_url, url = url_and_url
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
+                await writer.drain()
+                response = await reader.read()
+                assert response == b'HTTP/1.1 404 Not Found\r\n'
+
+
+@given(urls(), text().filter(lambda method: not method.startswith('GET ')))
+async def test_http_405(base_url, request):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'{request}\r\n'.encode())
+                await writer.drain()
+                response = await reader.read()
+                assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'