summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/fead.py39
1 files changed, 28 insertions, 11 deletions
diff --git a/src/fead.py b/src/fead.py
index 5c42116..dd47266 100755
--- a/src/fead.py
+++ b/src/fead.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 # Advert generator from web feeds
 # Copyright (C) 2022, 2024  Nguyễn Gia Phong
+# Copyright (C) 2023  Ngô Ngọc Đức Huy
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published
@@ -18,7 +19,7 @@
 __version__ = '0.1.3'
 
 from argparse import ArgumentParser, FileType, HelpFormatter
-from asyncio import CancelledError, gather, open_connection, run
+from asyncio import CancelledError, TaskGroup, gather, open_connection, run
 from collections import namedtuple
 from datetime import datetime
 from email.utils import parsedate_to_datetime
@@ -29,6 +30,7 @@ from pathlib import Path
 from re import compile as regex
 from sys import stdin, stdout
 from textwrap import shorten
+from traceback import print_exception
 from urllib.error import HTTPError
 from urllib.parse import urljoin, urlsplit
 from warnings import warn
@@ -192,14 +194,26 @@ async def fetch(raw_url):
                         response.getheaders(), response)
 
 
-async def fetch_all(urls):
-    """Fetch all given URLs asynchronously and return them parsed."""
-    tasks = gather(*map(fetch, urls))
-    try:
-        return await tasks
-    except:
-        tasks.cancel()  # structured concurrency
-        raise
+async def fetch_all(urls, strict):
+    """Fetch all given URLs asynchronously and return them parsed.
+
+    If in strict mode, abort when encounter the first error.
+    """
+    if strict:
+        async with TaskGroup() as group:
+            tasks = tuple(group.create_task(fetch(url)) for url in urls)
+        return (task.result() for task in tasks)
+    else:
+        feeds, exceptions = [], []
+        for result in await gather(*map(fetch, urls), return_exceptions=True):
+            if isinstance(result, BaseException):
+                exceptions.append(result)
+            else:
+                feeds.append(result)
+        if exceptions:
+            warn('some web feed(s) have been skipped')
+            print_exception(ExceptionGroup("ignored errors", exceptions))
+        return feeds
 
 
 def select(n, ads):
@@ -228,6 +242,8 @@ def main():
     parser.add_argument('-f', '--feed', metavar='URL',
                         action='append', dest='feeds',
                         help='addtional web feed URL (multiple use)')
+    parser.add_argument('-s', '--strict', action='store_true',
+                        help='abort when fail to fetch or parse a web feed')
     parser.add_argument('-n', '--count', metavar='N', type=int, default=3,
                         help='maximum number of ads in total (default to 3)')
     parser.add_argument('-p', '--per-feed', metavar='N', type=int, default=1,
@@ -245,8 +261,9 @@ def main():
 
     template = args.template.read()
     args.template.close()
-    for ad in select(args.count, (ad for feed in run(fetch_all(args.feeds))
-                                  for ad in select(args.per_feed, feed))):
+    for ad in select(args.count,
+                     (ad for feed in run(fetch_all(args.feeds, args.strict))
+                      for ad in select(args.per_feed, feed))):
         args.output.write(template.format(**truncate(ad, args.len)._asdict()))
     args.output.close()