diff options
Diffstat (limited to 'tst/test_listen.py')
-rw-r--r-- | tst/test_listen.py | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/tst/test_listen.py b/tst/test_listen.py index 3737baa..e204d4f 100644 --- a/tst/test_listen.py +++ b/tst/test_listen.py @@ -17,8 +17,6 @@ # along with scadere. If not, see <https://www.gnu.org/licenses/>. from asyncio import TaskGroup, open_connection, start_server -from base64 import (urlsafe_b64decode as from_base64, - urlsafe_b64encode as base64) from contextlib import asynccontextmanager, contextmanager from copy import deepcopy from email.parser import BytesHeaderParser @@ -36,8 +34,9 @@ from hypothesis.strategies import (builds, composite, data, datetimes, integers, lists, text) from hypothesis.provisional import domains, urls -from scadere.listen import (body, entry, handle, is_subdomain, - path, with_trailing_slash, xml) +from scadere.check import base64_from_str +from scadere.listen import (body, entry, handle, is_subdomain, path, + str_from_base64, with_trailing_slash, xml) ATOM_NAMESPACES = {'': 'http://www.w3.org/2005/Atom'} XHTML_NAMESPACES = {'': 'http://www.w3.org/1999/xhtml'} @@ -55,7 +54,7 @@ def serials(): def ca_names(): """Return a Hypothesis strategy for CA names.""" - return text().map(lambda name: base64(name.encode()).decode()) + return text().map(lambda name: base64_from_str(name)) @given(domains(), ports(), ca_names(), serials()) @@ -75,7 +74,7 @@ def test_body(hostname, port, issuer, serial, not_before, not_after): (v for k, v in r[-1][1:] if k == 'dd'))) assert d['Domain'] == hostname assert d['Port'] == port - assert d['Issuer'] == from_base64(issuer.encode()).decode() + assert d['Issuer'] == str_from_base64(issuer) assert d['Serial number'] == serial assert d['Valid from'] == not_before assert d['Valid until'] == not_after @@ -88,7 +87,7 @@ def test_atom_entry(base_url, hostname, port, cert = not_before, not_after, hostname, port, serial, issuer r = str_from_xml(xml(entry(base_url, cert)), 'unicode', short_empty_elements=False) - issuer_str = from_base64(issuer.encode()).decode() + issuer_str = str_from_base64(issuer) url = urljoin(base_url, path(hostname, port, issuer, serial)) assert r == f'''<entry> <author> @@ -142,7 +141,7 @@ def certificates(draw): port = draw(ports()) serial = draw(serials()) # Free-formed UTF-8 could easily creates malformed XML. - issuer = base64(draw(text(ascii_letters)).encode()).decode() + issuer = base64_from_str(draw(text(ascii_letters))) return f'{not_before} {not_after} {hostname} {port} {serial} {issuer}' |