[mpact][test] add a count-equal idiom (for sparse consideration) (#73)
The equal operator currently does not sparsify under
PyTorch, but if it were, this would be a great candidate
to further optimize with doing the sum() without
materializing the intermediate result!
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py
index 36e2394..14f505c 100644
--- a/python/mpact/models/kernels.py
+++ b/python/mpact/models/kernels.py
@@ -36,6 +36,12 @@
return (x * x).sum()
+class CountEq(torch.nn.Module):
+ def forward(self, x, s):
+ nums = (x == s).sum()
+ return nums
+
+
class FeatureScale(torch.nn.Module):
def forward(self, F):
sum_vector = torch.sum(F, dim=1)
diff --git a/test/python/counteq.py b/test/python/counteq.py
new file mode 100644
index 0000000..3cdd90a
--- /dev/null
+++ b/test/python/counteq.py
@@ -0,0 +1,56 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit
+
+from mpact.models.kernels import CountEq
+
+
+net = CountEq()
+
+# Construct dense and sparse matrices.
+A = torch.tensor(
+ [
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 2.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [3.0, 0.0, 3.0, 0.0],
+ ],
+ dtype=torch.float32,
+)
+
+# TODO: very interesting idiom to sparsify (collapse the sum
+# into the eq for full sparsity), but needs PyTorch support
+S = A
+# S = A.to_sparse()
+# S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK: 10
+# CHECK: 3
+# CHECK: 1
+# CHECK: 2
+# CHECK: 0
+# CHECK: mpact
+# CHECK: 10
+# CHECK: 3
+# CHECK: 1
+# CHECK: 2
+# CHECK: 0
+#
+
+# Run it with PyTorch.
+print("pytorch")
+for i in range(5):
+ target = torch.tensor(i)
+ res = net(S, target).item()
+ print(res)
+
+print("mpact")
+for i in range(5):
+ target = torch.tensor(i)
+ res = mpact_jit(net, S, target)
+ print(res)