about summary refs log tree commit diff
path: root/src/rub/xml.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/rub/xml.py')
-rw-r--r--src/rub/xml.py48
1 files changed, 37 insertions, 11 deletions
diff --git a/src/rub/xml.py b/src/rub/xml.py
index ed61a8b..4c3a2ae 100644
--- a/src/rub/xml.py
+++ b/src/rub/xml.py
@@ -17,12 +17,13 @@
 # along with rub.  If not, see <https://www.gnu.org/licenses/>.
 
 from copy import deepcopy
-from functools import partial
 from pathlib import Path
 
-from lxml.etree import QName, XML, XSLT, XSLTExtension
+from lxml.builder import E
+from lxml.html import document_fromstring as from_html
+from lxml.etree import QName, XML, XSLT, XSLTExtension, tostring as serialize
 
-__all__ = ['NS', 'generator', 'recurse']
+__all__ = ['NS', 'Processor', 'recurse']
 
 NS = 'https://rub.parody'
 
@@ -57,14 +58,39 @@ class Evaluator(XSLTExtension):
         handle(self, context, input_node, output_parent)
 
 
-def generator(xslt, **handlers):
-    """Return a function taking an XML file and apply given XSLT."""
-    stylesheet = xslt.read_bytes()
-    extensions = {(NS, 'eval'): Evaluator(**handlers)}
-    transform = XSLT(XML(stylesheet), extensions=extensions)
+class Serializer(XSLTExtension):
+    def execute(self, context, self_node, input_node, output_parent):
+        output_parent.text = serialize(deepcopy(input_node))
+
+
+class Processor:
+    """Callable XSLT processor."""
 
-    def make(src, dest):
+    def __init__(self, xslt: Path, change_name, **handlers) -> None:
+        self.xslt, self.change_name = xslt, change_name
+        stylesheet = xslt.read_bytes()
+        extensions = {(NS, 'eval'): Evaluator(**handlers),
+                      (NS, 'serialize'): Serializer()}
+        self.transform = XSLT(XML(stylesheet), extensions=extensions)
+
+    def process(self, src: Path, dest: Path) -> None:
         dest.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
-        dest.write_text(str(transform(XML(src.read_bytes()))))
+        dest.write_text(str(self.transform(XML(src.read_bytes()))))
+
 
-    return make
+def gen_omnifeed(sources: list[Path], pages: list[Path],
+                 out_dir: Path, dest: Path) -> None:
+    """Generate generic global feed."""
+    entries = []
+    for src, page in zip(sources, pages):
+        src_root = XML(src.read_bytes())
+        desc = src_root.findtext('description', '', {None: NS})
+        if not desc: continue
+        title = src_root.findtext('title', '', {None: NS})
+        date = src_root.findtext('date', '', {None: NS})
+        page_root = from_html(page.read_bytes())
+        path = str(page.relative_to(out_dir))
+        entries.append(E.entry(E.title(title), E.description(desc),
+                               E.date(date), E.path(path), page_root))
+    dest.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
+    dest.write_bytes(serialize(E.feed(*entries), pretty_print=True))