summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gnu/packages/machine-learning.scm79
1 files changed, 79 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 5748275242..057d90a709 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -104,6 +104,7 @@
   #:use-module (gnu packages python-science)
   #:use-module (gnu packages python-web)
   #:use-module (gnu packages python-xyz)
+  #:use-module (gnu packages rdf)
   #:use-module (gnu packages regex)
   #:use-module (gnu packages rpc)
   #:use-module (gnu packages serialization)
@@ -4161,6 +4162,84 @@ Note: currently this package does not provide GPU support.")
        (replace "onnx" onnx-for-torch2)
        (replace "onnx-optimizer" onnx-optimizer-for-torch2)))))
 
+(define-public python-pytorch-geometric
+  (package
+    (name "python-pytorch-geometric")
+    (version "2.4.0")
+    (source (origin
+              (method git-fetch)
+              (uri (git-reference
+                    (url "https://github.com/pyg-team/pytorch_geometric/")
+                    (commit version)))
+              (file-name (git-file-name name version))
+              (sha256
+               (base32
+                "0hrs579asjsph16hyb4ablkbgfwd5j9y5s6ny7ahn3qrbkl2ji1g"))))
+    (build-system pyproject-build-system)
+    (arguments
+     (list
+      #:test-flags
+      ;; Hangs with AttributeError: 'NoneType' object has no attribute 'rpc_async'
+      '(list "--ignore=test/distributed/test_rpc.py"
+             ;; A message passing jinja template is missing
+             "--ignore=test/nn/conv/test_message_passing.py"
+             "--ignore=test/nn/test_sequential.py"
+             "--ignore=test/nn/models/test_basic_gnn.py"
+             ;; These all fail with a size mismatch error such as
+             ;; RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
+             "--ignore=test/explain/algorithm/test_captum_explainer.py"
+             "-k" (string-append
+                   ;; Permissions error
+                   "not test_packaging"
+                   ;; These refuse to be run on CPU and really want a GPU
+                   " and not test_add_random_walk_pe"
+                   " and not test_asap"
+                   " and not test_two_hop"))
+      #:phases
+      '(modify-phases %standard-phases
+         (add-after 'unpack 'delete-top-level-directories
+           (lambda _
+             ;; The presence of these directories confuses the pyproject build
+             ;; system.
+             (for-each delete-file-recursively
+                       '("conda" "docker" "graphgym")))))))
+    (propagated-inputs
+     (list onnx
+           python-captum
+           python-graphviz
+           python-h5py
+           python-jinja2
+           python-matplotlib
+           python-networkx
+           python-numba
+           python-numpy
+           python-opt-einsum
+           python-pandas
+           python-protobuf
+           python-psutil
+           python-pyparsing
+           python-pytorch-lightning
+           python-rdflib
+           python-requests
+           python-scikit-image
+           python-scikit-learn
+           python-scipy
+           python-statsmodels
+           python-sympy
+           python-tabulate
+           python-torchmetrics
+           python-tqdm))
+    (native-inputs
+     (list python-flit-core
+           python-pytest
+           python-pytest-cov))
+    (home-page "https://pyg.org")
+    (synopsis "Graph Neural Network library for PyTorch")
+    (description
+     "PyG is a library built upon PyTorch to easily write and train Graph
+Neural Networks for a wide range of applications related to structured data.")
+    (license license:expat)))
+
 (define-public python-lightning-cloud
   (package
     (name "python-lightning-cloud")