about summary refs log tree commit diff
path: root/tst/test_listen.py
diff options
context:
space:
mode:
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r--tst/test_listen.py39
1 files changed, 19 insertions, 20 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index ee95c3c..662c576 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -16,7 +16,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with scadere.  If not, see <https://www.gnu.org/licenses/>.
 
-from asyncio import TaskGroup, open_connection, start_server
+from asyncio import open_unix_connection, start_unix_server
 from contextlib import asynccontextmanager, contextmanager
 from copy import deepcopy
 from datetime import datetime
@@ -24,7 +24,8 @@ from email.parser import BytesHeaderParser
 from functools import partial
 from http import HTTPMethod
 from pathlib import Path
-from tempfile import mkstemp
+from shutil import rmtree
+from tempfile import mkdtemp, mkstemp
 from urllib.parse import urljoin, urlsplit
 from xml.etree.ElementTree import (XML, XMLParser, indent,
                                    tostring as str_from_xml)
@@ -151,7 +152,7 @@ def has_usual_path(url):
 @asynccontextmanager
 async def connect(socket):
     """Return a read-write stream for an asyncio TCP connection."""
-    reader, writer = await open_connection(*socket.getsockname()[:2])
+    reader, writer = await open_unix_connection(socket)
     try:
         yield reader, writer
     finally:
@@ -202,11 +203,15 @@ async def check_feed(socket, base_url):
                          page.find('.//dl', XHTML_NAMESPACES))
 
 
-async def check_server(sockets, func, *args):
-    """Test server listening for connections on sockets using func."""
-    async with TaskGroup() as group:
-        for socket in sockets:
-            group.create_task(func(socket, *args))
+async def check_server(handler, func, *args):
+    """Test request handler using func."""
+    d = Path(mkdtemp())
+    socket = d / 'sock'
+    try:
+        async with await start_unix_server(handler, socket):
+            await func(socket, *args)
+    finally:
+        rmtree(d)
 
 
 @given(urls().filter(is_base_url).filter(has_usual_path),
@@ -216,8 +221,7 @@ async def test_content(base_url, certs):
     base_path = urlsplit(base_url).path
     with tmp_cert_file(certs) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, check_feed, base_path)
+        await check_server(handler, check_feed, base_path)
 
 
 async def bad_request(socket, request):
@@ -232,8 +236,7 @@ async def bad_request(socket, request):
 async def test_incomplete_request(request):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, 'http://localhost')
-        async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, bad_request, request)
+        await check_server(handler, bad_request, request)
 
 
 async def not_found(socket, url):
@@ -252,8 +255,7 @@ async def test_unrecognized_url(drawer):
                       label='request URL')
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, not_found, urlsplit(url).path)
+        await check_server(handler, not_found, urlsplit(url).path)
 
 
 async def method_not_allowed(socket, method, url):
@@ -268,9 +270,7 @@ async def method_not_allowed(socket, method, url):
 async def test_unallowed_method(base_url, method):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, method_not_allowed,
-                               method.value, base_url)
+        await check_server(handler, method_not_allowed, method.value, base_url)
 
 
 async def unsupported_http_version(socket, url, version):
@@ -285,6 +285,5 @@ async def unsupported_http_version(socket, url, version):
 async def test_unsupported_http_version(base_url, version):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        async with await start_server(handler, 'localhost') as server:
-            await check_server(server.sockets, unsupported_http_version,
-                               base_url, version)
+        await check_server(handler, unsupported_http_version,
+                           base_url, version)