[mpact][compiler] more sparsity propagation rules with tests (#39)

diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 9b33892..d944376 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -359,6 +359,12 @@
         raise RuntimeError(f"Unsupported sparse layout for {a}")
 
 
+def sparse_arg(args, i):
+    if isinstance(args[i], torch.fx.node.Node):
+        return args[i].meta.get("sparsity", None)
+    return None
+
+
 def sparse_export(
     f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
 ) -> torch.export.ExportedProgram:
@@ -402,8 +408,22 @@
             # TODO: use upstream _opname implementation when available
             opname = node.target._schema.name.split("::")[1]
             # Zero preserving elt-wise unary op.
-            if opname in {"abs", "neg", "relu", "sin", "mul"}:
-                node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
+            if opname in {"abs", "neg", "relu", "sin"}:
+                node.meta["sparsity"] = sparse_arg(node.args, 0)
+            # Some simplistic rules for preserving sparsity. Soon
+            # to be replaced by proper FX graph propagation.
+            elif opname in {"mul"}:
+                m0 = sparse_arg(node.args, 0)
+                m1 = sparse_arg(node.args, 1)
+                if m0 is not None:
+                    node.meta["sparsity"] = m0
+                elif m1 is not None:
+                    node.meta["sparsity"] = m1
+            elif opname in {"add", "mm"}:
+                m0 = sparse_arg(node.args, 0)
+                m1 = sparse_arg(node.args, 1)
+                if m0 is not None and m1 is not None:
+                    node.meta["sparsity"] = m0
             elif opname == "_to_sparse" or opname == "to_sparse":
                 dim = len(node.meta.get("val").shape)
                 node.meta["sparsity"] = SparsityMeta(
@@ -412,13 +432,13 @@
             # TODO: Uncomment this to hack sparsity into the network.
             # elif opname == "_to_dense" or opname == "to_dense":
             #     # hack (assumes we never really want the to_dense for now)
-            #     node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
-            elif opname == "select" and node.args[0].meta.get("sparsity", None):
+            #     node.meta["sparsity"] = sparse_arg(node.args, 0)
+            elif opname == "select" and sparse_arg(node.args, 0):
                 dim = len(node.meta.get("val").shape)
                 node.meta["sparsity"] = SparsityMeta(
                     torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
                 )
-            elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
+            elif opname == "stack" and sparse_arg(node.args[0], 0):
                 dim = len(node.meta.get("val").shape)
                 node.meta["sparsity"] = SparsityMeta(
                     torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
diff --git a/test/python/add.py b/test/python/add.py
new file mode 100644
index 0000000..00d4d62
--- /dev/null
+++ b/test/python/add.py
@@ -0,0 +1,89 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import AddNet
+
+
+def print_sparse(res):
+    print(res[0])
+    print(res[1])
+    print(res[2])
+
+
+net = AddNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+    [
+        [0.0, 1.0, 0.0, 0.0],
+        [0.0, 0.0, 0.0, 2.0],
+        [0.0, 0.0, 0.0, 0.0],
+        [3.0, 0.0, 0.0, 0.0],
+    ],
+    dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK:   tensor({{\[}}[16., 18., 20., 22.],
+# CHECK:                [24., 26., 28., 30.],
+# CHECK:                [32., 34., 36., 38.],
+# CHECK:                [40., 42., 44., 46.]{{\]}})
+# CHECK:  tensor({{\[}}[16., 18., 18., 19.],
+# CHECK:               [20., 21., 22., 25.],
+# CHECK:               [24., 25., 26., 27.],
+# CHECK:               [31., 29., 30., 31.]{{\]}})
+# CHECK:  tensor({{\[}}[ 0.,  2.,  2.,  3.],
+# CHECK:               [ 4.,  5.,  6.,  9.],
+# CHECK:               [ 8.,  9., 10., 11.],
+# CHECK:               [15., 13., 14., 15.]{{\]}})
+# CHECK:  tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK:         col_indices=tensor([1, 3, 0]),
+# CHECK:         values=tensor([2., 4., 6.]), size=(4, 4), nnz=3,
+# CHECK:         layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK:   {{\[}}[16. 18. 20. 22.]
+# CHECK:         [24. 26. 28. 30.]
+# CHECK:         [32. 34. 36. 38.]
+# CHECK:         [40. 42. 44. 46.]{{\]}}
+# CHECK:   {{\[}}[16. 18. 18. 19.]
+# CHECK:         [20. 21. 22. 25.]
+# CHECK:         [24. 25. 26. 27.]
+# CHECK:         [31. 29. 30. 31.]{{\]}}
+# CHECK:   {{\[}}[ 0.  2.  2.  3.]
+# CHECK:         [ 4.  5.  6.  9.]
+# CHECK:         [ 8.  9. 10. 11.]
+# CHECK:         [15. 13. 14. 15.]{{\]}}
+# CHECK:  [0 1 2 2 3]
+# CHECK:  [1 3 0]
+# CHECK:  [2. 4. 6.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print(res)
+res = mpact_jit(net, X, S)
+print(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/mm.py b/test/python/mm.py
new file mode 100644
index 0000000..3c51c37
--- /dev/null
+++ b/test/python/mm.py
@@ -0,0 +1,89 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import MMNet
+
+
+def print_sparse(res):
+    print(res[0])
+    print(res[1])
+    print(res[2])
+
+
+net = MMNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+    [
+        [0.0, 1.0, 0.0, 0.0],
+        [0.0, 0.0, 0.0, 2.0],
+        [0.0, 0.0, 0.0, 0.0],
+        [3.0, 0.0, 0.0, 0.0],
+    ],
+    dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK:   tensor({{\[}}[ 152.,  158.,  164.,  170.],
+# CHECK:                [ 504.,  526.,  548.,  570.],
+# CHECK:                [ 856.,  894.,  932.,  970.],
+# CHECK:                [1208., 1262., 1316., 1370.]{{\]}})
+# CHECK:  tensor({{\[}}[20., 21., 22., 23.],
+# CHECK:               [56., 58., 60., 62.],
+# CHECK:               [ 0.,  0.,  0.,  0.],
+# CHECK:               [48., 51., 54., 57.]{{\]}})
+# CHECK:  tensor({{\[}}[ 9.,  0.,  0.,  2.],
+# CHECK:               [21.,  4.,  0., 10.],
+# CHECK:               [33.,  8.,  0., 18.],
+# CHECK:               [45., 12.,  0., 26.]{{\]}})
+# CHECK:  tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK:         col_indices=tensor([3, 0, 1]),
+# CHECK:         values=tensor([2., 6., 3.]), size=(4, 4), nnz=3,
+# CHECK:         layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK:   {{\[}}[ 152.  158.  164.  170.]
+# CHECK:         [ 504.  526.  548.  570.]
+# CHECK:         [ 856.  894.  932.  970.]
+# CHECK:         [1208. 1262. 1316. 1370.]{{\]}}
+# CHECK:   {{\[}}[20. 21. 22. 23.]
+# CHECK:         [56. 58. 60. 62.]
+# CHECK:         [ 0.  0.  0.  0.]
+# CHECK:         [48. 51. 54. 57.]{{\]}}
+# CHECK:   {{\[}}[ 9.  0.  0.  2.]
+# CHECK:         [21.  4.  0. 10.]
+# CHECK:         [33.  8.  0. 18.]
+# CHECK:         [45. 12.  0. 26.]{{\]}}
+# CHECK:  [0 1 2 2 3]
+# CHECK:  [3 0 1]
+# CHECK:  [2. 6. 3.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print(res)
+res = mpact_jit(net, X, S)
+print(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/mul.py b/test/python/mul.py
new file mode 100644
index 0000000..fd8692f
--- /dev/null
+++ b/test/python/mul.py
@@ -0,0 +1,87 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import MulNet
+
+
+def print_sparse(res):
+    print(res[0])
+    print(res[1])
+    print(res[2])
+
+
+net = MulNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+    [
+        [0.0, 1.0, 0.0, 0.0],
+        [0.0, 0.0, 0.0, 2.0],
+        [0.0, 0.0, 0.0, 0.0],
+        [3.0, 0.0, 0.0, 0.0],
+    ],
+    dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[  0.,  17.,  36.,  57.],
+# CHECK:              [ 80., 105., 132., 161.],
+# CHECK:              [192., 225., 260., 297.],
+# CHECK:              [336., 377., 420., 465.]{{\]}})
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK:        col_indices=tensor([1, 3, 0]),
+# CHECK:        values=tensor([17., 46., 84.]), size=(4, 4), nnz=3,
+# CHECK:        layout=torch.sparse_csr)
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK:        col_indices=tensor([1, 3, 0]),
+# CHECK:        values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3,
+# CHECK:        layout=torch.sparse_csr)
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK:        col_indices=tensor([1, 3, 0]),
+# CHECK:        values=tensor([1., 4., 9.]), size=(4, 4), nnz=3,
+# CHECK:        layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK:   {{\[}}[  0.  17.  36.  57.]
+# CHECK:         [ 80. 105. 132. 161.]
+# CHECK:         [192. 225. 260. 297.]
+# CHECK:         [336. 377. 420. 465.]{{\]}}
+# CHECK:  [0 1 2 2 3]
+# CHECK:  [1 3 0]
+# CHECK:  [17. 46. 84.]
+# CHECK:  [0 1 2 2 3]
+# CHECK:  [1 3 0]
+# CHECK:  [ 1. 14. 36.]
+# CHECK:  [0 1 2 2 3]
+# CHECK:  [1 3 0]
+# CHECK:  [1. 4. 9.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print_sparse(res)
+res = mpact_jit(net, X, S)
+print_sparse(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/sddmm.py b/test/python/sddmm.py
index 8793120..a0890c7 100644
--- a/test/python/sddmm.py
+++ b/test/python/sddmm.py
@@ -7,6 +7,14 @@
 
 from mpact.models.kernels import MMNet, SDDMMNet
 
+
+def print_sparse(res):
+    print(res[0])
+    print(res[1])
+    print(res[2])
+    print(res[3])
+
+
 mmnet = MMNet()
 sddmmnet = SDDMMNet()
 
@@ -54,7 +62,4 @@
 dense = mpact_jit(mmnet, A, B)
 print(dense)
 res = mpact_jit(sddmmnet, S, A, B)
-print(res[0])
-print(res[1])
-print(res[2])
-print(res[3])
+print_sparse(res)