about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-26 23:23:37 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-26 23:23:37 +0900
commit83fc0b5427e1c43ed1f4391ec6cf6da96aac64a9 (patch)
tree3f0bbba7be7d5717acff63ab48c8cd9f1d3c06a6
parent053b6da35753267456a167bfa84b955781b7d986 (diff)
downloadscadere-83fc0b5427e1c43ed1f4391ec6cf6da96aac64a9.tar.gz
Let URL paths ending with /. be not found
-rw-r--r--tst/test_listen.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 8004380..7121ad1 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -17,7 +17,7 @@
 from asyncio import open_connection, start_server
 from base64 import (urlsafe_b64decode as from_base64,
                     urlsafe_b64encode as base64)
-from contextlib import closing, contextmanager
+from contextlib import asynccontextmanager, contextmanager
 from functools import partial
 from pathlib import Path
 from tempfile import mkstemp
@@ -132,8 +132,23 @@ def tmp_cert_file(lines):
         cert_file.unlink()
 
 
-@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')),
-       lists(certificates()))
+def has_usual_path(url):
+    """Check if the given URL path tends to mess with urljoin."""
+    url_path = urlsplit(url).path
+    return not (url_path.startswith('//') or url_path.endswith('/.'))
+
+
+@asynccontextmanager
+async def connect(*args, **kwargs):
+    """Return a read-write stream for an asyncio TCP connection."""
+    reader, writer = await open_connection(*args, **kwargs)
+    try:
+        yield reader, writer
+    finally:
+        writer.close()
+
+
+@given(urls().filter(has_usual_path), lists(certificates(), min_size=1))
 async def test_http_200(base_url, certs):
     base_path = urlsplit(base_url).path
     with tmp_cert_file(certs) as cert_file:
@@ -141,8 +156,7 @@ async def test_http_200(base_url, certs):
         server = await start_server(handler, 'localhost')
         async with server:
             socket, = server.sockets
-            reader, writer = await open_connection(*socket.getsockname())
-            with closing(writer):
+            async with connect(*socket.getsockname()) as (reader, writer):
                 writer.write(f'GET {base_path}\r\n'.encode())
                 await writer.drain()
                 response = await reader.readuntil(b'\r\n')
@@ -165,8 +179,7 @@ async def test_http_404(url_and_url):
         server = await start_server(handler, 'localhost')
         async with server:
             socket, = server.sockets
-            reader, writer = await open_connection(*socket.getsockname())
-            with closing(writer):
+            async with connect(*socket.getsockname()) as (reader, writer):
                 writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
                 await writer.drain()
                 response = await reader.read()
@@ -180,8 +193,7 @@ async def test_http_405(base_url, request):
         server = await start_server(handler, 'localhost')
         async with server:
             socket, = server.sockets
-            reader, writer = await open_connection(*socket.getsockname())
-            with closing(writer):
+            async with connect(*socket.getsockname()) as (reader, writer):
                 writer.write(f'{request}\r\n'.encode())
                 await writer.drain()
                 response = await reader.read()