diff options
author | Nguyễn Gia Phong <cnx@loang.net> | 2025-06-04 15:41:28 +0900 |
---|---|---|
committer | Nguyễn Gia Phong <cnx@loang.net> | 2025-06-04 15:41:28 +0900 |
commit | 11d05505cdf25b77cfbdf09f5f1d1be79eeaa0f3 (patch) | |
tree | e0c19143ef56cdb856846a59d3f17ec323acd85a | |
parent | 10bffe228843ea4c59110fc9ce40663a7a144338 (diff) | |
download | scadere-11d05505cdf25b77cfbdf09f5f1d1be79eeaa0f3.tar.gz |
Handle server errors
-rw-r--r-- | src/scadere/listen.py | 220 | ||||
-rw-r--r-- | tst/test_listen.py | 98 |
2 files changed, 212 insertions, 106 deletions
diff --git a/src/scadere/listen.py b/src/scadere/listen.py index d8c1178..59dbef6 100644 --- a/src/scadere/listen.py +++ b/src/scadere/listen.py @@ -23,6 +23,7 @@ from datetime import datetime, timezone from functools import partial from http import HTTPStatus from pathlib import Path +from string import digits from typing import assert_never from urllib.parse import parse_qs, urljoin, urlsplit from xml.etree.ElementTree import (Element as xml_element, @@ -47,6 +48,18 @@ def path(hostname, port, issuer, serial): return f'{hostname}/{port}/{issuer}/{serial}' +def supported_http_version(version): + """Check if given HTTP version complies with section 2.5 of RFC 9110.""" + match len(version): + case 1: + return version in digits + case 3: + major, period, minor = version + return major in digits and period == '.' and minor in digits + case _: + return False + + def datetime_from_str(string, unavailable_ok=False): """Parse datetime from string in ISO 8601 format.""" if string == 'N/A' and unavailable_ok: @@ -54,9 +67,10 @@ def datetime_from_str(string, unavailable_ok=False): return datetime.fromisoformat(string) -async def write_status(writer, status): - """Write the given HTTP/1.1 status line.""" - writer.write(f'HTTP/1.1 {status.value} {status.phrase}\r\n'.encode()) +async def write_status(writer, http_version, status): + """Write the given HTTP status line.""" + status = f'HTTP/{http_version} {status.value} {status.phrase}\r\n' + writer.write(status.encode()) await writer.drain() @@ -66,9 +80,9 @@ async def write_content_type(writer, content_type): await writer.drain() -async def describe_status(writer, status): +async def describe_status(writer, status, http_version='1.1'): """Write a HTTP/1.1 response including status description.""" - await write_status(writer, status) + await write_status(writer, http_version, status) content = f'{status.description}\n'.encode() await write_content_type(writer, 'text/plain') writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode()) @@ -121,6 +135,52 @@ def entry(base_url, cert): ('updated', not_before)) +def split_domain(domain): + """Split domain and order by ascending level.""" + return tuple(domain.split('.')[::-1]) + + +def is_subdomain(subject, objects): + """Check if subject is a subdomain of any object.""" + if not objects: + return True + sbj_parts = split_domain(subject) + return any(sbj_parts[:len(obj_parts)] == obj_parts + for obj_parts in map(split_domain, objects)) + + +def feed(base_url, filename, certificates, domains): + """Construct an Atom feed based on the given information.""" + return ('feed', {'xmlns': 'http://www.w3.org/2005/Atom'}, + ('id', base_url), + ('link', {'rel': 'self', 'href': base_url}), + ('title', filename), + ('updated', datetime.now(tz=timezone.utc).isoformat()), + ('generator', + {'uri': 'https://trong.loang.net/scadere/about', + 'version': __version__}, + 'Scadere'), + *(entry(base_url, cert) for cert in certificates + if is_subdomain(cert[2], domains))) + + +def page(certificate): + """Construct an XHTML page for the given TLS certificate.""" + not_before, not_after, hostname, port, serial, issuer = certificate + return ('html', {'xmlns': 'http://www.w3.org/1999/xhtml', + 'lang': 'en'}, + ('head', + ('meta', {'name': 'color-scheme', + 'content': 'dark light'}), + ('meta', {'name': 'viewport', + 'content': ('width=device-width,' + 'initial-scale=1.0')}), + ('link', {'rel': 'icon', 'href': 'data:,'}), + ('title', f'TLS certificate - {hostname}:{port}')), + ('body', *body(not_before, not_after, + hostname, port, serial, issuer))) + + def xml(tree, parent=None): """Construct XML element from the given tree.""" tag, attrs, children = ((tree[0], tree[1], tree[2:]) @@ -138,96 +198,88 @@ def xml(tree, parent=None): elem.text = child case datetime(): elem.text = child.isoformat() - case _: # pragma: no cover + case _: assert_never(child) if parent is None: indent(elem) return elem -async def write_xml(writer, document): +async def write_xml(writer, http_version, application, func, *args): """Write given document as XML.""" - content = tuple(map(str.encode, - strings_from_xml(xml(document), 'unicode', - xml_declaration=True, - default_namespace=None))) - writer.write(f'Content-Length: {sum(map(len, content))}\r\n\r\n'.encode()) - for part in content: - writer.write(part) - await writer.drain() - - -def split_domain(domain): - """Split domain and order by ascending level.""" - return tuple(domain.split('.')[::-1]) - - -def is_subdomain(subject, objects): - """Check if subject is a subdomain of any object.""" - if not objects: - return True - sbj_parts = split_domain(subject) - return any(sbj_parts[:len(obj_parts)] == obj_parts - for obj_parts in map(split_domain, objects)) + try: + content = tuple(map(str.encode, + strings_from_xml(xml(func(*args)), 'unicode', + xml_declaration=True, + default_namespace=None))) + except Exception: # pragma: no cover + await describe_status(writer, HTTPStatus.INTERNAL_SERVER_ERROR, + http_version) + raise + else: + await write_status(writer, http_version, HTTPStatus.OK) + await write_content_type(writer, f'application/{application}+xml') + length = sum(map(len, content)) + writer.write(f'Content-Length: {length}\r\n\r\n'.encode()) + for part in content: + writer.write(part) + await writer.drain() async def handle(certs, base_url, reader, writer): """Handle HTTP request.""" - summaries = map(parse_summary, certs.read_text().splitlines()) - lookup = {urlsplit(urljoin(base_url, - path(hostname, port, issuer, serial))).path: - (datetime_from_str(not_before), - datetime_from_str(not_after, unavailable_ok=True), - hostname, port, serial, issuer) - for not_before, not_after, hostname, port, serial, issuer - in summaries} - request = await reader.readuntil(b'\r\n') - url = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1)[0].strip() - url_parts = urlsplit(urljoin(base_url, url.decode())) - domains = tuple(parse_qs(url_parts.query).get('domain', [])) - - if not request.startswith(b'GET '): - await describe_status(writer, HTTPStatus.METHOD_NOT_ALLOWED) - elif url_parts.path == urlsplit(base_url).path: # Atom feed - await write_status(writer, HTTPStatus.OK) - await write_content_type(writer, 'application/atom+xml') - feed = ('feed', {'xmlns': 'http://www.w3.org/2005/Atom'}, - ('id', base_url), - ('link', {'rel': 'self', 'href': base_url}), - ('title', certs.name), - ('updated', datetime.now(tz=timezone.utc).isoformat()), - ('generator', - {'uri': 'https://trong.loang.net/scadere/about', - 'version': __version__}, - 'Scadere'), - *(entry(base_url, cert) for cert in lookup.values() - if is_subdomain(cert[2], domains))) - await write_xml(writer, feed) - elif url_parts.path in lookup: # accessible Atom entry's link/ID - await write_status(writer, HTTPStatus.OK) - await write_content_type(writer, 'application/xhtml+xml') - (not_before, not_after, - hostname, port, serial, issuer) = lookup[url_parts.path] - page = ('html', {'xmlns': 'http://www.w3.org/1999/xhtml', - 'lang': 'en'}, - ('head', - ('meta', {'name': 'color-scheme', - 'content': 'dark light'}), - ('meta', {'name': 'viewport', - 'content': ('width=device-width,' - 'initial-scale=1.0')}), - ('link', {'rel': 'icon', 'href': 'data:,'}), - ('title', f'TLS certificate - {hostname}:{port}')), - ('body', *body(not_before, not_after, - hostname, port, serial, issuer))) - await write_xml(writer, page) - else: - await describe_status(writer, HTTPStatus.NOT_FOUND) - - assert writer.can_write_eof() - writer.write_eof() - writer.close() - await writer.wait_closed() + try: + try: + request = await reader.readuntil(b'\r\n') + except Exception: + await describe_status(writer, HTTPStatus.BAD_REQUEST) + return + + if not request.startswith(b'GET '): + await describe_status(writer, HTTPStatus.METHOD_NOT_ALLOWED) + return + + try: + # Raise ValueError on the lack of b'HTTP/' + url, version = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1) + http_version = version.strip().decode() + if not supported_http_version(http_version): + raise ValueError + except ValueError: + await describe_status(writer, + HTTPStatus.HTTP_VERSION_NOT_SUPPORTED) + return + + try: + url_parts = urlsplit(urljoin(base_url, url.strip().decode())) + domains = tuple(parse_qs(url_parts.query).get('domain', [])) + summaries = map(parse_summary, certs.read_text().splitlines()) + lookup = {urlsplit(urljoin(base_url, + path(hostname, port, + issuer, serial))).path: + (datetime_from_str(not_before), + datetime_from_str(not_after, unavailable_ok=True), + hostname, port, serial, issuer) + for not_before, not_after, hostname, port, serial, issuer + in summaries} + except Exception: # pragma: no cover + await describe_status(writer, HTTPStatus.INTERNAL_SERVER_ERROR, + http_version) + raise + + if url_parts.path == urlsplit(base_url).path: # Atom feed + await write_xml(writer, http_version, 'atom', feed, + base_url, certs.name, lookup.values(), domains) + elif url_parts.path in lookup: # accessible Atom entry's link/ID + await write_xml(writer, http_version, 'xhtml', page, + lookup.get(url_parts.path)) + else: + await describe_status(writer, HTTPStatus.NOT_FOUND, http_version) + finally: + assert writer.can_write_eof() + writer.write_eof() + writer.close() + await writer.wait_closed() async def listen(certs, base_url, host, port): # pragma: no cover diff --git a/tst/test_listen.py b/tst/test_listen.py index cc6a9a1..3862d9d 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -19,8 +19,10 @@ from asyncio import TaskGroup, open_connection, start_server from contextlib import asynccontextmanager, contextmanager from copy import deepcopy +from datetime import datetime from email.parser import BytesHeaderParser from functools import partial +from http import HTTPMethod from pathlib import Path from tempfile import mkstemp from urllib.parse import urljoin, urlsplit @@ -28,12 +30,14 @@ from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) from hypothesis import HealthCheck, given, settings -from hypothesis.strategies import (booleans, builds, composite, data, - datetimes, integers, lists, text) +from hypothesis.strategies import (booleans, builds, composite, + data, datetimes, from_type, + integers, lists, sampled_from, text) from hypothesis.provisional import domains, urls +from pytest import raises from scadere.check import base64_from_str -from scadere.listen import handle, is_subdomain, path, with_trailing_slash +from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'} @@ -81,6 +85,17 @@ def test_is_subdomain(subject, objects): or not subject.removesuffix(obj).endswith('.')) +def xml_unsupported_type(child): + """Check if child is of a type supported by the XML constructor.""" + return not isinstance(child, (tuple, str, datetime)) + + +@given(text(), from_type(type).flatmap(from_type).filter(xml_unsupported_type)) +def test_xml_unsupported_type(tag, child): + with raises(AssertionError, match='Expected code to be unreachable'): + xml((tag, {}, child)) + + @composite def certificates(draw): """Return a Hypothesis strategy for certificate summaries.""" @@ -148,15 +163,22 @@ async def connect(socket): await writer.wait_closed() +async def write_request(writer, request): + """Write given request.""" + writer.write(request.encode()) + await writer.drain() + assert writer.can_write_eof() + writer.write_eof() + + async def fetch_xml(socket, url, content_type): """Fetch the content at the URL from the socket.""" header_parser = BytesHeaderParser() xml_parser = XMLParser(encoding='utf-8') async with connect(socket) as (reader, writer): - writer.write(f'GET {url}\r\n'.encode()) - await writer.drain() + await write_request(writer, f'GET {url} HTTP/2\r\n') status = await reader.readuntil(b'\r\n') - assert status == b'HTTP/1.1 200 OK\r\n' + assert status == b'HTTP/2 200 OK\r\n' headers_bytes = await reader.readuntil(b'\r\n\r\n') headers = header_parser.parsebytes(headers_bytes) assert headers['Content-Type'] == content_type @@ -176,11 +198,11 @@ def equal_xml(a, b): async def check_feed(socket, base_url): """Check the Atom feed at the given path and its entry pages.""" feed = await fetch_xml(socket, base_url, 'application/atom+xml') - for feed_entry in feed.findall('entry', ATOM_NAMESPACES): - link = feed_entry.find('link', ATOM_NAMESPACES).attrib + for entry in feed.findall('entry', ATOM_NAMESPACES): + link = entry.find('link', ATOM_NAMESPACES).attrib assert link['rel'] == 'alternate' page = await fetch_xml(socket, link['href'], link['type']) - assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES), + assert equal_xml(entry.find('.//dl', XHTML_NAMESPACES), page.find('.//dl', XHTML_NAMESPACES)) @@ -202,13 +224,28 @@ async def test_content(base_url, certs): await check_server(server.sockets, check_feed, base_path) +async def bad_request(socket, request): + """Expect from socket a HTTP response with status 400.""" + async with connect(socket) as (reader, writer): + await write_request(writer, request) + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 400 Bad Request\r\n' + + +@given(text().filter(lambda request: '\r\n' not in request)) +async def test_incomplete_request(request): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, 'http://localhost') + async with await start_server(handler, 'localhost') as server: + await check_server(server.sockets, bad_request, request) + + async def not_found(socket, url): - """Send GET request for URL and expect a 404 status from socket.""" + """Send GET request for URL and expect HTTP status 404 from socket.""" async with connect(socket) as (reader, writer): - writer.write(f'GET {url}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 404 Not Found\r\n' + await write_request(writer, f'GET {url} HTTP/1\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1 404 Not Found\r\n' @given(data()) @@ -223,18 +260,35 @@ async def test_unrecognized_url(drawer): await check_server(server.sockets, not_found, urlsplit(url).path) -async def method_not_allowed(socket, request): +async def method_not_allowed(socket, method, url): """Expect from socket a HTTP response with status 405.""" async with connect(socket) as (reader, writer): - writer.write(f'{request}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' + await write_request(writer, f'{method} {url} HTTP/1.1\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 405 Method Not Allowed\r\n' + + +@given(urls(), sampled_from(HTTPMethod).filter(lambda method: method != 'GET')) +async def test_unallowed_method(base_url, method): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + async with await start_server(handler, 'localhost') as server: + await check_server(server.sockets, method_not_allowed, + method.value, base_url) + + +async def unsupported_http_version(socket, url, version): + """Expect from socket a HTTP response with status 505.""" + async with connect(socket) as (reader, writer): + await write_request(writer, f'GET {url} HTTP/{version}\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 505 HTTP Version Not Supported\r\n' -@given(urls(), text().filter(lambda method: not method.startswith('GET '))) -async def test_unallowed_method(base_url, request): +@given(urls().filter(is_base_url).filter(has_usual_path), integers(10)) +async def test_unsupported_http_version(base_url, version): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, method_not_allowed, request) + await check_server(server.sockets, unsupported_http_version, + base_url, version) |