diff options
author | Nguyễn Gia Phong <cnx@loang.net> | 2025-05-29 17:14:48 +0900 |
---|---|---|
committer | Nguyễn Gia Phong <cnx@loang.net> | 2025-05-29 17:15:14 +0900 |
commit | d481c68fef4a78f757d78de92f1fad32ce0dd891 (patch) | |
tree | 179da0673bbc189c8df4c7d1ba6f4e2df96f0e91 | |
parent | b7168e182d1929978efb684f27826d306a30523d (diff) | |
download | scadere-d481c68fef4a78f757d78de92f1fad32ce0dd891.tar.gz |
Run test clients asynchronously
-rw-r--r-- | tst/test_listen.py | 65 |
1 files changed, 41 insertions, 24 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py index 437fb91..dc4dfd1 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Affero General Public License # along with scadere. If not, see <https://www.gnu.org/licenses/>. -from asyncio import open_connection, start_server +from asyncio import TaskGroup, open_connection, start_server from base64 import (urlsafe_b64decode as from_base64, urlsafe_b64encode as base64) from contextlib import asynccontextmanager, contextmanager @@ -31,7 +31,7 @@ from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) from xml.sax.saxutils import escape -from hypothesis import HealthCheck, example, given, settings +from hypothesis import HealthCheck, given, settings from hypothesis.strategies import (builds, composite, data, datetimes, integers, lists, text) from hypothesis.provisional import domains, urls @@ -211,23 +211,41 @@ def equal_xml(a, b): return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip() +async def check_feed(socket, base_url): + """Check the Atom feed at the given path and its entry pages.""" + feed = await fetch_xml(socket, base_url, 'application/atom+xml') + for feed_entry in feed.findall('entry', ATOM_NAMESPACES): + link = feed_entry.find('link', ATOM_NAMESPACES).attrib + assert link['rel'] == 'alternate' + page = await fetch_xml(socket, link['href'], link['type']) + assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES), + page.find('.//dl', XHTML_NAMESPACES)) + + +async def check_server(sockets, func, *args): + """Test server listening for connections on sockets using func.""" + async with TaskGroup() as group: + for socket in sockets: + group.create_task(func(socket, *args)) + + @given(urls().filter(is_base_url).filter(has_usual_path), lists(certificates(), min_size=1)) -@settings(deadline=None) async def test_http_200(base_url, certs): base_path = urlsplit(base_url).path with tmp_cert_file(certs) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: - for socket in server.sockets: - feed = await fetch_xml(socket, base_path, - 'application/atom+xml') - for e in feed.findall('entry', ATOM_NAMESPACES): - link = e.find('link', ATOM_NAMESPACES).attrib - assert link['rel'] == 'alternate' - page = await fetch_xml(socket, link['href'], link['type']) - assert equal_xml(e.find('.//dl', XHTML_NAMESPACES), - page.find('.//dl', XHTML_NAMESPACES)) + await check_server(server.sockets, check_feed, base_path) + + +async def not_found(socket, url): + """Send GET request for URL and expect a 404 status from socket.""" + async with connect(socket) as (reader, writer): + writer.write(f'GET {urlsplit(url).path}\r\n'.encode()) + await writer.drain() + response = await reader.readuntil(b'\r\n') + assert response == b'HTTP/1.1 404 Not Found\r\n' @given(data()) @@ -239,12 +257,16 @@ async def test_unrecognized_url(drawer): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: - for socket in server.sockets: - async with connect(socket) as (reader, writer): - writer.write(f'GET {urlsplit(url).path}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 404 Not Found\r\n' + await check_server(server.sockets, not_found, urlsplit(url).path) + + +async def method_not_allowed(socket, request): + """Expect from socket a HTTP response with status 405.""" + async with connect(socket) as (reader, writer): + writer.write(f'{request}\r\n'.encode()) + await writer.drain() + response = await reader.readuntil(b'\r\n') + assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' @given(urls(), text().filter(lambda method: not method.startswith('GET '))) @@ -252,9 +274,4 @@ async def test_unallowed_method(base_url, request): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) async with await start_server(handler, 'localhost') as server: - for socket in server.sockets: - async with connect(socket) as (reader, writer): - writer.write(f'{request}\r\n'.encode()) - await writer.drain() - response = await reader.readuntil(b'\r\n') - assert response == b'HTTP/1.1 405 Method Not Allowed\r\n' + await check_server(server.sockets, method_not_allowed, request) |