aboutsummaryrefslogtreecommitdiff
path: root/tst
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-26 23:23:37 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-26 23:23:37 +0900
commit83fc0b5427e1c43ed1f4391ec6cf6da96aac64a9 (patch)
tree3f0bbba7be7d5717acff63ab48c8cd9f1d3c06a6 /tst
parent053b6da35753267456a167bfa84b955781b7d986 (diff)
downloadscadere-83fc0b5427e1c43ed1f4391ec6cf6da96aac64a9.tar.gz
Let URL paths ending with /. be not found
Diffstat (limited to 'tst')
-rw-r--r--tst/test_listen.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 8004380..7121ad1 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -17,7 +17,7 @@
from asyncio import open_connection, start_server
from base64 import (urlsafe_b64decode as from_base64,
urlsafe_b64encode as base64)
-from contextlib import closing, contextmanager
+from contextlib import asynccontextmanager, contextmanager
from functools import partial
from pathlib import Path
from tempfile import mkstemp
@@ -132,8 +132,23 @@ def tmp_cert_file(lines):
cert_file.unlink()
-@given(urls().filter(lambda url: not urlsplit(url).path.startswith('//')),
- lists(certificates()))
+def has_usual_path(url):
+ """Check if the given URL path tends to mess with urljoin."""
+ url_path = urlsplit(url).path
+ return not (url_path.startswith('//') or url_path.endswith('/.'))
+
+
+@asynccontextmanager
+async def connect(*args, **kwargs):
+ """Return a read-write stream for an asyncio TCP connection."""
+ reader, writer = await open_connection(*args, **kwargs)
+ try:
+ yield reader, writer
+ finally:
+ writer.close()
+
+
+@given(urls().filter(has_usual_path), lists(certificates(), min_size=1))
async def test_http_200(base_url, certs):
base_path = urlsplit(base_url).path
with tmp_cert_file(certs) as cert_file:
@@ -141,8 +156,7 @@ async def test_http_200(base_url, certs):
server = await start_server(handler, 'localhost')
async with server:
socket, = server.sockets
- reader, writer = await open_connection(*socket.getsockname())
- with closing(writer):
+ async with connect(*socket.getsockname()) as (reader, writer):
writer.write(f'GET {base_path}\r\n'.encode())
await writer.drain()
response = await reader.readuntil(b'\r\n')
@@ -165,8 +179,7 @@ async def test_http_404(url_and_url):
server = await start_server(handler, 'localhost')
async with server:
socket, = server.sockets
- reader, writer = await open_connection(*socket.getsockname())
- with closing(writer):
+ async with connect(*socket.getsockname()) as (reader, writer):
writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
await writer.drain()
response = await reader.read()
@@ -180,8 +193,7 @@ async def test_http_405(base_url, request):
server = await start_server(handler, 'localhost')
async with server:
socket, = server.sockets
- reader, writer = await open_connection(*socket.getsockname())
- with closing(writer):
+ async with connect(*socket.getsockname()) as (reader, writer):
writer.write(f'{request}\r\n'.encode())
await writer.drain()
response = await reader.read()