diff options
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r-- | tst/test_listen.py | 182 |
1 files changed, 161 insertions, 21 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py index d3a8052..f35ce82 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -14,35 +14,175 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. -from base64 import urlsafe_b64decode as from_base64 +from asyncio import open_connection, start_server +from base64 import (urlsafe_b64decode as from_base64, + urlsafe_b64encode as base64) +from contextlib import closing, contextmanager +from functools import partial +from pathlib import Path +from tempfile import mkstemp +from urllib.parse import urljoin, urlsplit +from xml.etree.ElementTree import tostring as str_from_xml +from xml.sax.saxutils import escape from hypothesis import given -from hypothesis.strategies import integers, datetimes, text -from hypothesis.provisional import domains +from hypothesis.strategies import (builds, composite, datetimes, + integers, lists, text) +from hypothesis.provisional import domains, urls -from scadere.listen import body, path +from scadere.listen import body, entry, handle, path, xml -@given(domains(), integers(1, 65535), text(), integers(0, 256**20)) +def ports(): + """Return a Hypothesis strategy for TCP ports.""" + return integers(1, 65535) + + +def serials(): + """Return a Hypothesis strategy for TLS serial number.""" + return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1)) + + +def ca_names(): + """Return a Hypothesis strategy for CA names.""" + return text().map(lambda name: base64(name.encode()).decode()) + + +@given(domains(), ports(), ca_names(), serials()) def test_path(hostname, port, issuer, serial): - r = path(hostname, port, issuer, hex(serial).removeprefix('0x')).split('/') - assert(r[0] == hostname) - assert(int(r[1]) == port) - assert(from_base64(r[2]).decode() == issuer) - assert(int(r[3], 16) == serial) + r = path(hostname, port, issuer, serial).split('/') + assert r[0] == hostname + assert int(r[1]) == port + assert r[2] == issuer + assert r[3] == serial -@given(domains(), integers(1, 65535), text(), integers(0, 256**20), - datetimes(), datetimes()) +@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes()) def test_body(hostname, port, issuer, serial, not_before, not_after): - r = body(not_before, not_after, hostname, port, - hex(serial).removeprefix('0x'), issuer) - assert(r[-1][0] == 'dl') + r = body(not_before, not_after, hostname, port, serial, issuer) + assert r[-1][0] == 'dl' d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'), (v for k, v in r[-1][1:] if k == 'dd'))) - assert(d['Domain'] == hostname) - assert(d['Port'] == port) - assert(d['Issuer'] == issuer) - assert(int(d['Serial number'], 16) == serial) - assert(d['Valid from'] == not_before) - assert(d['Valid until'] == not_after) + assert d['Domain'] == hostname + assert d['Port'] == port + assert d['Issuer'] == from_base64(issuer.encode()).decode() + assert d['Serial number'] == serial + assert d['Valid from'] == not_before + assert d['Valid until'] == not_after + + +@given(urls(), domains(), ports(), + ca_names(), serials(), datetimes(), datetimes()) +def test_atom_entry(base_url, hostname, port, + issuer, serial, not_before, not_after): + cert = not_before, not_after, hostname, port, serial, issuer + r = str_from_xml(xml(entry(base_url, cert)), + 'unicode', short_empty_elements=False) + issuer_str = from_base64(issuer.encode()).decode() + url = urljoin(base_url, path(hostname, port, issuer, serial)) + assert r == f'''<entry> + <author> + <name>{escape(issuer_str)}</name> + </author> + <content type="xhtml"> + <div xmlns="http://www.w3.org/1999/xhtml"> + <h1>TLS certificate information</h1> + <dl> + <dt>Domain</dt> + <dd>{hostname}</dd> + <dt>Port</dt> + <dd>{port}</dd> + <dt>Issuer</dt> + <dd>{escape(issuer_str)}</dd> + <dt>Serial number</dt> + <dd>{serial}</dd> + <dt>Valid from</dt> + <dd>{not_before.isoformat()}</dd> + <dt>Valid until</dt> + <dd>{not_after.isoformat()}</dd> + </dl> + </div> + </content> + <id>{url}</id> + <link rel="alternate" type="text/plain" href="{url}"></link> + <title>TLS cert for {hostname} will expire at {not_after}</title> + <updated>{not_before.isoformat()}</updated> +</entry>''' + + +@composite +def certificates(draw): + """Return a Hypothesis strategy for certificate summaries.""" + not_before = draw(datetimes()).isoformat() + not_after = draw(datetimes()).isoformat() + hostname = draw(domains()) + port = draw(ports()) + serial = draw(serials()) + issuer = draw(ca_names()) or '\0' + return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' + + +@contextmanager +def tmp_cert_file(lines): + cert_file = Path(mkstemp(text=True)[1]) + cert_file.write_text('\n'.join(lines)) + try: + yield cert_file + finally: + cert_file.unlink() + + +@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')), + lists(certificates())) +async def test_http_200(base_url, certs): + base_path = urlsplit(base_url).path + with tmp_cert_file(certs) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {base_path}\r\n'.encode()) + await writer.drain() + response = await reader.readuntil(b'\r\n') + assert response == b'HTTP/1.1 200 OK\r\n' + + +@composite +def two_urls(draw, constraint): + """Return a Hypothesis strategy for 2 URLs.""" + first = draw(urls()) + second = draw(urls().filter(partial(constraint, first))) + return first, second + + +@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path)) +async def test_http_404(url_and_url): + base_url, url = url_and_url + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'GET {urlsplit(url).path}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 404 Not Found\r\n' + + +@given(urls(), text().filter(lambda method: not method.startswith('GET '))) +async def test_http_405(base_url, request): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + server = await start_server(handler, 'localhost') + async with server: + socket, = server.sockets + reader, writer = await open_connection(*socket.getsockname()) + with closing(writer): + writer.write(f'{request}\r\n'.encode()) + await writer.drain() + response = await reader.read() + assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' |