about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-27 16:28:58 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-27 17:03:14 +0900
commitaec9f489eed05e01dfe358e71e14870dbbf0196b (patch)
tree12617a0561e3bc2d4749f5a88fe1197478345a08
parent655c4818e30a9eb1c2d40d977ee4b89b0b37a766 (diff)
downloadscadere-aec9f489eed05e01dfe358e71e14870dbbf0196b.tar.gz
Cover all branches in server code
-rw-r--r--tst/test_listen.py48
1 files changed, 41 insertions, 7 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py
index 86df3d6..573330b 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -18,11 +18,15 @@ from asyncio import open_connection, start_server
 from base64 import (urlsafe_b64decode as from_base64,
                     urlsafe_b64encode as base64)
 from contextlib import asynccontextmanager, contextmanager
+from copy import deepcopy
+from email.parser import BytesHeaderParser
 from functools import partial
 from pathlib import Path
+from string import ascii_letters
 from tempfile import mkstemp
 from urllib.parse import urljoin, urlsplit
-from xml.etree.ElementTree import tostring as str_from_xml
+from xml.etree.ElementTree import (XML, XMLParser, indent,
+                                   tostring as str_from_xml)
 from xml.sax.saxutils import escape
 
 from hypothesis import HealthCheck, given, settings
@@ -32,6 +36,9 @@ from hypothesis.provisional import domains, urls
 
 from scadere.listen import body, entry, handle, path, xml
 
+ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
+XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
+
 
 def ports():
     """Return a Hypothesis strategy for TCP ports."""
@@ -118,7 +125,8 @@ def certificates(draw):
     hostname = draw(domains())
     port = draw(ports())
     serial = draw(serials())
-    issuer = draw(ca_names())
+    # Free-formed UTF-8 could easily creates malformed XML.
+    issuer = base64(draw(text(ascii_letters)).encode()).decode()
     return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}'
 
 
@@ -158,6 +166,30 @@ async def connect(socket):
         writer.close()
 
 
+async def fetch_xml(socket, url, content_type):
+    """Fetch the content at the URL from the socket."""
+    header_parser = BytesHeaderParser()
+    xml_parser = XMLParser(encoding='utf-8')
+    async with connect(socket) as (reader, writer):
+        writer.write(f'GET {url}\r\n'.encode())
+        await writer.drain()
+        status = await reader.readuntil(b'\r\n')
+        assert status == b'HTTP/1.1 200 OK\r\n'
+        headers_bytes = await reader.readuntil(b'\r\n\r\n')
+        headers = header_parser.parsebytes(headers_bytes)
+        assert headers['Content-Type'] == content_type
+        content = await reader.read(int(headers['Content-Length']))
+        return XML(content.decode(), xml_parser)
+
+
+def equal_xml(a, b):
+    """Check if the two XML elements are equal."""
+    a_copy, b_copy = deepcopy(a), deepcopy(b)
+    indent(a_copy)
+    indent(b_copy)
+    return str_from_xml(a_copy).rstrip() == str_from_xml(b_copy).rstrip()
+
+
 @given(urls().filter(has_usual_path), lists(certificates(), min_size=1))
 async def test_http_200(base_url, certs):
     base_path = urlsplit(base_url).path
@@ -165,11 +197,13 @@ async def test_http_200(base_url, certs):
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
             socket, = server.sockets
-            async with connect(*socket.getsockname()) as (reader, writer):
-                writer.write(f'GET {base_path}\r\n'.encode())
-                await writer.drain()
-                response = await reader.readuntil(b'\r\n')
-                assert response == b'HTTP/1.1 200 OK\r\n'
+            feed = await fetch_xml(socket, base_path, 'application/atom+xml')
+            for feed_entry in feed.findall('entry', ATOM_NAMESPACES):
+                link = feed_entry.find('link', ATOM_NAMESPACES).attrib
+                assert link['rel'] == 'alternate'
+                page = await fetch_xml(socket, link['href'], link['type'])
+                assert equal_xml(feed_entry.find('.//dl', XHTML_NAMESPACES),
+                                 page.find('.//dl', XHTML_NAMESPACES))
 
 
 @given(data())