[mpact][compiler] add MM and SDDMM unit tests (#38)
This also exposed lack of sparsity propagation
for the element-wise torch.mul operation, fixed
in the same PR
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 5607082..9b33892 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -402,7 +402,7 @@
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
- if opname in {"abs", "neg", "relu", "sin"}:
+ if opname in {"abs", "neg", "relu", "sin", "mul"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
diff --git a/test/python/sddmm.py b/test/python/sddmm.py
new file mode 100644
index 0000000..8793120
--- /dev/null
+++ b/test/python/sddmm.py
@@ -0,0 +1,60 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import MMNet, SDDMMNet
+
+mmnet = MMNet()
+sddmmnet = SDDMMNet()
+
+# Construct very sparse matrix.
+idx = torch.tensor([[0, 4], [0, 4]], dtype=torch.int64)
+val = torch.tensor([2.0, 3.0], dtype=torch.float64)
+S = torch.sparse_coo_tensor(idx, val, size=[5, 5])
+
+# Trivial dense inputs.
+A = torch.arange(0, 25, dtype=torch.float32).view(5, 5)
+B = torch.arange(25, 50, dtype=torch.float32).view(5, 5)
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[ 400., 410., 420., 430., 440.],
+# CHECK: [1275., 1310., 1345., 1380., 1415.],
+# CHECK: [2150., 2210., 2270., 2330., 2390.],
+# CHECK: [3025., 3110., 3195., 3280., 3365.],
+# CHECK: [3900., 4010., 4120., 4230., 4340.]{{\]}})
+# CHECK: tensor(indices=tensor({{\[}}[0, 4],
+# CHECK: [0, 4]{{\]}}),
+# CHECK: values=tensor([ 800., 13020.]),
+# CHECK: size=(5, 5), nnz=2, dtype=torch.float64, layout=torch.sparse_coo)
+# CHECK: mpact
+# CHECK: {{\[}}[ 400. 410. 420. 430. 440.]
+# CHECK: [1275. 1310. 1345. 1380. 1415.]
+# CHECK: [2150. 2210. 2270. 2330. 2390.]
+# CHECK: [3025. 3110. 3195. 3280. 3365.]
+# CHECK: [3900. 4010. 4120. 4230. 4340.]{{\]}}
+# CHECK: [0 2]
+# CHECK: [0 4]
+# CHECK: [0 4]
+# CHECK: [ 800. 13020.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+dense = mmnet(A, B)
+print(dense)
+res = sddmmnet(S, A, B)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+dense = mpact_jit(mmnet, A, B)
+print(dense)
+res = mpact_jit(sddmmnet, S, A, B)
+print(res[0])
+print(res[1])
+print(res[2])
+print(res[3])