about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNguyễn Gia Phong <cnx@loang.net>2025-05-29 15:59:23 +0900
committerNguyễn Gia Phong <cnx@loang.net>2025-05-29 15:59:23 +0900
commitbd2898d5182ed4ac4e46f5035c6df765d59ad298 (patch)
tree1958437806c9f7640f00c69322dccaa98063c2fd
parenta727804142db3258c5af8a0a31f79454418ccee2 (diff)
downloadscadere-bd2898d5182ed4ac4e46f5035c6df765d59ad298.tar.gz
Relieve backpressure
-rw-r--r--src/scadere/listen.py116
-rw-r--r--tst/test_check.py1
-rw-r--r--tst/test_listen.py14
3 files changed, 79 insertions, 52 deletions
diff --git a/src/scadere/listen.py b/src/scadere/listen.py
index 80929ea..aa322ee 100644
--- a/src/scadere/listen.py
+++ b/src/scadere/listen.py
@@ -19,13 +19,14 @@
 from argparse import ArgumentParser
 from asyncio import run, start_server
 from base64 import urlsafe_b64decode as from_base64
-from datetime import datetime
+from datetime import datetime, timezone
 from functools import partial
+from http import HTTPStatus
 from pathlib import Path
 from urllib.parse import parse_qs, urljoin, urlsplit
 from xml.etree.ElementTree import (Element as xml_element,
                                    SubElement as xml_subelement,
-                                   indent, tostring as str_from_xml)
+                                   indent, tostringlist as strings_from_xml)
 from sys import argv
 
 from . import __version__, GNUHelpFormatter, NetLoc
@@ -43,6 +44,38 @@ def path(hostname, port, issuer, serial):
     return f'{hostname}/{port}/{issuer}/{serial}'
 
 
