about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-06-04 15:41:28 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-06-04 15:41:28 +0900
commit11d05505cdf25b77cfbdf09f5f1d1be79eeaa0f3 (patch)
treee0c19143ef56cdb856846a59d3f17ec323acd85a
parent10bffe228843ea4c59110fc9ce40663a7a144338 (diff)
downloadscadere-11d05505cdf25b77cfbdf09f5f1d1be79eeaa0f3.tar.gz
Handle server errors
-rw-r--r--src/scadere/listen.py220
-rw-r--r--tst/test_listen.py98
2 files changed, 212 insertions, 106 deletions
diff --git a/src/scadere/listen.py b/src/scadere/listen.py
index d8c1178..59dbef6 100644
--- a/src/scadere/listen.py
+++ b/src/scadere/listen.py
@@ -23,6 +23,7 @@ from datetime import datetime, timezone
 from functools import partial
 from http import HTTPStatus
 from pathlib import Path
+from string import digits
 from typing import assert_never
 from urllib.parse import parse_qs, urljoin, urlsplit
 from xml.etree.ElementTree import (Element as xml_element,
@@ -47,6 +48,18 @@ def path(hostname, port, issuer, serial):
     return f'{hostname}/{port}/{issuer}/{serial}'
 
 
+def supported_http_version(version):
+    """Check if given HTTP version complies with section 2.5 of RFC 9110."""
+    match len(version):
+        case 1:
+            return version in digits
+        case 3:
+            major, period, minor = version
+            return major in digits and period == '.' and minor in digits
+        case _:
+            return False
+
+
 def datetime_from_str(string, unavailable_ok=False):
     """Parse datetime from string in ISO 8601 format."""
     if string == 'N/A' and unavailable_ok:
@@ -54,9 +67,10 @@ def datetime_from_str(string, unavailable_ok=False):
     return datetime.fromisoformat(string)
 
 
-async def write_status(writer, status):
-    """Write the given HTTP/1.1 status line."""
-    writer.write(f'HTTP/1.1 {status.value} {status.phrase}\r\n'.encode())
+async def write_status(writer, http_version, status):
+    """Write the given HTTP status line."""
+    status = f'HTTP/{http_version} {status.value} {status.phrase}\r\n'
+    writer.write(status.encode())
     await writer.drain()
 
 
@@ -66,9 +80,9 @@ async def write_content_type(writer, content_type):
     await writer.drain()
 
 
-async def describe_status(writer, status):
+async def describe_status(writer, status, http_version='1.1'):
     """Write a HTTP/1.1 response including status description."""
-    await write_status(writer, status)
+    await write_status(writer, http_version, status)
     content = f'{status.description}\n'.encode()
     await write_content_type(writer, 'text/plain')
     writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
@@ -121,6 +135,52 @@ def entry(base_url, cert):
             ('updated', not_before))
 
 
