about summary refs log tree commit diff
path: root/tst/test_listen.py
diff options
context:
space:
mode:
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r--tst/test_listen.py182
1 files changed, 161 insertions, 21 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index d3a8052..f35ce82 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -14,35 +14,175 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-from base64 import urlsafe_b64decode as from_base64
+from asyncio import open_connection, start_server
+from base64 import (urlsafe_b64decode as from_base64,
+                    urlsafe_b64encode as base64)
+from contextlib import closing, contextmanager
+from functools import partial
+from pathlib import Path
+from tempfile import mkstemp
+from urllib.parse import urljoin, urlsplit
+from xml.etree.ElementTree import tostring as str_from_xml
+from xml.sax.saxutils import escape
 
 from hypothesis import given
-from hypothesis.strategies import integers, datetimes, text
-from hypothesis.provisional import domains
+from hypothesis.strategies import (builds, composite, datetimes,
+                                   integers, lists, text)
+from hypothesis.provisional import domains, urls
 
-from scadere.listen import body, path
+from scadere.listen import body, entry, handle, path, xml
 
 
-@given(domains(), integers(1, 65535), text(), integers(0, 256**20))
+def ports():
+    """Return a Hypothesis strategy for TCP ports."""
+    return integers(1, 65535)
+
+
+def serials():
+    """Return a Hypothesis strategy for TLS serial number."""
+    return builds(lambda n: hex(n).removeprefix('0x'), integers(0, 256**20-1))
+
+
+def ca_names():
+    """Return a Hypothesis strategy for CA names."""
+    return text().map(lambda name: base64(name.encode()).decode())
+
+
+@given(domains(), ports(), ca_names(), serials())
 def test_path(hostname, port, issuer, serial):
-    r = path(hostname, port, issuer, hex(serial).removeprefix('0x')).split('/')
-    assert(r[0] == hostname)
-    assert(int(r[1]) == port)
-    assert(from_base64(r[2]).decode() == issuer)
-    assert(int(r[3], 16) == serial)
+    r = path(hostname, port, issuer, serial).split('/')
+    assert r[0] == hostname
+    assert int(r[1]) == port
+    assert r[2] == issuer
+    assert r[3] == serial
 
 
-@given(domains(), integers(1, 65535), text(), integers(0, 256**20),
-       datetimes(), datetimes())
+@given(domains(), ports(), ca_names(), serials(), datetimes(), datetimes())
 def test_body(hostname, port, issuer, serial, not_before, not_after):
-    r = body(not_before, not_after, hostname, port,
-             hex(serial).removeprefix('0x'), issuer)
-    assert(r[-1][0] == 'dl')
+    r = body(not_before, not_after, hostname, port, serial, issuer)
+    assert r[-1][0] == 'dl'
     d = dict(zip((v for k, v in r[-1][1:] if k == 'dt'),
                  (v for k, v in r[-1][1:] if k == 'dd')))
-    assert(d['Domain'] == hostname)
-    assert(d['Port'] == port)
-    assert(d['Issuer'] == issuer)
-    assert(int(d['Serial number'], 16) == serial)
-    assert(d['Valid from'] == not_before)
-    assert(d['Valid until'] == not_after)
+    assert d['Domain'] == hostname
+    assert d['Port'] == port
+    assert d['Issuer'] == from_base64(issuer.encode()).decode()
+    assert d['Serial number'] == serial
+    assert d['Valid from'] == not_before
+    assert d['Valid until'] == not_after
+
+
+@given(urls(), domains(), ports(),
+       ca_names(), serials(), datetimes(), datetimes())
+def test_atom_entry(base_url, hostname, port,
+                    issuer, serial, not_before, not_after):
+    cert = not_before, not_after, hostname, port, serial, issuer
+    r = str_from_xml(xml(entry(base_url, cert)),
+                     'unicode', short_empty_elements=False)
+    issuer_str = from_base64(issuer.encode()).decode()
+    url = urljoin(base_url, path(hostname, port, issuer, serial))
+    assert r == f'''<entry>
+  <author>
+    <name>{escape(issuer_str)}</name>
+  </author>
+  <content type="xhtml">
+    <div xmlns="http://www.w3.org/1999/xhtml">
+      <h1>TLS certificate information</h1>
+      <dl>
+        <dt>Domain</dt>
+        <dd>{hostname}</dd>
+        <dt>Port</dt>
+        <dd>{port}</dd>
+        <dt>Issuer</dt>
+        <dd>{escape(issuer_str)}</dd>
+        <dt>Serial number</dt>
+        <dd>{serial}</dd>
+        <dt>Valid from</dt>
+        <dd>{not_before.isoformat()}</dd>
+        <dt>Valid until</dt>
+        <dd>{not_after.isoformat()}</dd>
+      </dl>
+    </div>
+  </content>
+  <id>{url}</id>
+  <link rel="alternate" type="text/plain" href="{url}"></link>
+  <title>TLS cert for {hostname} will expire at {not_after}</title>
+  <updated>{not_before.isoformat()}</updated>
+</entry>'''
+
+
+@composite
+def certificates(draw):
+    """Return a Hypothesis strategy for certificate summaries."""
+    not_before = draw(datetimes()).isoformat()
+    not_after = draw(datetimes()).isoformat()
+    hostname = draw(domains())
+    port = draw(ports())
+    serial = draw(serials())
+    issuer = draw(ca_names()) or '\0'
+    return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'
+
+
+@contextmanager
+def tmp_cert_file(lines):
+    cert_file = Path(mkstemp(text=True)[1])
+    cert_file.write_text('\n'.join(lines))
+    try:
+        yield cert_file
+    finally:
+        cert_file.unlink()
+
+
+@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')),
+       lists(certificates()))
+async def test_http_200(base_url, certs):
+    base_path = urlsplit(base_url).path
+    with tmp_cert_file(certs) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'GET {base_path}\r\n'.encode())
+                await writer.drain()
+                response = await reader.readuntil(b'\r\n')
+                assert response == b'HTTP/1.1 200 OK\r\n'
+
+
+@composite
+def two_urls(draw, constraint):
+    """Return a Hypothesis strategy for 2 URLs."""
+    first = draw(urls())
+    second = draw(urls().filter(partial(constraint, first)))
+    return first, second
+
+
+@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path))
+async def test_http_404(url_and_url):
+    base_url, url = url_and_url
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
+                await writer.drain()
+                response = await reader.read()
+                assert response == b'HTTP/1.1 404 Not Found\r\n'
+
+
+@given(urls(), text().filter(lambda method: not method.startswith('GET ')))
+async def test_http_405(base_url, request):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        server = await start_server(handler, 'localhost')
+        async with server:
+            socket, = server.sockets
+            reader, writer = await open_connection(*socket.getsockname())
+            with closing(writer):
+                writer.write(f'{request}\r\n'.encode())
+                await writer.drain()
+                response = await reader.read()
+                assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'