diff options
-rw-r--r-- | tst/test_listen.py | 39 |
1 files changed, 19 insertions, 20 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py index ee95c3c..662c576 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Affero General Public License # along with scadere. If not, see <https://www.gnu.org/licenses/>. -from asyncio import TaskGroup, open_connection, start_server +from asyncio import open_unix_connection, start_unix_server from contextlib import asynccontextmanager, contextmanager from copy import deepcopy from datetime import datetime @@ -24,7 +24,8 @@ from email.parser import BytesHeaderParser from functools import partial from http import HTTPMethod from pathlib import Path -from tempfile import mkstemp +from shutil import rmtree +from tempfile import mkdtemp, mkstemp from urllib.parse import urljoin, urlsplit from xml.etree.ElementTree import (XML, XMLParser, indent, tostring as str_from_xml) @@ -151,7 +152,7 @@ def has_usual_path(url): @asynccontextmanager async def connect(socket): """Return a read-write stream for an asyncio TCP connection.""" - reader, writer = await open_connection(*socket.getsockname()[:2]) + reader, writer = await open_unix_connection(socket) try: yield reader, writer finally: @@ -202,11 +203,15 @@ async def check_feed(socket, base_url): page.find('.//dl', XHTML_NAMESPACES)) -async def check_server(sockets, func, *args): - """Test server listening for connections on sockets using func.""" - async with TaskGroup() as group: - for socket in sockets: - group.create_task(func(socket, *args)) +async def check_server(handler, func, *args): + """Test request handler using func.""" + d = Path(mkdtemp()) + socket = d / 'sock' + try: + async with await start_unix_server(handler, socket): + await func(socket, *args) + finally: + rmtree(d) @given(urls().filter(is_base_url).filter(has_usual_path), @@ -216,8 +221,7 @@ async def test_content(base_url, certs): base_path = urlsplit(base_url).path with tmp_cert_file(certs) as cert_file: handler = partial(handle, cert_file, base_url) - async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, check_feed, base_path) + await check_server(handler, check_feed, base_path) async def bad_request(socket, request): @@ -232,8 +236,7 @@ async def bad_request(socket, request): async def test_incomplete_request(request): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, 'http://localhost') - async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, bad_request, request) + await check_server(handler, bad_request, request) async def not_found(socket, url): @@ -252,8 +255,7 @@ async def test_unrecognized_url(drawer): label='request URL') with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) - async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, not_found, urlsplit(url).path) + await check_server(handler, not_found, urlsplit(url).path) async def method_not_allowed(socket, method, url): @@ -268,9 +270,7 @@ async def method_not_allowed(socket, method, url): async def test_unallowed_method(base_url, method): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) - async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, method_not_allowed, - method.value, base_url) + await check_server(handler, method_not_allowed, method.value, base_url) async def unsupported_http_version(socket, url, version): @@ -285,6 +285,5 @@ async def unsupported_http_version(socket, url, version): async def test_unsupported_http_version(base_url, version): with tmp_cert_file(()) as cert_file: handler = partial(handle, cert_file, base_url) - async with await start_server(handler, 'localhost') as server: - await check_server(server.sockets, unsupported_http_version, - base_url, version) + await check_server(handler, unsupported_http_version, + base_url, version) |