+def split_domain(domain):
+    """Split domain and order by ascending level."""
+    return tuple(domain.split('.')[::-1])
+
+
+def is_subdomain(subject, objects):
+    """Check if subject is a subdomain of any object."""
+    if not objects:
+        return True
+    sbj_parts = split_domain(subject)
+    return any(sbj_parts[:len(obj_parts)] == obj_parts
+               for obj_parts in map(split_domain, objects))
+
+
+def feed(base_url, filename, certificates, domains):
+    """Construct an Atom feed based on the given information."""
+    return ('feed', {'xmlns': 'http://www.w3.org/2005/Atom'},
+            ('id', base_url),
+            ('link', {'rel': 'self', 'href': base_url}),
+            ('title', filename),
+            ('updated', datetime.now(tz=timezone.utc).isoformat()),
+            ('generator',
+             {'uri': 'https://trong.loang.net/scadere/about',
+              'version': __version__},
+             'Scadere'),
+            *(entry(base_url, cert) for cert in certificates
+              if is_subdomain(cert[2], domains)))
+
+
+def page(certificate):
+    """Construct an XHTML page for the given TLS certificate."""
+    not_before, not_after, hostname, port, serial, issuer = certificate
+    return ('html', {'xmlns': 'http://www.w3.org/1999/xhtml',
+                     'lang': 'en'},
+            ('head',
+             ('meta', {'name': 'color-scheme',
+                       'content': 'dark light'}),
+             ('meta', {'name': 'viewport',
+                       'content': ('width=device-width,'
+                                   'initial-scale=1.0')}),
+             ('link', {'rel': 'icon', 'href': 'data:,'}),
+             ('title', f'TLS certificate - {hostname}:{port}')),
+            ('body', *body(not_before, not_after,
+                           hostname, port, serial, issuer)))
+
+
 def xml(tree, parent=None):
     """Construct XML element from the given tree."""
     tag, attrs, children = ((tree[0], tree[1], tree[2:])
@@ -138,96 +198,88 @@ def xml(tree, parent=None):
                 elem.text = child
             case datetime():
                 elem.text = child.isoformat()
-            case _:  # pragma: no cover
+            case _:
                 assert_never(child)
     if parent is None:
         indent(elem)
     return elem
 
 
-async def write_xml(writer, document):
+async def write_xml(writer, http_version, application, func, *args):
     """Write given document as XML."""
-    content = tuple(map(str.encode,
-                        strings_from_xml(xml(document), 'unicode',
-                                         xml_declaration=True,
-                                         default_namespace=None)))
-    writer.write(f'Content-Length: {sum(map(len, content))}\r\n\r\n'.encode())
-    for part in content:
-        writer.write(part)
-        await writer.drain()
-
-
-def split_domain(domain):
-    """Split domain and order by ascending level."""
-    return tuple(domain.split('.')[::-1])
-
-
-def is_subdomain(subject, objects):
-    """Check if subject is a subdomain of any object."""
-    if not objects:
-        return True
-    sbj_parts = split_domain(subject)
-    return any(sbj_parts[:len(obj_parts)] == obj_parts
-               for obj_parts in map(split_domain, objects))
+    try:
+        content = tuple(map(str.encode,
+                            strings_from_xml(xml(func(*args)), 'unicode',
+                                             xml_declaration=True,
+                                             default_namespace=None)))
+    except Exception:  # pragma: no cover
+        await describe_status(writer, HTTPStatus.INTERNAL_SERVER_ERROR,
+                              http_version)
+        raise
+    else:
+        await write_status(writer, http_version, HTTPStatus.OK)
+        await write_content_type(writer, f'application/{application}+xml')
+        length = sum(map(len, content))
+        writer.write(f'Content-Length: {length}\r\n\r\n'.encode())
+        for part in content:
+            writer.write(part)
+            await writer.drain()
 
 
 async def handle(certs, base_url, reader, writer):
     """Handle HTTP request."""
-    summaries = map(parse_summary, certs.read_text().splitlines())
-    lookup = {urlsplit(urljoin(base_url,
-                               path(hostname, port, issuer, serial))).path:
-              (datetime_from_str(not_before),
-               datetime_from_str(not_after, unavailable_ok=True),
-               hostname, port, serial, issuer)
-              for not_before, not_after, hostname, port, serial, issuer
-              in summaries}
-    request = await reader.readuntil(b'\r\n')
-    url = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1)[0].strip()
-    url_parts = urlsplit(urljoin(base_url, url.decode()))
-    domains = tuple(parse_qs(url_parts.query).get('domain', []))
-
-    if not request.startswith(b'GET '):
-        await describe_status(writer, HTTPStatus.METHOD_NOT_ALLOWED)
-    elif url_parts.path == urlsplit(base_url).path:  # Atom feed
-        await write_status(writer, HTTPStatus.OK)
-        await write_content_type(writer, 'application/atom+xml')
-        feed = ('feed', {'xmlns': 'http://www.w3.org/2005/Atom'},
-                ('id', base_url),
-                ('link', {'rel': 'self', 'href': base_url}),
-                ('title', certs.name),
-                ('updated', datetime.now(tz=timezone.utc).isoformat()),
-                ('generator',
-                 {'uri': 'https://trong.loang.net/scadere/about',
-                  'version': __version__},
-                 'Scadere'),
-                *(entry(base_url, cert) for cert in lookup.values()
-                  if is_subdomain(cert[2], domains)))
-        await write_xml(writer, feed)
-    elif url_parts.path in lookup:  # accessible Atom entry's link/ID
-        await write_status(writer, HTTPStatus.OK)
-        await write_content_type(writer, 'application/xhtml+xml')
-        (not_before, not_after,
-         hostname, port, serial, issuer) = lookup[url_parts.path]
-        page = ('html', {'xmlns': 'http://www.w3.org/1999/xhtml',
-                         'lang': 'en'},
-                ('head',
-                 ('meta', {'name': 'color-scheme',
-                           'content': 'dark light'}),
-                 ('meta', {'name': 'viewport',
-                           'content': ('width=device-width,'
-                                       'initial-scale=1.0')}),
-                 ('link', {'rel': 'icon', 'href': 'data:,'}),
-                 ('title', f'TLS certificate - {hostname}:{port}')),
-                ('body', *body(not_before, not_after,
-                               hostname, port, serial, issuer)))
-        await write_xml(writer, page)
-    else:
-        await describe_status(writer, HTTPStatus.NOT_FOUND)
-
-    assert writer.can_write_eof()
-    writer.write_eof()
-    writer.close()
-    await writer.wait_closed()
+    try:
+        try:
+            request = await reader.readuntil(b'\r\n')
+        except Exception:
+            await describe_status(writer, HTTPStatus.BAD_REQUEST)
+            return
+
+        if not request.startswith(b'GET '):
+            await describe_status(writer, HTTPStatus.METHOD_NOT_ALLOWED)
+            return
+
+        try:
+            # Raise ValueError on the lack of b'HTTP/'
+            url, version = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1)
+            http_version = version.strip().decode()
+            if not supported_http_version(http_version):
+                raise ValueError
+        except ValueError:
+            await describe_status(writer,
+                                  HTTPStatus.HTTP_VERSION_NOT_SUPPORTED)
+            return
+
+        try:
+            url_parts = urlsplit(urljoin(base_url, url.strip().decode()))
+            domains = tuple(parse_qs(url_parts.query).get('domain', []))
+            summaries = map(parse_summary, certs.read_text().splitlines())
+            lookup = {urlsplit(urljoin(base_url,
+                                       path(hostname, port,
+                                            issuer, serial))).path:
+                      (datetime_from_str(not_before),
+                       datetime_from_str(not_after, unavailable_ok=True),
+                       hostname, port, serial, issuer)
+                      for not_before, not_after, hostname, port, serial, issuer
+                      in summaries}
+        except Exception:  # pragma: no cover
+            await describe_status(writer, HTTPStatus.INTERNAL_SERVER_ERROR,
+                                  http_version)
+            raise
+
+        if url_parts.path == urlsplit(base_url).path:  # Atom feed
+            await write_xml(writer, http_version, 'atom', feed,
+                            base_url, certs.name, lookup.values(), domains)
+        elif url_parts.path in lookup:  # accessible Atom entry's link/ID
+            await write_xml(writer, http_version, 'xhtml', page,
+                            lookup.get(url_parts.path))
+        else:
+            await describe_status(writer, HTTPStatus.NOT_FOUND, http_version)
+    finally:
+        assert writer.can_write_eof()
+        writer.write_eof()
+        writer.close()
+        await writer.wait_closed()
 
 
 async def listen(certs, base_url, host, port):  # pragma: no cover
