about summary refs log tree commit diff
path: root/tst
diff options
context:
space:
mode:
Diffstat (limited to 'tst')
-rw-r--r--tst/test_listen.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 6ef7645..437fb91 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -142,6 +142,23 @@ def tmp_cert_file(lines):
         cert_file.unlink()
 
 
+def is_base_url(url):
+    """Check if the given URL has a trailing slash.
+
+    The parser for command-line arguments
+    enforces this property for base URLs.
+    """
+    return url.endswith('/')
+
+
+@given(urls())
+def test_with_trailing_slash(url):
+    if is_base_url(url):
+        assert with_trailing_slash(url) == url
+    else:
+        assert with_trailing_slash(url) == f'{url}/'
+
+
 def is_root(base_url, url):
     """Check if the given URL points to the same base URL.
 
@@ -150,8 +167,7 @@ def is_root(base_url, url):
     """
     base_path = urlsplit(base_url).path
     url_path = urlsplit(url).path
-    return (urlsplit(urljoin(base_url, url_path)).path == base_path
-            and not url_path.startswith('//'))
+    return urlsplit(urljoin(base_url, url_path)).path == base_path
 
 
 def has_usual_path(url):
@@ -195,7 +211,9 @@ def equal_xml(a, b):
     return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip()
 
 
-@given(urls().filter(has_usual_path), lists(certificates(), min_size=1))
+@given(urls().filter(is_base_url).filter(has_usual_path),
+       lists(certificates(), min_size=1))
+@settings(deadline=None)
 async def test_http_200(base_url, certs):
     base_path = urlsplit(base_url).path
     with tmp_cert_file(certs) as cert_file:
@@ -212,11 +230,10 @@ async def test_http_200(base_url, certs):
                                      page.find('.//dl', XHTML_NAMESPACES))
 
 
-@example(type('//', (), {'draw': lambda *a, **kw: 'https://a.example//b'}))
 @given(data())
 @settings(suppress_health_check=[HealthCheck.too_slow])
 async def test_unrecognized_url(drawer):
-    base_url = drawer.draw(urls(), label='base URL')
+    base_url = drawer.draw(urls().filter(is_base_url), label='base URL')
     url = drawer.draw(urls().filter(lambda url: not is_root(base_url, url)),
                       label='request URL')
     with tmp_cert_file(()) as cert_file:
@@ -241,11 +258,3 @@ async def test_unallowed_method(base_url, request):
                     await writer.drain()
                     response = await reader.readuntil(b'\r\n')
                     assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'
-
-
-@given(urls())
-def test_with_trailing_slash(base_url):
-    if base_url.endswith('/'):
-        assert with_trailing_slash(base_url) == base_url
-    else:
-        assert with_trailing_slash(base_url) == f'{base_url}/'