diff options
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r-- | tst/test_listen.py | 98 |
1 files changed, 76 insertions, 22 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py index cc6a9a1..3862d9d 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -19,8 +19,10 @@ from asyncio import TaskGroup, open_connection, start_server from contextlib import asynccontextmanager, contextmanager from copy import deepcopy +from datetime import datetime from email.parser import BytesHeaderParser from functools import partial +from http import HTTPMethod from pathlib import Path from tempfile import mkstemp from urllib.parse import urljoin, urlsplit @@ -28,12 +30,14 @@ from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) from hypothesis import HealthCheck, given, settings -from hypothesis.strategies import (booleans, builds, composite, data, - datetimes, integers, lists, text) +from hypothesis.strategies import (booleans, builds, composite, + data, datetimes, from_type, + integers, lists, sampled_from, text) from hypothesis.provisional import domains, urls +from pytest import raises from scadere.check import base64_from_str -from scadere.listen import handle, is_subdomain, path, with_trailing_slash +from scadere.listen import handle, is_subdomain, path, with_trailing_slash, xml ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'} @@ -81,6 +85,17 @@ def test_is_subdomain(subject, objects): or not subject.removesuffix(obj).endswith('.')) +def xml_unsupported_type(child): + """Check if child is of a type supported by the XML constructor.""" + return not isinstance(child, (tuple, str, datetime)) + + +@given(text(), from_type(type).flatmap(from_type).filter(xml_unsupported_type)) +def test_xml_unsupported_type(tag, child): + with raises(AssertionError, match='Expected code to be unreachable'): + xml((tag, {}, child)) + + @composite def certificates(draw): """Return a Hypothesis strategy for certificate summaries.""" @@ -148,15 +163,22 @@ async def connect(socket): await writer.wait_closed() +async def write_request(writer, request): + """Write given request.""" + writer.write(request.encode()) + await writer.drain() + assert writer.can_write_eof() + writer.write_eof() + + async def fetch_xml(socket, url, content_type): """Fetch the content at the URL from the socket.""" header_parser = BytesHeaderParser() xml_parser = XMLParser(encoding='utf-8') async with connect(socket) as (reader, writer): - writer.write(f'GET {url}\r\n'.encode()) - await writer.drain() + await write_request(writer, f'GET {url} HTTP/2\r\n') status = await reader.readuntil(b'\r\n') - assert status == b'HTTP/1.1 200 OK\r\n' + assert status == b'HTTP/2 200 OK\r\n' headers_bytes = await reader.readuntil(b'\r\n\r\n') headers = header_parser.parsebytes(headers_bytes) assert headers['Content-Type'] == content_type @@ -176,11 +198,11 @@ def equal_xml(a, b): async def check_feed(socket, base_url): """Check the Atom feed at the given path and its entry pages.""" feed = await fetch_xml(socket, base_url, 'application/atom+xml') - for feed_entry in feed.findall('entry', ATOM_NAMESPACES): - link = feed_entry.find('link', ATOM_NAMESPACES).attrib + for entry in feed.findall('entry', ATOM_NAMESPACES): + link = entry.find('link', ATOM_NAMESPACES).attrib assert link['rel'] == 'alternate' page = await fetch_xml(socket, link['href'], link['type']) - assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES), + assert equal_xml(entry.find('.//dl', XHTML_NAMESPACES), page.find('.//dl', XHTML_NAMESPACES)) @@ -202,13 +224,28 @@ async def test_content(base_url, certs): await check_server(server.sockets, check_feed, base_path) +async def bad_request(socket, request): + """Expect from socket a HTTP response with status 400.""" + async with connect(socket) as (reader, writer): + await write_request(writer, request) + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 400 Bad Request\r\n' + + +@given(text().filter(lambda request: '\r\n' not in request)) +async def test_incomplete_request(request): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, 'http://localhost') + async with await start_server(handler, 'localhost') as server: + await check_server(server.sockets, bad_request, request) + + async def not_found(socket, url): - """Send GET request for URL and expect a 404 status from socket.""" + """Send GET request for URL and expect HTTP status 404 from socket.""" async with connect(socket) as (reader, writer): - writer.write(f'GET {url}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 404 Not Found\r\n' + await write_request(writer, f'GET {url} HTTP/1\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1 404 Not Found\r\n' @given(data()) @@ -223,18 +260,35 @@ async def test_unrecognized_url(drawer): await check_server(server.sockets, not_found, urlsplit(url).path) -async def method_not_allowed(socket, request): +async def method_not_allowed(socket, method, url): """Expect from socket a HTTP response with status 405.""" async with connect(socket) as (reader, writer): - writer.write(f'{request}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' + await write_request(writer, f'{method} {url} HTTP/1.1\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 405 Method Not Allowed\r\n' + + +@given(urls(), sampled_from(HTTPMethod).filter(lambda method: method != 'GET')) +async def test_unallowed_method(base_url, method): + with tmp_cert_file(()) as cert_file: + handler = partial(handle, cert_file, base_url) + async with await start_server(handler, 'localhost') as server: + await check_server(server.sockets, method_not_allowed, + method.value, base_url) + + +async def unsupported_http_version(socket, url, version): + """Expect from socket a HTTP response with status 505.""" + async with connect(socket) as (reader, writer): + await write_request(writer, f'GET {url} HTTP/{version}\r\n') + status = await reader.readuntil(b'\r\n') + assert status == b'HTTP/1.1 505 HTTP Version Not Supported\r\n' -@given(urls(), text().filter(lambda method: not method.startswith('GET '))) -async def test_unallowed_method(base_url, request): +@given(urls().filter(is_base_url).filter(has_usual_path), integers(10)) +async def test_unsupported_http_version(base_url, version): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, method_not_allowed, request) + await check_server(server.sockets, unsupported_http_version, + base_url, version) |