about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-27 16:27:34 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-27 17:03:14 +0900
commit655c4818e30a9eb1c2d40d977ee4b89b0b37a766 (patch)
treeb97b75d76002492c2530c1fefef4fb3b39b5ed60
parent6b1a96e7726aadfe00704134b60888048a96c212 (diff)
downloadscadere-655c4818e30a9eb1c2d40d977ee4b89b0b37a766.tar.gz
Fix test condition for 404 responses
-rw-r--r--tst/test_listen.py56
1 files changed, 28 insertions, 28 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index dc2011b..86df3d6 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -25,9 +25,9 @@ from urllib.parse import urljoin, urlsplit
 from xml.etree.ElementTree import tostring as str_from_xml
 from xml.sax.saxutils import escape
 
-from hypothesis import given
-from hypothesis.strategies import (builds, composite, datetimes,
-                                   integers, lists, text)
+from hypothesis import HealthCheck, given, settings
+from hypothesis.strategies import (builds, composite, data,
+                                   datetimes, integers, lists, text)
 from hypothesis.provisional import domains, urls
 
 from scadere.listen import body, entry, handle, path, xml
@@ -132,16 +132,26 @@ def tmp_cert_file(lines):
         cert_file.unlink()
 
 
+def is_root(base_url, url):
+    """Check if the given URL points to the same base URL.
+
+    Paths starting with // are excluded because urljoin
+    sometimes confuse them with URL scheme.
+    """
+    url_path = urlsplit(url).path
+    return (urlsplit(urljoin(base_url, url_path)).path == url_path
+            and not url_path.startswith('//'))
+
+
 def has_usual_path(url):
     """Check if the given URL path tends to mess with urljoin."""
-    url_path = urlsplit(url).path
-    return not (url_path.startswith('//') or url_path.endswith('/.'))
+    return is_root(url, url)
 
 
 @asynccontextmanager
-async def connect(*args, **kwargs):
+async def connect(socket):
     """Return a read-write stream for an asyncio TCP connection."""
-    reader, writer = await open_connection(*args, **kwargs)
+    reader, writer = await open_connection(*socket.getsockname())
     try:
         yield reader, writer
     finally:
@@ -153,8 +163,7 @@ async def test_http_200(base_url, certs):
     base_path = urlsplit(base_url).path
     with tmp_cert_file(certs) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        server = await start_server(handler, 'localhost')
-        async with server:
+        async with await start_server(handler, 'localhost') as server:
             socket, = server.sockets
             async with connect(*socket.getsockname()) as (reader, writer):
                 writer.write(f'GET {base_path}\r\n'.encode())
@@ -163,23 +172,16 @@ async def test_http_200(base_url, certs):
                 assert response == b'HTTP/1.1 200 OK\r\n'
 
 
-@composite
-def two_urls(draw, constraint):
-    """Return a Hypothesis strategy for 2 URLs."""
-    first = draw(urls())
-    second = draw(urls().filter(partial(constraint, first)))
-    return first, second
-
-
-@given(two_urls(lambda a, b: urlsplit(a).path != urlsplit(b).path))
-async def test_http_404(url_and_url):
-    base_url, url = url_and_url
+@given(data())
+@settings(suppress_health_check=[HealthCheck.too_slow])
+async def test_http_404(drawer):
+    base_url = drawer.draw(urls(), label='base URL')
+    url = drawer.draw(urls().filter(lambda url: not is_root(base_url, url)),
+                      label='request URL')
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        server = await start_server(handler, 'localhost')
-        async with server:
-            socket, = server.sockets
-            async with connect(*socket.getsockname()) as (reader, writer):
+        async with await start_server(handler, 'localhost') as server:
+            async with connect(*server.sockets) as (reader, writer):
                 writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
                 await writer.drain()
                 response = await reader.read()
@@ -190,10 +192,8 @@ async def test_http_404(url_and_url):
 async def test_http_405(base_url, request):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
-        server = await start_server(handler, 'localhost')
-        async with server:
-            socket, = server.sockets
-            async with connect(*socket.getsockname()) as (reader, writer):
+        async with await start_server(handler, 'localhost') as server:
+            async with connect(*server.sockets) as (reader, writer):
                 writer.write(f'{request}\r\n'.encode())
                 await writer.drain()
                 response = await reader.read()