about summary refs log tree commit diff
path: root/tst/test_listen.py
diff options
context:
space:
mode:
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r--tst/test_listen.py98
1 files changed, 76 insertions, 22 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index cc6a9a1..3862d9d 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -19,8 +19,10 @@
 from asyncio import TaskGroup, open_connection, start_server
 from contextlib import asynccontextmanager, contextmanager
 from copy import deepcopy
+from datetime import datetime
 from email.parser import BytesHeaderParser
 from functools import partial
+from http import HTTPMethod
 from pathlib import Path
 from tempfile import mkstemp
 from urllib.parse import urljoin, urlsplit
@@ -28,12 +30,14 @@ from xml.etree.ElementTree import (XML, XMLParser, indent,
                                    tostring as str_from_xml)
 
 from hypothesis import HealthCheck, given, settings
-from hypothesis.strategies import (booleans, builds, composite, data,
-                                   datetimes, integers, lists, text)
+from hypothesis.strategies import (booleans, builds, composite,
+                                   data, datetimes, from_type,
+                                   integers, lists, sampled_from, text)
 from hypothesis.provisional import domains, urls
+from pytest import raises
 
 from scadere.check import base64_from_str
-from scadere.listen import handle, is_subdomain, path, with_trailing_slash
+from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
 XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
@@ -81,6 +85,17 @@ def test_is_subdomain(subject, objects):
                     or not subject.removesuffix(obj).endswith('.'))
 
 
+def xml_unsupported_type(child):
+    """Check if child is of a type supported by the XML constructor."""
+    return not isinstance(child, (tuple, str, datetime))
+
+
+@given(text(), from_type(type).flatmap(from_type).filter(xml_unsupported_type))
+def test_xml_unsupported_type(tag, child):
+    with raises(AssertionError, match='Expected code to be unreachable'):
+        xml((tag, {}, child))
+
+
 @composite
 def certificates(draw):
     """Return a Hypothesis strategy for certificate summaries."""
@@ -148,15 +163,22 @@ async def connect(socket):
         await writer.wait_closed()
 
 
+async def write_request(writer, request):
+    """Write given request."""
+    writer.write(request.encode())
+    await writer.drain()
+    assert writer.can_write_eof()
+    writer.write_eof()
+
+
 async def fetch_xml(socket, url, content_type):
     """Fetch the content at the URL from the socket."""
     header_parser = BytesHeaderParser()
     xml_parser = XMLParser(encoding='utf-8')
     async with connect(socket) as (reader, writer):
-        writer.write(f'GET {url}\r\n'.encode())
-        await writer.drain()
+        await write_request(writer, f'GET {url} HTTP/2\r\n')
         status = await reader.readuntil(b'\r\n')
-        assert status == b'HTTP/1.1 200 OK\r\n'
+        assert status == b'HTTP/2 200 OK\r\n'
         headers_bytes = await reader.readuntil(b'\r\n\r\n')
         headers = header_parser.parsebytes(headers_bytes)
         assert headers['Content-Type'] == content_type
@@ -176,11 +198,11 @@ def equal_xml(a, b):
 async def check_feed(socket, base_url):
     """Check the Atom feed at the given path and its entry pages."""
     feed = await fetch_xml(socket, base_url, 'application/atom+xml')
-    for feed_entry in feed.findall('entry', ATOM_NAMESPACES):
-        link = feed_entry.find('link', ATOM_NAMESPACES).attrib
+    for entry in feed.findall('entry', ATOM_NAMESPACES):
+        link = entry.find('link', ATOM_NAMESPACES).attrib
         assert link['rel'] == 'alternate'
         page = await fetch_xml(socket, link['href'], link['type'])
-        assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES),
+        assert equal_xml(entry.find('.//dl', XHTML_NAMESPACES),
                          page.find('.//dl', XHTML_NAMESPACES))
 
 
@@ -202,13 +224,28 @@ async def test_content(base_url, certs):
             await check_server(server.sockets, check_feed, base_path)
 
 
+async def bad_request(socket, request):
+    """Expect from socket a HTTP response with status 400."""
+    async with connect(socket) as (reader, writer):
+        await write_request(writer, request)
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 400 Bad Request\r\n'
+
+
+@given(text().filter(lambda request: '\r\n' not in request))
+async def test_incomplete_request(request):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, 'http://localhost')
+        async with await start_server(handler, 'localhost') as server:
+            await check_server(server.sockets, bad_request, request)
+
+
 async def not_found(socket, url):
-    """Send GET request for URL and expect a 404 status from socket."""
+    """Send GET request for URL and expect HTTP status 404 from socket."""
     async with connect(socket) as (reader, writer):
-        writer.write(f'GET {url}\r\n'.encode())
-        await writer.drain()
-        response = await reader.readuntil(b'\r\n')
-        assert response == b'HTTP/1.1 404 Not Found\r\n'
+        await write_request(writer, f'GET {url} HTTP/1\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1 404 Not Found\r\n'
 
 
 @given(data())
@@ -223,18 +260,35 @@ async def test_unrecognized_url(drawer):
             await check_server(server.sockets, not_found, urlsplit(url).path)
 
 
-async def method_not_allowed(socket, request):
+async def method_not_allowed(socket, method, url):
     """Expect from socket a HTTP response with status 405."""
     async with connect(socket) as (reader, writer):
-        writer.write(f'{request}\r\n'.encode())
-        await writer.drain()
-        response = await reader.readuntil(b'\r\n')
-        assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
+        await write_request(writer, f'{method} {url} HTTP/1.1\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 405 Method Not Allowed\r\n'
+
+
+@given(urls(), sampled_from(HTTPMethod).filter(lambda method: method != 'GET'))
+async def test_unallowed_method(base_url, method):
+    with tmp_cert_file(()) as cert_file:
+        handler = partial(handle, cert_file, base_url)
+        async with await start_server(handler, 'localhost') as server:
+            await check_server(server.sockets, method_not_allowed,
+                               method.value, base_url)
+
+
+async def unsupported_http_version(socket, url, version):
+    """Expect from socket a HTTP response with status 505."""
+    async with connect(socket) as (reader, writer):
+        await write_request(writer, f'GET {url} HTTP/{version}\r\n')
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 505 HTTP Version Not Supported\r\n'
 
 
-@given(urls(), text().filter(lambda method: not method.startswith('GET ')))
-async def test_unallowed_method(base_url, request):
+@given(urls().filter(is_base_url).filter(has_usual_path), integers(10))
+async def test_unsupported_http_version(base_url, version):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, method_not_allowed, request)
+            await check_server(server.sockets, unsupported_http_version,
+                               base_url, version)