diff options
author | Nguyễn Gia Phong <cnx@loang.net> | 2025-05-26 17:45:48 +0900 |
---|---|---|
committer | Nguyễn Gia Phong <cnx@loang.net> | 2025-05-26 17:45:48 +0900 |
commit | 56a032568443bdf85dd37df5f6716b3475626d6a (patch) | |
tree | 487de1b9bff07ede605d5321faa1f09706e3bd00 | |
parent | b37d71bca632c1e29a3402fbaf69a14843eab8f2 (diff) | |
download | scadere-56a032568443bdf85dd37df5f6716b3475626d6a.tar.gz |
Fix handling of base URL
-rw-r--r-- | pyproject.toml | 3 | ||||
-rw-r--r-- | src/scadere/listen.py | 24 | ||||
-rw-r--r-- | tst/test_listen.py | 182 |
3 files changed, 178 insertions, 31 deletions
diff --git a/pyproject.toml b/pyproject.toml index 4104f67..244eba7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,10 @@ urls = { Source = 'https://trong.loang.net/scadere' } scripts = { scadere = 'scadere.__main__:main' } [tool.pytest.ini_options] +asyncio_mode = 'auto' +asyncio_default_fixture_loop_scope = 'function' testpaths = [ 'tst' ] +verbosity_assertions = 2 [tool.coverage.run] branch = true diff --git a/src/scadere/listen.py b/src/scadere/listen.py index aa80c32..0777f4a 100644 --- a/src/scadere/listen.py +++ b/src/scadere/listen.py @@ -21,7 +21,7 @@ from functools import partial from urllib.parse import parse_qs, urljoin, urlsplit from xml.etree.ElementTree import (Element as xml_element, SubElement as xml_subelement, - indent, tostring as xml_to_string) + indent, tostring as str_from_xml) from . import __version__ @@ -84,13 +84,14 @@ async def handle(certs, base_url, reader, writer): """Handle HTTP request.""" summaries = tuple(cert.rstrip().split(maxsplit=5) for cert in certs.read_text().splitlines()) - lookup = {f'/{path(hostname, port, issuer, serial)}': + lookup = {urlsplit(urljoin(base_url, + path(hostname, port, issuer, serial))).path: (not_before, not_after, hostname, port, serial, issuer) for not_before, not_after, hostname, port, serial, issuer in summaries} request = await reader.readuntil(b'\r\n') url = request.removeprefix(b'GET ').rsplit(b' HTTP/', 1)[0] - url_parts = urlsplit(url.decode()) + url_parts = urlsplit(urljoin(base_url, url.decode())) domains = tuple(parse_qs(url_parts.query).get('domain', [''])) if not request.startswith(b'GET '): @@ -99,7 +100,9 @@ async def handle(certs, base_url, reader, writer): writer.close() await writer.wait_closed() return - elif url_parts.path == '/': # Atom feed + elif url.startswith(b'//'): # urljoin goes haywire + writer.write(b'HTTP/1.1 404 Not Found\r\n') + elif url_parts.path == urlsplit(base_url).path: # Atom feed writer.write(b'HTTP/1.1 200 OK\r\n') writer.write(b'Content-Type: application/atom+xml\r\n') feed = xml(('feed', {'xmlns': 'http://www.w3.org/2005/Atom'}, @@ -112,9 +115,10 @@ async def handle(certs, base_url, reader, writer): 'version': __version__}, 'Scadere'), *(entry(base_url, cert) - for cert in summaries if cert[2].endswith(domains)))) - content = xml_to_string(feed, 'unicode', xml_declaration=True, - default_namespace=None).encode() + for cert in lookup.values() + if cert[2].endswith(domains)))) + content = str_from_xml(feed, 'unicode', xml_declaration=True, + default_namespace=None).encode() writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode()) writer.write(content) elif url_parts.path in lookup: # accessible Atom entry's link/ID @@ -134,8 +138,8 @@ async def handle(certs, base_url, reader, writer): ('title', f'TLS certificate - {hostname}:{port}')), ('body', *body(not_before, not_after, hostname, port, serial, issuer)))) - content = xml_to_string(page, 'unicode', xml_declaration=True, - default_namespace=None).encode() + content = str_from_xml(page, 'unicode', xml_declaration=True, + default_namespace=None).encode() writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode()) writer.write(content) else: @@ -145,7 +149,7 @@ async def handle(certs, base_url, reader, writer): await writer.wait_closed() -async def listen(certs, base_url, host, port): +async def listen(certs, base_url, host, port): # pragma: no cover """Serve HTTP server for TLS certificate expirations' Atom feed.""" server = await start_server(partial(handle, certs, base_url), host, port) async with server: diff --git a/tst/test_listen.py b/tst/test_listen.py index d3a8052..f35ce82 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -14,35 +14,175 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. -from base64 import urlsafe_b64decode as from_base64 +from asyncio import open_connection, start_server +from base64 import (urlsafe_b64decode as from_base64, + urlsafe_b64encode as base64) +from contextlib import closing, contextmanager +from functools import partial +from pathlib import Path +from tempfile import mkstemp +from urllib.parse import urljoin, urlsplit +from xml.etree.ElementTree import tostring as str_from_xml +from xml.sax.saxutils import escape from hypothesis import given -from hypothesis.strategies import integers, datetimes, text -from hypothesis.provisional import domains +from hypothesis.strategies import (builds, composite, datetimes, + integers, lists, text) +from hypothesis.provisional import domains, urls -from scadere.listen import body, path +from scadere.listen import body, entry, handle, path, xml -@given(domains(), integers(1, 65535), text(), integers(0, 256**20)) +def ports(): + """Return a Hypothesis strategy for TCP ports.""" + return integers(1, 65535) + + +def serials(): + """Return a Hypothesis strategy for TLS serial number.""" + return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) + + +def ca_names(): + """Return a Hypothesis strategy for CA names.""" + return text().map(lambda name: base64(name.encode()).decode()) + + +@given(domains(), ports(), ca_names(), serials()) def test_path(hostname, port, issuer, serial): - r = path(hostname, port, issuer, hex(serial).removeprefix('0x')).split('/') - assert(r[0] == hostname) - assert(int(r[1]) == port) - assert(from_base64(r[2]).decode() == issuer) - assert(int(r[3], 16) == serial) + r = path(hostname, port, issuer, serial).split('/') + assert r[0] == hostname + assert int(r[1]) == port + assert r[2] == issuer + assert r[3] == serial -@given(domains(), integers(1, 65535), text(), integers(0, 256**20), - datetimes(), datetimes()) +@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) def test_body(hostname, port, issuer, serial, not_before, not_after): - r = body(not_before, not_after, hostname, port, - hex(serial).removeprefix('0x'), issuer) - assert(r[-1][0] == 'dl') + r = body(not_before, not_after, hostname, port, serial, issuer) + assert r[-1][0] == 'dl' d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'), (v for k, v in r[-1][1:] if k == 'dd'))) - assert(d['Domain'] == hostname) - assert(d['Port'] == port) - assert(d['Issuer'] == issuer) - assert(int(d['Serial number'], 16) == serial) - assert(d['Valid from'] == not_before) - assert(d['Valid until'] == not_after) + assert d['Domain'] == hostname + assert d['Port'] == port + assert d['Issuer'] == from_base64(issuer.encode()).decode() + assert d['Serial number'] == serial + assert d['Valid from'] == not_before + assert d['Valid until'] == not_after + + +@given(urls(), domains(), ports(), + ca_names(), serials(), datetimes(), datetimes()) +def test_atom_entry(base_url, hostname, port, + issuer, serial, not_before, not_after): + cert = not_before, not_after, hostname, port, serial, issuer + r = str_from_xml(xml(entry(base_url, cert)), + 'unicode', short_empty_elements=False) + issuer_str = from_base64(issuer.encode()).decode() + url = urljoin(base_url, path(hostname, port, issuer, serial)) + assert r == f'''<entry> + <author> + <name>{escape(issuer_str)}</name> + </author> + <content type="xhtml"> + <div xmlns="http://www.w3.org/1999/xhtml"> + <h1>TLS certificate information</h1> + <dl> + <dt>Domain</dt> + <dd>{hostname}</dd> + <dt>Port</dt> + <dd>{port}</dd> + <dt>Issuer</dt> + <dd>{escape(issuer_str)}</dd> + <dt>Serial number</dt> + <dd>{serial}</dd> + <dt>Valid from</dt> + <dd>{not_before.isoformat()}</dd> + <dt>Valid until</dt> + <dd>{not_after.isoformat()}</dd> + </dl> + </div> + </content> + <id>{url}</id> + <link rel="alternate" type="text/plain" href="{url}"></link> + <title>TLS cert for {hostname} will expire at {not_after}</title> + <updated>{not_before.isoformat()}</updated> +</entry>''' + + +@composite +def certificates(draw): + """Return a Hypothesis strategy for certificate summaries.""" + not_before = draw(datetimes()).isoformat() + not_after = draw(datetimes()).isoformat() + hostname = draw(domains()) + port = draw(ports()) + serial = draw(serials()) + issuer = draw(ca_names()) or '\0' + return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' + + +@contextmanager +def tmp_cert_file(lines): + cert_file = Path(mkstemp(text=True)[1]) + cert_file.write_text('\n'.join(lines)) + try: + yield cert_file + finally: + cert_file.unlink() + + +@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')), + lists(certificates())) +async def test_http_200(base_url, certs): + base_path = urlsplit(base_url).path + with tmp_cert_file(certs) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {base_path}\r\n'.encode()) + await writer.drain() + response = await reader.readuntil(b'\r\n') + assert response == b'HTTP/1.1 200 OK\r\n' + + +@composite +def two_urls(draw, constraint): + """Return a Hypothesis strategy for 2 URLs.""" + first = draw(urls()) + second = draw(urls().filter(partial(constraint, first))) + return first, second + + +@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path)) +async def test_http_404(url_and_url): + base_url, url = url_and_url + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {urlsplit(url).path}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 404 Not Found\r\n' + + +@given(urls(), text().filter(lambda method: not method.startswith('GET '))) +async def test_http_405(base_url, request): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'{request}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' |