# Tests for the HTTP server
# Copyright (C) 2025 Nguyễn Gia Phong
#
# This file is part of scadere.
#
# Scadere is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Scadere is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with scadere. If not, see .
from asyncio import TaskGroup, open_connection, start_server
from contextlib import asynccontextmanager, contextmanager
from copy import deepcopy
from email.parser import BytesHeaderParser
from functools import partial
from pathlib import Path
from string import ascii_letters
from tempfile import mkstemp
from urllib.parse import urljoin, urlsplit
from xml.etree.ElementTree import (XML, XMLParser, indent,
tostring as str_from_xml)
from xml.sax.saxutils import escape
from hypothesis import HealthCheck, given, settings
from hypothesis.strategies import (builds, composite, data,
datetimes, integers, lists, text)
from hypothesis.provisional import domains, urls
from scadere.check import base64_from_str
from scadere.listen import (body, entry, handle, is_subdomain, path,
str_from_base64, with_trailing_slash, xml)
ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
def ports():
"""Return a Hypothesis strategy for TCP ports."""
return integers(1, 65535)
def serials():
"""Return a Hypothesis strategy for TLS serial number."""
return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1))
def ca_names():
"""Return a Hypothesis strategy for CA names."""
return text().map(lambda name: base64_from_str(name))
@given(domains(), ports(), ca_names(), serials())
def test_path(hostname, port, issuer, serial):
r = path(hostname, port, issuer, serial).split('/')
assert r[0] == hostname
assert int(r[1]) == port
assert r[2] == issuer
assert r[3] == serial
@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes())
def test_body(hostname, port, issuer, serial, not_before, not_after):
r = body(not_before, not_after, hostname, port, serial, issuer)
assert r[-1][0] == 'dl'
d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'),
(v for k, v in r[-1][1:] if k == 'dd')))
assert d['Domain'] == hostname
assert d['Port'] == port
assert d['Issuer'] == str_from_base64(issuer)
assert d['Serial number'] == serial
assert d['Valid from'] == not_before
assert d['Valid until'] == not_after
@given(urls(), domains(), ports(),
ca_names(), serials(), datetimes(), datetimes())
def test_atom_entry(base_url, hostname, port,
issuer, serial, not_before, not_after):
cert = not_before, not_after, hostname, port, serial, issuer
r = str_from_xml(xml(entry(base_url, cert)),
'unicode', short_empty_elements=False)
issuer_str = str_from_base64(issuer)
url = urljoin(base_url, path(hostname, port, issuer, serial))
assert r == f'''{escape(issuer_str)}
TLS certificate information
Domain
{hostname}
Port
{port}
Issuer
{escape(issuer_str)}
Serial number
{serial}
Valid from
{not_before.isoformat()}
Valid until
{not_after.isoformat()}
{url}TLS cert for {hostname} will expire at {not_after}{not_before.isoformat()}'''
@given(domains(), lists(domains()))
def test_is_subdomain(subject, objects):
if not objects:
assert is_subdomain(subject, objects)
elif is_subdomain(subject, objects):
assert any(child == '' or child.endswith('.')
for child in map(subject.removesuffix, objects))
else:
for obj in objects:
assert (not subject.endswith(obj)
or not subject.removesuffix(obj).endswith('.'))
@composite
def certificates(draw):
"""Return a Hypothesis strategy for certificate summaries."""
not_before = draw(datetimes()).isoformat()
not_after = draw(datetimes()).isoformat()
hostname = draw(domains())
port = draw(ports())
serial = draw(serials())
# Free-formed UTF-8 could easily creates malformed XML.
issuer = base64_from_str(draw(text(ascii_letters)))
return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'
@contextmanager
def tmp_cert_file(lines):
cert_file = Path(mkstemp(text=True)[1])
cert_file.write_text('\n'.join(lines))
try:
yield cert_file
finally:
cert_file.unlink()
def is_base_url(url):
"""Check if the given URL has a trailing slash.
The parser for command-line arguments
enforces this property for base URLs.
"""
return url.endswith('/')
@given(urls())
def test_with_trailing_slash(url):
if is_base_url(url):
assert with_trailing_slash(url) == url
else:
assert with_trailing_slash(url) == f'{url}/'
def is_root(base_url, url):
"""Check if the given URL points to the same base URL.
Paths starting with // are excluded because urljoin
sometimes confuse them with URL scheme.
"""
base_path = urlsplit(base_url).path
url_path = urlsplit(url).path
return urlsplit(urljoin(base_url, url_path)).path == base_path
def has_usual_path(url):
"""Check if the given URL path tends to mess with urljoin."""
return is_root(url, url)
@asynccontextmanager
async def connect(socket):
"""Return a read-write stream for an asyncio TCP connection."""
reader, writer = await open_connection(*socket.getsockname()[:2])
try:
yield reader, writer
finally:
writer.close()
await writer.wait_closed()
async def fetch_xml(socket, url, content_type):
"""Fetch the content at the URL from the socket."""
header_parser = BytesHeaderParser()
xml_parser = XMLParser(encoding='utf-8')
async with connect(socket) as (reader, writer):
writer.write(f'GET {url}\r\n'.encode())
await writer.drain()
status = await reader.readuntil(b'\r\n')
assert status == b'HTTP/1.1 200 OK\r\n'
headers_bytes = await reader.readuntil(b'\r\n\r\n')
headers = header_parser.parsebytes(headers_bytes)
assert headers['Content-Type'] == content_type
content = await reader.read()
assert len(content) == int(headers['Content-Length'])
return XML(content.decode(), xml_parser)
def equal_xml(a, b):
"""Check if the two XML elements are equal."""
a_copy, b_copy = deepcopy(a), deepcopy(b)
indent(a_copy)
indent(b_copy)
return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip()
async def check_feed(socket, base_url):
"""Check the Atom feed at the given path and its entry pages."""
feed = await fetch_xml(socket, base_url, 'application/atom+xml')
for feed_entry in feed.findall('entry', ATOM_NAMESPACES):
link = feed_entry.find('link', ATOM_NAMESPACES).attrib
assert link['rel'] == 'alternate'
page = await fetch_xml(socket, link['href'], link['type'])
assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES),
page.find('.//dl', XHTML_NAMESPACES))
async def check_server(sockets, func, *args):
"""Test server listening for connections on sockets using func."""
async with TaskGroup() as group:
for socket in sockets:
group.create_task(func(socket, *args))
@given(urls().filter(is_base_url).filter(has_usual_path),
lists(certificates(), min_size=1))
@settings(deadline=None)
async def test_http_200(base_url, certs):
base_path = urlsplit(base_url).path
with tmp_cert_file(certs) as cert_file:
handler = partial(handle, cert_file, base_url)
async with await start_server(handler, 'localhost') as server:
await check_server(server.sockets, check_feed, base_path)
async def not_found(socket, url):
"""Send GET request for URL and expect a 404 status from socket."""
async with connect(socket) as (reader, writer):
writer.write(f'GET {url}\r\n'.encode())
await writer.drain()
response = await reader.readuntil(b'\r\n')
assert response == b'HTTP/1.1 404 Not Found\r\n'
@given(data())
@settings(suppress_health_check=[HealthCheck.too_slow])
async def test_unrecognized_url(drawer):
base_url = drawer.draw(urls().filter(is_base_url), label='base URL')
url = drawer.draw(urls().filter(lambda url: not is_root(base_url, url)),
label='request URL')
with tmp_cert_file(()) as cert_file:
handler = partial(handle, cert_file, base_url)
async with await start_server(handler, 'localhost') as server:
await check_server(server.sockets, not_found, urlsplit(url).path)
async def method_not_allowed(socket, request):
"""Expect from socket a HTTP response with status 405."""
async with connect(socket) as (reader, writer):
writer.write(f'{request}\r\n'.encode())
await writer.drain()
response = await reader.readuntil(b'\r\n')
assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
@given(urls(), text().filter(lambda method: not method.startswith('GET ')))
async def test_unallowed_method(base_url, request):
with tmp_cert_file(()) as cert_file:
handler = partial(handle, cert_file, base_url)
async with await start_server(handler, 'localhost') as server:
await check_server(server.sockets, method_not_allowed, request)