diff --git a/tst/test_listen.py b/tst/test_listen.py
index cc6a9a1..3862d9d 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -19,8 +19,10 @@
 from asyncio import TaskGroup, open_connection, start_server
 from contextlib import asynccontextmanager, contextmanager
 from copy import deepcopy
+from datetime import datetime
 from email.parser import BytesHeaderParser
 from functools import partial
+from http import HTTPMethod
 from pathlib import Path
 from tempfile import mkstemp
 from urllib.parse import urljoin, urlsplit
@@ -28,12 +30,14 @@ from xml.etree.ElementTree import (XML, XMLParser, indent,
                                    tostring as str_from_xml)
 
 from hypothesis import HealthCheck, given, settings
-from hypothesis.strategies import (booleans, builds, composite, data,
-                                   datetimes, integers, lists, text)
+from hypothesis.strategies import (booleans, builds, composite,
+                                   data, datetimes, from_type,
+                                   integers, lists, sampled_from, text)
 from hypothesis.provisional import domains, urls
+from pytest import raises
 
 from scadere.check import base64_from_str
-from scadere.listen import handle, is_subdomain, path, with_trailing_slash
+from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
 XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
@@ -81,6 +85,17 @@ def test_is_subdomain(subject, objects):
                     or not subject.removesuffix(obj).endswith('.'))
 
 
+def xml_unsupported_type(child):
+    """Check if child is of a type supported by the XML constructor."""
+    return not isinstance(child, (tuple, str, datetime))
+
+
+@given(text(), from_type(type).flatmap(from_type).filter(xml_unsupported_type))
+def test_xml_unsupported_type(tag, child):
+    with raises(AssertionError, match='Expected code to be unreachable'):
+        xml((tag, {}, child))
+
+
 @composite
 def certificates(draw):
     """Return a Hypothesis strategy for certificate summaries."""
@@ -148,15 +163,22 @@ async def connect(socket):
         await writer.wait_closed()
 
 
