summary refs log tree commit diff
path: root/gnu
diff options
context:
space:
mode:
authorRicardo Wurmus <rekado@elephly.net>2023-07-10 13:13:11 +0200
committerRicardo Wurmus <rekado@elephly.net>2023-07-10 13:13:45 +0200
commitd0296970fb8ed97ac17bd4c580351af961a8c0fb (patch)
tree0a2d2b61c876da7fa72d61b1cf51b8aa37b22ed5 /gnu
parente3d9d896b540f82e4511f2bd6ae6373390ee2d4d (diff)
downloadguix-d0296970fb8ed97ac17bd4c580351af961a8c0fb.tar.gz
gnu: Add python-captum.
* gnu/packages/machine-learning.scm (python-captum): New variable.
Diffstat (limited to 'gnu')
-rw-r--r--gnu/packages/machine-learning.scm45
1 files changed, 45 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 5b98705943..f50398b555 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -3868,6 +3868,51 @@ AI services.")
 Actions for the Lightning suite of libraries.")
     (license license:asl2.0)))
 
+(define-public python-captum
+  (package
+    (name "python-captum")
+    (version "0.6.0")
+    (source (origin
+              (method git-fetch)
+              (uri (git-reference
+                    (url "https://github.com/pytorch/captum")
+                    (commit (string-append "v" version))))
+              (file-name (git-file-name name version))
+              (sha256
+               (base32
+                "1h4n91ivhjxm6wj0vgqpfss2dmq4sjcp0appd08cd5naisabjyb5"))))
+    (build-system pyproject-build-system)
+    (arguments
+     (list
+      #:test-flags
+      '(list "-k"
+             ;; These two tests (out of more than 1000 tests) fail because of
+             ;; accuracy problems.
+             "not test_softmax_classification_batch_multi_target\
+ and not test_softmax_classification_batch_zero_baseline")))
+    (propagated-inputs (list python-matplotlib python-numpy python-pytorch))
+    (native-inputs (list jupyter
+                         python-annoy
+                         python-black
+                         python-flake8
+                         python-flask
+                         python-flask-compress
+                         python-ipython
+                         python-ipywidgets
+                         python-mypy
+                         python-parameterized
+                         python-pytest
+                         python-pytest-cov
+                         python-scikit-learn))
+    (home-page "https://captum.ai")
+    (synopsis "Model interpretability for PyTorch")
+    (description "Captum is a model interpretability and understanding library
+for PyTorch.  Captum contains general purpose implementations of integrated
+gradients, saliency maps, smoothgrad, vargrad and others for PyTorch models.
+It has quick integration for models built with domain-specific libraries such
+as torchvision, torchtext, and others.")
+    (license license:bsd-3)))
+
 (define-public python-readchar
   (package
     (name "python-readchar")