# Tests for the HTTP server # Copyright (C) 2025 Nguyễn Gia Phong # # This file is part of scadere. # # Scadere is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published # by the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Scadere is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with scadere. If not, see . from asyncio import TaskGroup, open_connection, start_server from contextlib import asynccontextmanager, contextmanager from copy import deepcopy from email.parser import BytesHeaderParser from functools import partial from pathlib import Path from string import ascii_letters from tempfile import mkstemp from urllib.parse import urljoin, urlsplit from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) from xml.sax.saxutils import escape from hypothesis import HealthCheck, given, settings from hypothesis.strategies import (builds, composite, data, datetimes, integers, lists, text) from hypothesis.provisional import domains, urls from scadere.check import base64_from_str from scadere.listen import (body, entry, handle, is_subdomain, path, str_from_base64, with_trailing_slash, xml) ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'} def ports(): """Return a Hypothesis strategy for TCP ports.""" return integers(1, 65535) def serials(): """Return a Hypothesis strategy for TLS serial number.""" return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) def ca_names(): """Return a Hypothesis strategy for CA names.""" return text().map(lambda name: base64_from_str(name)) @given(domains(), ports(), ca_names(), serials()) def test_path(hostname, port, issuer, serial): r = path(hostname, port, issuer, serial).split('/') assert r[0] == hostname assert int(r[1]) == port assert r[2] == issuer assert r[3] == serial @given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) def test_body(hostname, port, issuer, serial, not_before, not_after): r = body(not_before, not_after, hostname, port, serial, issuer) assert r[-1][0] == 'dl' d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'), (v for k, v in r[-1][1:] if k == 'dd'))) assert d['Domain'] == hostname assert d['Port'] == port assert d['Issuer'] == str_from_base64(issuer) assert d['Serial number'] == serial assert d['Valid from'] == not_before assert d['Valid until'] == not_after @given(urls(), domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) def test_atom_entry(base_url, hostname, port, issuer, serial, not_before, not_after): cert = not_before, not_after, hostname, port, serial, issuer r = str_from_xml(xml(entry(base_url, cert)), 'unicode', short_empty_elements=False) issuer_str = str_from_base64(issuer) url = urljoin(base_url, path(hostname, port, issuer, serial)) assert r == f''' {escape(issuer_str)}

TLS certificate information

Domain
{hostname}
Port
{port}
Issuer
{escape(issuer_str)}
Serial number
{serial}
Valid from
{not_before.isoformat()}
Valid until
{not_after.isoformat()}
{url} TLS cert for {hostname} will expire at {not_after} {not_before.isoformat()}
''' @given(domains(), lists(domains())) def test_is_subdomain(subject, objects): if not objects: assert is_subdomain(subject, objects) elif is_subdomain(subject, objects): assert any(child == '' or child.endswith('.') for child in map(subject.removesuffix, objects)) else: for obj in objects: assert (not subject.endswith(obj) or not subject.removesuffix(obj).endswith('.')) @composite def certificates(draw): """Return a Hypothesis strategy for certificate summaries.""" not_before = draw(datetimes()).isoformat() not_after = draw(datetimes()).isoformat() hostname = draw(domains()) port = draw(ports()) serial = draw(serials()) # Free-formed UTF-8 could easily creates malformed XML. issuer = base64_from_str(draw(text(ascii_letters))) return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' @contextmanager def tmp_cert_file(lines): cert_file = Path(mkstemp(text=True)[1]) cert_file.write_text('\n'.join(lines)) try: yield cert_file finally: cert_file.unlink() def is_base_url(url): """Check if the given URL has a trailing slash. The parser for command-line arguments enforces this property for base URLs. """ return url.endswith('/') @given(urls()) def test_with_trailing_slash(url): if is_base_url(url): assert with_trailing_slash(url) == url else: assert with_trailing_slash(url) == f'{url}/' def is_root(base_url, url): """Check if the given URL points to the same base URL. Paths starting with // are excluded because urljoin sometimes confuse them with URL scheme. """ base_path = urlsplit(base_url).path url_path = urlsplit(url).path return urlsplit(urljoin(base_url, url_path)).path == base_path def has_usual_path(url): """Check if the given URL path tends to mess with urljoin.""" return is_root(url, url) @asynccontextmanager async def connect(socket): """Return a read-write stream for an asyncio TCP connection.""" reader, writer = await open_connection(*socket.getsockname()[:2]) try: yield reader, writer finally: writer.close() await writer.wait_closed() async def fetch_xml(socket, url, content_type): """Fetch the content at the URL from the socket.""" header_parser = BytesHeaderParser() xml_parser = XMLParser(encoding='utf-8') async with connect(socket) as (reader, writer): writer.write(f'GET {url}\r\n'.encode()) await writer.drain() status = await reader.readuntil(b'\r\n') assert status == b'HTTP/1.1 200 OK\r\n' headers_bytes = await reader.readuntil(b'\r\n\r\n') headers = header_parser.parsebytes(headers_bytes) assert headers['Content-Type'] == content_type content = await reader.read() assert len(content) == int(headers['Content-Length']) return XML(content.decode(), xml_parser) def equal_xml(a, b): """Check if the two XML elements are equal.""" a_copy, b_copy = deepcopy(a), deepcopy(b) indent(a_copy) indent(b_copy) return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip() async def check_feed(socket, base_url): """Check the Atom feed at the given path and its entry pages.""" feed = await fetch_xml(socket, base_url, 'application/atom+xml') for feed_entry in feed.findall('entry', ATOM_NAMESPACES): link = feed_entry.find('link', ATOM_NAMESPACES).attrib assert link['rel'] == 'alternate' page = await fetch_xml(socket, link['href'], link['type']) assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES), page.find('.//dl', XHTML_NAMESPACES)) async def check_server(sockets, func, *args): """Test server listening for connections on sockets using func.""" async with TaskGroup() as group: for socket in sockets: group.create_task(func(socket, *args)) @given(urls().filter(is_base_url).filter(has_usual_path), lists(certificates(), min_size=1)) @settings(deadline=None) async def test_http_200(base_url, certs): base_path = urlsplit(base_url).path with tmp_cert_file(certs) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: await check_server(server.sockets, check_feed, base_path) async def not_found(socket, url): """Send GET request for URL and expect a 404 status from socket.""" async with connect(socket) as (reader, writer): writer.write(f'GET {url}\r\n'.encode()) await writer.drain() response = await reader.readuntil(b'\r\n') assert response == b'HTTP/1.1 404 Not Found\r\n' @given(data()) @settings(suppress_health_check=[HealthCheck.too_slow]) async def test_unrecognized_url(drawer): base_url = drawer.draw(urls().filter(is_base_url), label='base URL') url = drawer.draw(urls().filter(lambda url: not is_root(base_url, url)), label='request URL') with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: await check_server(server.sockets, not_found, urlsplit(url).path) async def method_not_allowed(socket, request): """Expect from socket a HTTP response with status 405.""" async with connect(socket) as (reader, writer): writer.write(f'{request}\r\n'.encode()) await writer.drain() response = await reader.readuntil(b'\r\n') assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' @given(urls(), text().filter(lambda method: not method.startswith('GET '))) async def test_unallowed_method(base_url, request): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: await check_server(server.sockets, method_not_allowed, request)