+async def write_request(writer, request):
+    """Write given request."""
+    writer.write(request.encode())
+    await writer.drain()
+    assert writer.can_write_eof()
+    writer.write_eof()
+
+
 async def fetch_xml(socket, url, content_type):
     """Fetch the content at the URL from the socket."""
     header_parser = BytesHeaderParser()
     xml_parser = XMLParser(encoding='utf-8')
     async with connect(socket) as (reader, writer):
-        writer.write(f'GET {url}\r\n'.encode())
-        await writer.drain()
+        await write_request(writer, f'GET {url} HTTP/2\r\n')
         status = await reader.readuntil(b'\r\n')
-        assert status == b'HTTP/1.1 200 OK\r\n'
+        assert status == b'HTTP/2 200 OK\r\n'
         headers_bytes = await reader.readuntil(b'\r\n\r\n')
         headers = header_parser.parsebytes(headers_bytes)
         assert headers['Content-Type'] == content_type
@@ -176,11 +198,11 @@ def equal_xml(a, b):
 async def check_feed(socket, base_url):
     """Check the Atom feed at the given path and its entry pages."""
     feed = await fetch_xml(socket, base_url, 'application/atom+xml')
-    for feed_entry in feed.findall('entry', ATOM_NAMESPACES):
-        link = feed_entry.find('link', ATOM_NAMESPACES).attrib
+    for entry in feed.findall('entry', ATOM_NAMESPACES):
+        link = entry.find('link', ATOM_NAMESPACES).attrib
         assert link['rel'] == 'alternate'
         page = await fetch_xml(socket, link['href'], link['type'])
-        assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES),
+        assert equal_xml(entry.find('.//dl', XHTML_NAMESPACES),
                          page.find('.//dl', XHTML_NAMESPACES))
 
 
@@ -202,13 +224,28 @@ async def test_content(base_url, certs):
             await check_server(server.sockets, check_feed, base_path)
 
 
+async def bad_request(socket, request):
+    """Expect from socket a HTTP response with status 400."""
+    async with connect(socket) as (reader, writer):
+        await write_request(writer, request)
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 400 Bad Request\r\n'
+
+
+@given(text().filter(lambda request: '\r\n' not in request))
+async def test_incomplete_request(request):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, 'http://localhost')
+        async with await start_server(handler, 'localhost') as server:
+            await check_server(server.sockets, bad_request, request)
+
+
 async def not_found(socket, url):
-    """Send GET request for URL and expect a 404 status from socket."""
+    """Send GET request for URL and expect HTTP status 404 from socket."""
     async with connect(socket) as (reader, writer):
-        writer.write(f'GET {url}\r\n'.encode())
-        await writer.drain()
-        response = await reader.readuntil(b'\r\n')
-        assert response == b'HTTP/1.1 404 Not Found\r\n'
+        await write_request(writer, f'GET {url} HTTP/1\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1 404 Not Found\r\n'
 
 
 @given(data())
@@ -223,18 +260,35 @@ async def test_unrecognized_url(drawer):
             await check_server(server.sockets, not_found, urlsplit(url).path)
 
 
-async def method_not_allowed(socket, request):
+async def method_not_allowed(socket, method, url):
     """Expect from socket a HTTP response with status 405."""
     async with connect(socket) as (reader, writer):
-        writer.write(f'{request}\r\n'.encode())
-        await writer.drain()
-        response = await reader.readuntil(b'\r\n')
-        assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
+        await write_request(writer, f'{method} {url} HTTP/1.1\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 405 Method Not Allowed\r\n'
+
+
+@given(urls(), sampled_from(HTTPMethod).filter(lambda method: method != 'GET'))
+async def test_unallowed_method(base_url, method):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        async with await start_server(handler, 'localhost') as server:
+            await check_server(server.sockets, method_not_allowed,
+                               method.value, base_url)
+
+
+async def unsupported_http_version(socket, url, version):
+    """Expect from socket a HTTP response with status 505."""
+    async with connect(socket) as (reader, writer):
+        await write_request(writer, f'GET {url} HTTP/{version}\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 505 HTTP Version Not Supported\r\n'
 
 
-@given(urls(), text().filter(lambda method: not method.startswith('GET ')))
-async def test_unallowed_method(base_url, request):
+@given(urls().filter(is_base_url).filter(has_usual_path), integers(10))
+async def test_unsupported_http_version(base_url, version):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, method_not_allowed, request)
+            await check_server(server.sockets, unsupported_http_version,
+                               base_url, version)