+async def write_status(writer, status):
+    """Write the given HTTP/1.1 status line."""
+    writer.write(f'HTTP/1.1 {status.value} {status.phrase}\r\n'.encode())
+    await writer.drain()
+
+
+async def write_content_type(writer, content_type):
+    """Write the given HTTP content type."""
+    writer.write(f'Content-Type: {content_type}\r\n'.encode())
+    await writer.drain()
+
+
+async def describe_status(writer, status):
+    """Write a HTTP/1.1 response including status description."""
+    await write_status(writer, status)
+    content = f'{status.description}\n'.encode()
+    await write_content_type(writer, 'text/plain')
+    writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
+    writer.write(content)
+    await writer.drain()
+
+
+async def write_xml(writer, document):
+    content = strings_from_xml(xml(document), 'unicode',
+                               xml_declaration=True, default_namespace=None)
+    length = len(''.join(content))
+    writer.write(f'Content-Length: {length}\r\n\r\n'.encode())
+    for part in content:
+        writer.write(part.encode())
+        await writer.drain()
+
+
 def body(not_before, not_after, hostname, port, serial, issuer):
     """Describe the given certificate in XHTML."""
     return (('h1', 'TLS certificate information'),
@@ -106,56 +139,47 @@ async def handle(certs, base_url, reader, writer):
     domains = tuple(parse_qs(url_parts.query).get('domain', ['']))
 
     if not request.startswith(b'GET '):
-        writer.write(b'HTTP/1.1 405 Method Not Allowed\r\n')
-        await writer.drain()
-        writer.close()
-        await writer.wait_closed()
-        return
+        await describe_status(writer, HTTPStatus.METHOD_NOT_ALLOWED)
     elif url.startswith(b'//'):  # urljoin goes haywire
-        writer.write(b'HTTP/1.1 404 Not Found\r\n')
+        await describe_status(writer, HTTPStatus.NOT_FOUND)
     elif url_parts.path == urlsplit(base_url).path:  # Atom feed
-        writer.write(b'HTTP/1.1 200 OK\r\n')
-        writer.write(b'Content-Type: application/atom+xml\r\n')
-        feed = xml(('feed', {'xmlns': 'http://www.w3.org/2005/Atom'},
-                    ('id', base_url),
-                    ('link', {'rel': 'self', 'href': base_url}),
-                    ('title', certs.name),
-                    ('updated', datetime.now().isoformat()),
-                    ('generator',
-                     {'uri': 'https://trong.loang.net/scadere/about',
-                      'version': __version__},
-                     'Scadere'),
-                    *(entry(base_url, cert)
-                      for cert in lookup.values()
-                      if cert[2].endswith(domains))))
-        content = str_from_xml(feed, 'unicode', xml_declaration=True,
-                               default_namespace=None).encode()
-        writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
-        writer.write(content)
+        await write_status(writer, HTTPStatus.OK)
+        await write_content_type(writer, 'application/atom+xml')
+        feed = ('feed', {'xmlns': 'http://www.w3.org/2005/Atom'},
+                ('id', base_url),
+                ('link', {'rel': 'self', 'href': base_url}),
+                ('title', certs.name),
+                ('updated', datetime.now(tz=timezone.utc).isoformat()),
+                ('generator',
+                 {'uri': 'https://trong.loang.net/scadere/about',
+                  'version': __version__},
+                 'Scadere'),
+                *(entry(base_url, cert) for cert in lookup.values()
+                  if cert[2].endswith(domains)))
+        await write_xml(writer, feed)
     elif url_parts.path in lookup:  # accessible Atom entry's link/ID
-        writer.write(b'HTTP/1.1 200 OK\r\n')
-        writer.write(b'Content-Type: application/xhtml+xml\r\n')
+        await write_status(writer, HTTPStatus.OK)
+        await write_content_type(writer, 'application/xhtml+xml')
         (not_before, not_after,
          hostname, port, serial, issuer) = lookup[url_parts.path]
-        page = xml(('html', {'xmlns': 'http://www.w3.org/1999/xhtml',
-                             'lang': 'en'},
-                    ('head',
-                     ('meta', {'name': 'color-scheme',
-                               'content': 'dark light'}),
-                     ('meta', {'name': 'viewport',
-                               'content': ('width=device-width,'
-                                           'initial-scale=1.0')}),
-                     ('link', {'rel': 'icon', 'href': 'data:,'}),
-                     ('title', f'TLS certificate - {hostname}:{port}')),
-                    ('body', *body(not_before, not_after,
-                                   hostname, port, serial, issuer))))
-        content = str_from_xml(page, 'unicode', xml_declaration=True,
-                               default_namespace=None).encode()
-        writer.write(f'Content-Length: {len(content)}\r\n\r\n'.encode())
-        writer.write(content)
+        page = ('html', {'xmlns': 'http://www.w3.org/1999/xhtml',
+                         'lang': 'en'},
+                ('head',
+                 ('meta', {'name': 'color-scheme',
+                           'content': 'dark light'}),
+                 ('meta', {'name': 'viewport',
+                           'content': ('width=device-width,'
+                                       'initial-scale=1.0')}),
+                 ('link', {'rel': 'icon', 'href': 'data:,'}),
+                 ('title', f'TLS certificate - {hostname}:{port}')),
+                ('body', *body(not_before, not_after,
+                               hostname, port, serial, issuer)))
+        await write_xml(writer, page)
     else:
-        writer.write(b'HTTP/1.1 404 Not Found\r\n')
-    await writer.drain()
+        await describe_status(writer, HTTPStatus.NOT_FOUND)
+
+    assert writer.can_write_eof()
+    writer.write_eof()
     writer.close()
     await writer.wait_closed()
 
diff --git a/tst/test_check.py b/tst/test_check.py
index b9a89ff..9890209 100644
--- a/tst/test_check.py
+++ b/tst/test_check.py
@@ -36,6 +36,7 @@ NEXT_WEEK = SECONDS_AGO + timedelta(days=7)
 async def noop(reader, writer):
     """Do nothing."""
     writer.close()
+    await writer.wait_closed()
 
 
 def failed_to_get_cert(summary):
diff --git a/tst/test_listen.py b/tst/test_listen.py
index d4d4d76..4d49ae2 100644
--- a/tst/test_listen.py
+++ b/tst/test_listen.py
@@ -36,7 +36,7 @@ from hypothesis.strategies import (builds, composite, data,
                                    datetimes, integers, lists, text)
 from hypothesis.provisional import domains, urls
 
-from scadere.listen import body, entry, handle, path, xml
+from scadere.listen import body, entry, handle, path, with_trailing_slash, xml
 
 ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'}
 XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'}
@@ -167,6 +167,7 @@ async def connect(socket):
         yield reader, writer
     finally:
         writer.close()
+        await writer.wait_closed()
 
 
 async def fetch_xml(socket, url, content_type):
@@ -181,7 +182,8 @@ async def fetch_xml(socket, url, content_type):
         headers_bytes = await reader.readuntil(b'\r\n\r\n')
         headers = header_parser.parsebytes(headers_bytes)
         assert headers['Content-Type'] == content_type
-        content = await reader.read(int(headers['Content-Length']))
+        content = await reader.read()
+        assert len(content) == int(headers['Content-Length'])
         return XML(content.decode(), xml_parser)
 
 
@@ -213,7 +215,7 @@ async def test_http_200(base_url, certs):
 @example(type('//', (), {'draw': lambda *a, **kw: 'https://a.example//b'}))
 @given(data())
 @settings(suppress_health_check=[HealthCheck.too_slow])
-async def test_http_404(drawer):
+async def test_unrecognized_url(drawer):
     base_url = drawer.draw(urls(), label='base URL')
     url = drawer.draw(urls().filter(lambda url: not is_root(base_url, url)),
                       label='request URL')
@@ -224,12 +226,12 @@ async def test_http_404(drawer):
                 async with connect(socket) as (reader, writer):
                     writer.write(f'GET {urlsplit(url).path}\r\n'.encode())
                     await writer.drain()
-                    response = await reader.read()
+                    response = await reader.readuntil(b'\r\n')
                     assert response == b'HTTP/1.1 404 Not Found\r\n'
 
 
 @given(urls(), text().filter(lambda method: not method.startswith('GET ')))
-async def test_http_405(base_url, request):
+async def test_unallowed_method(base_url, request):
     with tmp_cert_file(()) as cert_file:
         handler = partial(handle, cert_file, base_url)
         async with await start_server(handler, 'localhost') as server:
@@ -237,5 +239,5 @@ async def test_http_405(base_url, request):
                 async with connect(socket) as (reader, writer):
                     writer.write(f'{request}\r\n'.encode())
                     await writer.drain()
-                    response = await reader.read()
+                    response = await reader.readuntil(b'\r\n')
                     assert response == b'HTTP/1.1 405 Method Not Allowed\r\n'