From 56a032568443bdf85dd37df5f6716b3475626d6a Mon Sep 17 00:00:00 2001 From: Nguyễn Gia Phong Date: Mon, 26 May 2025 17:45:48 +0900 Subject: Fix handling of base URL --- tst/test_listen.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 161 insertions(+), 21 deletions(-) (limited to 'tst/test_listen.py') diff --git a/tst/test_listen.py b/tst/test_listen.py index d3a8052..f35ce82 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -14,35 +14,175 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from base64 import urlsafe_b64decode as from_base64 +from asyncio import open_connection, start_server +from base64 import (urlsafe_b64decode as from_base64, + urlsafe_b64encode as base64) +from contextlib import closing, contextmanager +from functools import partial +from pathlib import Path +from tempfile import mkstemp +from urllib.parse import urljoin, urlsplit +from xml.etree.ElementTree import tostring as str_from_xml +from xml.sax.saxutils import escape from hypothesis import given -from hypothesis.strategies import integers, datetimes, text -from hypothesis.provisional import domains +from hypothesis.strategies import (builds, composite, datetimes, + integers, lists, text) +from hypothesis.provisional import domains, urls -from scadere.listen import body, path +from scadere.listen import body, entry, handle, path, xml -@given(domains(), integers(1, 65535), text(), integers(0, 256**20)) +def ports(): + """Return a Hypothesis strategy for TCP ports.""" + return integers(1, 65535) + + +def serials(): + """Return a Hypothesis strategy for TLS serial number.""" + return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) + + +def ca_names(): + """Return a Hypothesis strategy for CA names.""" + return text().map(lambda name: base64(name.encode()).decode()) + + +@given(domains(), ports(), ca_names(), serials()) def test_path(hostname, port, issuer, serial): - r = path(hostname, port, issuer, hex(serial).removeprefix('0x')).split('/') - assert(r[0] == hostname) - assert(int(r[1]) == port) - assert(from_base64(r[2]).decode() == issuer) - assert(int(r[3], 16) == serial) + r = path(hostname, port, issuer, serial).split('/') + assert r[0] == hostname + assert int(r[1]) == port + assert r[2] == issuer + assert r[3] == serial -@given(domains(), integers(1, 65535), text(), integers(0, 256**20), - datetimes(), datetimes()) +@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) def test_body(hostname, port, issuer, serial, not_before, not_after): - r = body(not_before, not_after, hostname, port, - hex(serial).removeprefix('0x'), issuer) - assert(r[-1][0] == 'dl') + r = body(not_before, not_after, hostname, port, serial, issuer) + assert r[-1][0] == 'dl' d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'), (v for k, v in r[-1][1:] if k == 'dd'))) - assert(d['Domain'] == hostname) - assert(d['Port'] == port) - assert(d['Issuer'] == issuer) - assert(int(d['Serial number'], 16) == serial) - assert(d['Valid from'] == not_before) - assert(d['Valid until'] == not_after) + assert d['Domain'] == hostname + assert d['Port'] == port + assert d['Issuer'] == from_base64(issuer.encode()).decode() + assert d['Serial number'] == serial + assert d['Valid from'] == not_before + assert d['Valid until'] == not_after + + +@given(urls(), domains(), ports(), + ca_names(), serials(), datetimes(), datetimes()) +def test_atom_entry(base_url, hostname, port, + issuer, serial, not_before, not_after): + cert = not_before, not_after, hostname, port, serial, issuer + r = str_from_xml(xml(entry(base_url, cert)), + 'unicode', short_empty_elements=False) + issuer_str = from_base64(issuer.encode()).decode() + url = urljoin(base_url, path(hostname, port, issuer, serial)) + assert r == f''' + + {escape(issuer_str)} + + +
+

TLS certificate information

+
+
Domain
+
{hostname}
+
Port
+
{port}
+
Issuer
+
{escape(issuer_str)}
+
Serial number
+
{serial}
+
Valid from
+
{not_before.isoformat()}
+
Valid until
+
{not_after.isoformat()}
+
+
+
+ {url} + + TLS cert for {hostname} will expire at {not_after} + {not_before.isoformat()} +
''' + + +@composite +def certificates(draw): + """Return a Hypothesis strategy for certificate summaries.""" + not_before = draw(datetimes()).isoformat() + not_after = draw(datetimes()).isoformat() + hostname = draw(domains()) + port = draw(ports()) + serial = draw(serials()) + issuer = draw(ca_names()) or '\0' + return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' + + +@contextmanager +def tmp_cert_file(lines): + cert_file = Path(mkstemp(text=True)[1]) + cert_file.write_text('\n'.join(lines)) + try: + yield cert_file + finally: + cert_file.unlink() + + +@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')), + lists(certificates())) +async def test_http_200(base_url, certs): + base_path = urlsplit(base_url).path + with tmp_cert_file(certs) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {base_path}\r\n'.encode()) + await writer.drain() + response = await reader.readuntil(b'\r\n') + assert response == b'HTTP/1.1 200 OK\r\n' + + +@composite +def two_urls(draw, constraint): + """Return a Hypothesis strategy for 2 URLs.""" + first = draw(urls()) + second = draw(urls().filter(partial(constraint, first))) + return first, second + + +@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path)) +async def test_http_404(url_and_url): + base_url, url = url_and_url + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {urlsplit(url).path}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 404 Not Found\r\n' + + +@given(urls(), text().filter(lambda method: not method.startswith('GET '))) +async def test_http_405(base_url, request): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'{request}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' -- cgit 1.4.1