[mpact][compiler] more sparsity propagation rules with tests (#39)
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 9b33892..d944376 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -359,6 +359,12 @@
raise RuntimeError(f"Unsupported sparse layout for {a}")
+def sparse_arg(args, i):
+ if isinstance(args[i], torch.fx.node.Node):
+ return args[i].meta.get("sparsity", None)
+ return None
+
+
def sparse_export(
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
@@ -402,8 +408,22 @@
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
- if opname in {"abs", "neg", "relu", "sin", "mul"}:
- node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
+ if opname in {"abs", "neg", "relu", "sin"}:
+ node.meta["sparsity"] = sparse_arg(node.args, 0)
+ # Some simplistic rules for preserving sparsity. Soon
+ # to be replaced by proper FX graph propagation.
+ elif opname in {"mul"}:
+ m0 = sparse_arg(node.args, 0)
+ m1 = sparse_arg(node.args, 1)
+ if m0 is not None:
+ node.meta["sparsity"] = m0
+ elif m1 is not None:
+ node.meta["sparsity"] = m1
+ elif opname in {"add", "mm"}:
+ m0 = sparse_arg(node.args, 0)
+ m1 = sparse_arg(node.args, 1)
+ if m0 is not None and m1 is not None:
+ node.meta["sparsity"] = m0
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
@@ -412,13 +432,13 @@
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
- # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
- elif opname == "select" and node.args[0].meta.get("sparsity", None):
+ # node.meta["sparsity"] = sparse_arg(node.args, 0)
+ elif opname == "select" and sparse_arg(node.args, 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
- elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
+ elif opname == "stack" and sparse_arg(node.args[0], 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
diff --git a/test/python/add.py b/test/python/add.py
new file mode 100644
index 0000000..00d4d62
--- /dev/null
+++ b/test/python/add.py
@@ -0,0 +1,89 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import AddNet
+
+
+def print_sparse(res):
+ print(res[0])
+ print(res[1])
+ print(res[2])
+
+
+net = AddNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+ [
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 2.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [3.0, 0.0, 0.0, 0.0],
+ ],
+ dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[16., 18., 20., 22.],
+# CHECK: [24., 26., 28., 30.],
+# CHECK: [32., 34., 36., 38.],
+# CHECK: [40., 42., 44., 46.]{{\]}})
+# CHECK: tensor({{\[}}[16., 18., 18., 19.],
+# CHECK: [20., 21., 22., 25.],
+# CHECK: [24., 25., 26., 27.],
+# CHECK: [31., 29., 30., 31.]{{\]}})
+# CHECK: tensor({{\[}}[ 0., 2., 2., 3.],
+# CHECK: [ 4., 5., 6., 9.],
+# CHECK: [ 8., 9., 10., 11.],
+# CHECK: [15., 13., 14., 15.]{{\]}})
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK: col_indices=tensor([1, 3, 0]),
+# CHECK: values=tensor([2., 4., 6.]), size=(4, 4), nnz=3,
+# CHECK: layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK: {{\[}}[16. 18. 20. 22.]
+# CHECK: [24. 26. 28. 30.]
+# CHECK: [32. 34. 36. 38.]
+# CHECK: [40. 42. 44. 46.]{{\]}}
+# CHECK: {{\[}}[16. 18. 18. 19.]
+# CHECK: [20. 21. 22. 25.]
+# CHECK: [24. 25. 26. 27.]
+# CHECK: [31. 29. 30. 31.]{{\]}}
+# CHECK: {{\[}}[ 0. 2. 2. 3.]
+# CHECK: [ 4. 5. 6. 9.]
+# CHECK: [ 8. 9. 10. 11.]
+# CHECK: [15. 13. 14. 15.]{{\]}}
+# CHECK: [0 1 2 2 3]
+# CHECK: [1 3 0]
+# CHECK: [2. 4. 6.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print(res)
+res = mpact_jit(net, X, S)
+print(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/mm.py b/test/python/mm.py
new file mode 100644
index 0000000..3c51c37
--- /dev/null
+++ b/test/python/mm.py
@@ -0,0 +1,89 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import MMNet
+
+
+def print_sparse(res):
+ print(res[0])
+ print(res[1])
+ print(res[2])
+
+
+net = MMNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+ [
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 2.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [3.0, 0.0, 0.0, 0.0],
+ ],
+ dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[ 152., 158., 164., 170.],
+# CHECK: [ 504., 526., 548., 570.],
+# CHECK: [ 856., 894., 932., 970.],
+# CHECK: [1208., 1262., 1316., 1370.]{{\]}})
+# CHECK: tensor({{\[}}[20., 21., 22., 23.],
+# CHECK: [56., 58., 60., 62.],
+# CHECK: [ 0., 0., 0., 0.],
+# CHECK: [48., 51., 54., 57.]{{\]}})
+# CHECK: tensor({{\[}}[ 9., 0., 0., 2.],
+# CHECK: [21., 4., 0., 10.],
+# CHECK: [33., 8., 0., 18.],
+# CHECK: [45., 12., 0., 26.]{{\]}})
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK: col_indices=tensor([3, 0, 1]),
+# CHECK: values=tensor([2., 6., 3.]), size=(4, 4), nnz=3,
+# CHECK: layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK: {{\[}}[ 152. 158. 164. 170.]
+# CHECK: [ 504. 526. 548. 570.]
+# CHECK: [ 856. 894. 932. 970.]
+# CHECK: [1208. 1262. 1316. 1370.]{{\]}}
+# CHECK: {{\[}}[20. 21. 22. 23.]
+# CHECK: [56. 58. 60. 62.]
+# CHECK: [ 0. 0. 0. 0.]
+# CHECK: [48. 51. 54. 57.]{{\]}}
+# CHECK: {{\[}}[ 9. 0. 0. 2.]
+# CHECK: [21. 4. 0. 10.]
+# CHECK: [33. 8. 0. 18.]
+# CHECK: [45. 12. 0. 26.]{{\]}}
+# CHECK: [0 1 2 2 3]
+# CHECK: [3 0 1]
+# CHECK: [2. 6. 3.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print(res)
+res = mpact_jit(net, X, S)
+print(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/mul.py b/test/python/mul.py
new file mode 100644
index 0000000..fd8692f
--- /dev/null
+++ b/test/python/mul.py
@@ -0,0 +1,87 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import MulNet
+
+
+def print_sparse(res):
+ print(res[0])
+ print(res[1])
+ print(res[2])
+
+
+net = MulNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+ [
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 2.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [3.0, 0.0, 0.0, 0.0],
+ ],
+ dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[ 0., 17., 36., 57.],
+# CHECK: [ 80., 105., 132., 161.],
+# CHECK: [192., 225., 260., 297.],
+# CHECK: [336., 377., 420., 465.]{{\]}})
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK: col_indices=tensor([1, 3, 0]),
+# CHECK: values=tensor([17., 46., 84.]), size=(4, 4), nnz=3,
+# CHECK: layout=torch.sparse_csr)
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK: col_indices=tensor([1, 3, 0]),
+# CHECK: values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3,
+# CHECK: layout=torch.sparse_csr)
+# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
+# CHECK: col_indices=tensor([1, 3, 0]),
+# CHECK: values=tensor([1., 4., 9.]), size=(4, 4), nnz=3,
+# CHECK: layout=torch.sparse_csr)
+# CHECK: mpact
+# CHECK: {{\[}}[ 0. 17. 36. 57.]
+# CHECK: [ 80. 105. 132. 161.]
+# CHECK: [192. 225. 260. 297.]
+# CHECK: [336. 377. 420. 465.]{{\]}}
+# CHECK: [0 1 2 2 3]
+# CHECK: [1 3 0]
+# CHECK: [17. 46. 84.]
+# CHECK: [0 1 2 2 3]
+# CHECK: [1 3 0]
+# CHECK: [ 1. 14. 36.]
+# CHECK: [0 1 2 2 3]
+# CHECK: [1 3 0]
+# CHECK: [1. 4. 9.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(X, Y)
+print(res)
+res = net(S, Y)
+print(res)
+res = net(X, S)
+print(res)
+res = net(S, S)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, X, Y)
+print(res)
+res = mpact_jit(net, S, Y)
+print_sparse(res)
+res = mpact_jit(net, X, S)
+print_sparse(res)
+res = mpact_jit(net, S, S)
+print_sparse(res)
diff --git a/test/python/sddmm.py b/test/python/sddmm.py
index 8793120..a0890c7 100644
--- a/test/python/sddmm.py
+++ b/test/python/sddmm.py
@@ -7,6 +7,14 @@
from mpact.models.kernels import MMNet, SDDMMNet
+
+def print_sparse(res):
+ print(res[0])
+ print(res[1])
+ print(res[2])
+ print(res[3])
+
+
mmnet = MMNet()
sddmmnet = SDDMMNet()
@@ -54,7 +62,4 @@
dense = mpact_jit(mmnet, A, B)
print(dense)
res = mpact_jit(sddmmnet, S, A, B)
-print(res[0])
-print(res[1])
-print(res[2])
-print(res[3])
+print_sparse(res)