about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-29 17:14:48 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-29 17:15:14 +0900
commitd481c68fef4a78f757d78de92f1fad32ce0dd891 (patch)
tree179da0673bbc189c8df4c7d1ba6f4e2df96f0e91
parentb7168e182d1929978efb684f27826d306a30523d (diff)
downloadscadere-d481c68fef4a78f757d78de92f1fad32ce0dd891.tar.gz
Run test clients asynchronously
-rw-r--r--tst/test_listen.py65
1 files changed, 41 insertions, 24 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 437fb91..dc4dfd1 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -16,7 +16,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with scadere.  If not, see <https://www.gnu.org/licenses/>.
 
-from asyncio import open_connection, start_server
+from asyncio import TaskGroup, open_connection, start_server
 from base64 import (urlsafe_b64decode as from_base64,
                     urlsafe_b64encode as base64)
 from contextlib import asynccontextmanager, contextmanager
@@ -31,7 +31,7 @@ from xml.etree.ElementTree import (XML, XMLParser, indent,
                                    tostring as str_from_xml)
 from xml.sax.saxutils import escape
 
-from hypothesis import HealthCheck, example, given, settings
+from hypothesis import HealthCheck, given, settings
 from hypothesis.strategies import (builds, composite, data,
                                    datetimes, integers, lists, text)
 from hypothesis.provisional import domains, urls
@@ -211,23 +211,41 @@ def equal_xml(a, b):
     return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip()
 
 
+async def check_feed(socket, base_url):
+    """Check the Atom feed at the given path and its entry pages."""
+    feed = await fetch_xml(socket, base_url, 'application/atom+xml')
+    for feed_entry in feed.findall('entry', ATOM_NAMESPACES):
+        link = feed_entry.find('link', ATOM_NAMESPACES).attrib
+        assert link['rel'] == 'alternate'
+        page = await fetch_xml(socket, link['href'], link['type'])
+        assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES),
+                         page.find('.//dl', XHTML_NAMESPACES))
+
+
+async def check_server(sockets, func, *args):
+    """Test server listening for connections on sockets using func."""
+    async with TaskGroup() as group:
+        for socket in sockets:
+            group.create_task(func(socket, *args))
+
+
 @given(urls().filter(is_base_url).filter(has_usual_path),
        lists(certificates(), min_size=1))
-@settings(deadline=None)
 async def test_http_200(base_url, certs):
     base_path = urlsplit(base_url).path
     with tmp_cert_file(certs) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
-            for socket in server.sockets:
-                feed = await fetch_xml(socket, base_path,
-                                       'application/atom+xml')
-                for e in feed.findall('entry', ATOM_NAMESPACES):
-                    link = e.find('link', ATOM_NAMESPACES).attrib
-                    assert link['rel'] == 'alternate'
-                    page = await fetch_xml(socket, link['href'], link['type'])
-                    assert equal_xml(e.find('.//dl', XHTML_NAMESPACES),
-                                     page.find('.//dl', XHTML_NAMESPACES))
+            await check_server(server.sockets, check_feed, base_path)
+
+
+async def not_found(socket, url):
+    """Send GET request for URL and expect a 404 status from socket."""
+    async with connect(socket) as (reader, writer):
+        writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
+        await writer.drain()
+        response = await reader.readuntil(b'\r\n')
+        assert response == b'HTTP/1.1 404 Not Found\r\n'
 
 
 @given(data())
@@ -239,12 +257,16 @@ async def test_unrecognized_url(drawer):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
-            for socket in server.sockets:
-                async with connect(socket) as (reader, writer):
-                    writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
-                    await writer.drain()
-                    response = await reader.readuntil(b'\r\n')
-                    assert response == b'HTTP/1.1 404 Not Found\r\n'
+            await check_server(server.sockets, not_found, urlsplit(url).path)
+
+
+async def method_not_allowed(socket, request):
+    """Expect from socket a HTTP response with status 405."""
+    async with connect(socket) as (reader, writer):
+        writer.write(f'{request}\r\n'.encode())
+        await writer.drain()
+        response = await reader.readuntil(b'\r\n')
+        assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
 
 
 @given(urls(), text().filter(lambda method: not method.startswith('GET ')))
@@ -252,9 +274,4 @@ async def test_unallowed_method(base_url, request):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
-            for socket in server.sockets:
-                async with connect(socket) as (reader, writer):
-                    writer.write(f'{request}\r\n'.encode())
-                    await writer.drain()
-                    response = await reader.readuntil(b'\r\n')
-                    assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
+            await check_server(server.sockets, method_not_allowed, request)