[mpact][compiler] more sparsity propagation rules with tests (#39)
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index 9b33892..d944376 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py
@@ -359,6 +359,12 @@ raise RuntimeError(f"Unsupported sparse layout for {a}") +def sparse_arg(args, i): + if isinstance(args[i], torch.fx.node.Node): + return args[i].meta.get("sparsity", None) + return None + + def sparse_export( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None ) -> torch.export.ExportedProgram: @@ -402,8 +408,22 @@ # TODO: use upstream _opname implementation when available opname = node.target._schema.name.split("::")[1] # Zero preserving elt-wise unary op. - if opname in {"abs", "neg", "relu", "sin", "mul"}: - node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + if opname in {"abs", "neg", "relu", "sin"}: + node.meta["sparsity"] = sparse_arg(node.args, 0) + # Some simplistic rules for preserving sparsity. Soon + # to be replaced by proper FX graph propagation. + elif opname in {"mul"}: + m0 = sparse_arg(node.args, 0) + m1 = sparse_arg(node.args, 1) + if m0 is not None: + node.meta["sparsity"] = m0 + elif m1 is not None: + node.meta["sparsity"] = m1 + elif opname in {"add", "mm"}: + m0 = sparse_arg(node.args, 0) + m1 = sparse_arg(node.args, 1) + if m0 is not None and m1 is not None: + node.meta["sparsity"] = m0 elif opname == "_to_sparse" or opname == "to_sparse": dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( @@ -412,13 +432,13 @@ # TODO: Uncomment this to hack sparsity into the network. # elif opname == "_to_dense" or opname == "to_dense": # # hack (assumes we never really want the to_dense for now) - # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "select" and node.args[0].meta.get("sparsity", None): + # node.meta["sparsity"] = sparse_arg(node.args, 0) + elif opname == "select" and sparse_arg(node.args, 0): dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 ) - elif opname == "stack" and node.args[0][0].meta.get("sparsity", None): + elif opname == "stack" and sparse_arg(node.args[0], 0): dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
diff --git a/test/python/add.py b/test/python/add.py new file mode 100644 index 0000000..00d4d62 --- /dev/null +++ b/test/python/add.py
@@ -0,0 +1,89 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + +from mpact.models.kernels import AddNet + + +def print_sparse(res): + print(res[0]) + print(res[1]) + print(res[2]) + + +net = AddNet() + +# Construct dense and sparse matrices. +X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) +Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) +A = torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 0.0, 0.0, 0.0], + [3.0, 0.0, 0.0, 0.0], + ], + dtype=torch.float32, +) +S = A.to_sparse_csr() + +# +# CHECK: pytorch +# CHECK: tensor({{\[}}[16., 18., 20., 22.], +# CHECK: [24., 26., 28., 30.], +# CHECK: [32., 34., 36., 38.], +# CHECK: [40., 42., 44., 46.]{{\]}}) +# CHECK: tensor({{\[}}[16., 18., 18., 19.], +# CHECK: [20., 21., 22., 25.], +# CHECK: [24., 25., 26., 27.], +# CHECK: [31., 29., 30., 31.]{{\]}}) +# CHECK: tensor({{\[}}[ 0., 2., 2., 3.], +# CHECK: [ 4., 5., 6., 9.], +# CHECK: [ 8., 9., 10., 11.], +# CHECK: [15., 13., 14., 15.]{{\]}}) +# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), +# CHECK: col_indices=tensor([1, 3, 0]), +# CHECK: values=tensor([2., 4., 6.]), size=(4, 4), nnz=3, +# CHECK: layout=torch.sparse_csr) +# CHECK: mpact +# CHECK: {{\[}}[16. 18. 20. 22.] +# CHECK: [24. 26. 28. 30.] +# CHECK: [32. 34. 36. 38.] +# CHECK: [40. 42. 44. 46.]{{\]}} +# CHECK: {{\[}}[16. 18. 18. 19.] +# CHECK: [20. 21. 22. 25.] +# CHECK: [24. 25. 26. 27.] +# CHECK: [31. 29. 30. 31.]{{\]}} +# CHECK: {{\[}}[ 0. 2. 2. 3.] +# CHECK: [ 4. 5. 6. 9.] +# CHECK: [ 8. 9. 10. 11.] +# CHECK: [15. 13. 14. 15.]{{\]}} +# CHECK: [0 1 2 2 3] +# CHECK: [1 3 0] +# CHECK: [2. 4. 6.] +# + +# Run it with PyTorch. +print("pytorch") +res = net(X, Y) +print(res) +res = net(S, Y) +print(res) +res = net(X, S) +print(res) +res = net(S, S) +print(res) + +# Run it with MPACT. +print("mpact") +res = mpact_jit(net, X, Y) +print(res) +res = mpact_jit(net, S, Y) +print(res) +res = mpact_jit(net, X, S) +print(res) +res = mpact_jit(net, S, S) +print_sparse(res)
diff --git a/test/python/mm.py b/test/python/mm.py new file mode 100644 index 0000000..3c51c37 --- /dev/null +++ b/test/python/mm.py
@@ -0,0 +1,89 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + +from mpact.models.kernels import MMNet + + +def print_sparse(res): + print(res[0]) + print(res[1]) + print(res[2]) + + +net = MMNet() + +# Construct dense and sparse matrices. +X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) +Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) +A = torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 0.0, 0.0, 0.0], + [3.0, 0.0, 0.0, 0.0], + ], + dtype=torch.float32, +) +S = A.to_sparse_csr() + +# +# CHECK: pytorch +# CHECK: tensor({{\[}}[ 152., 158., 164., 170.], +# CHECK: [ 504., 526., 548., 570.], +# CHECK: [ 856., 894., 932., 970.], +# CHECK: [1208., 1262., 1316., 1370.]{{\]}}) +# CHECK: tensor({{\[}}[20., 21., 22., 23.], +# CHECK: [56., 58., 60., 62.], +# CHECK: [ 0., 0., 0., 0.], +# CHECK: [48., 51., 54., 57.]{{\]}}) +# CHECK: tensor({{\[}}[ 9., 0., 0., 2.], +# CHECK: [21., 4., 0., 10.], +# CHECK: [33., 8., 0., 18.], +# CHECK: [45., 12., 0., 26.]{{\]}}) +# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), +# CHECK: col_indices=tensor([3, 0, 1]), +# CHECK: values=tensor([2., 6., 3.]), size=(4, 4), nnz=3, +# CHECK: layout=torch.sparse_csr) +# CHECK: mpact +# CHECK: {{\[}}[ 152. 158. 164. 170.] +# CHECK: [ 504. 526. 548. 570.] +# CHECK: [ 856. 894. 932. 970.] +# CHECK: [1208. 1262. 1316. 1370.]{{\]}} +# CHECK: {{\[}}[20. 21. 22. 23.] +# CHECK: [56. 58. 60. 62.] +# CHECK: [ 0. 0. 0. 0.] +# CHECK: [48. 51. 54. 57.]{{\]}} +# CHECK: {{\[}}[ 9. 0. 0. 2.] +# CHECK: [21. 4. 0. 10.] +# CHECK: [33. 8. 0. 18.] +# CHECK: [45. 12. 0. 26.]{{\]}} +# CHECK: [0 1 2 2 3] +# CHECK: [3 0 1] +# CHECK: [2. 6. 3.] +# + +# Run it with PyTorch. +print("pytorch") +res = net(X, Y) +print(res) +res = net(S, Y) +print(res) +res = net(X, S) +print(res) +res = net(S, S) +print(res) + +# Run it with MPACT. +print("mpact") +res = mpact_jit(net, X, Y) +print(res) +res = mpact_jit(net, S, Y) +print(res) +res = mpact_jit(net, X, S) +print(res) +res = mpact_jit(net, S, S) +print_sparse(res)
diff --git a/test/python/mul.py b/test/python/mul.py new file mode 100644 index 0000000..fd8692f --- /dev/null +++ b/test/python/mul.py
@@ -0,0 +1,87 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + +from mpact.models.kernels import MulNet + + +def print_sparse(res): + print(res[0]) + print(res[1]) + print(res[2]) + + +net = MulNet() + +# Construct dense and sparse matrices. +X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) +Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) +A = torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 2.0], + [0.0, 0.0, 0.0, 0.0], + [3.0, 0.0, 0.0, 0.0], + ], + dtype=torch.float32, +) +S = A.to_sparse_csr() + +# +# CHECK: pytorch +# CHECK: tensor({{\[}}[ 0., 17., 36., 57.], +# CHECK: [ 80., 105., 132., 161.], +# CHECK: [192., 225., 260., 297.], +# CHECK: [336., 377., 420., 465.]{{\]}}) +# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), +# CHECK: col_indices=tensor([1, 3, 0]), +# CHECK: values=tensor([17., 46., 84.]), size=(4, 4), nnz=3, +# CHECK: layout=torch.sparse_csr) +# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), +# CHECK: col_indices=tensor([1, 3, 0]), +# CHECK: values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3, +# CHECK: layout=torch.sparse_csr) +# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), +# CHECK: col_indices=tensor([1, 3, 0]), +# CHECK: values=tensor([1., 4., 9.]), size=(4, 4), nnz=3, +# CHECK: layout=torch.sparse_csr) +# CHECK: mpact +# CHECK: {{\[}}[ 0. 17. 36. 57.] +# CHECK: [ 80. 105. 132. 161.] +# CHECK: [192. 225. 260. 297.] +# CHECK: [336. 377. 420. 465.]{{\]}} +# CHECK: [0 1 2 2 3] +# CHECK: [1 3 0] +# CHECK: [17. 46. 84.] +# CHECK: [0 1 2 2 3] +# CHECK: [1 3 0] +# CHECK: [ 1. 14. 36.] +# CHECK: [0 1 2 2 3] +# CHECK: [1 3 0] +# CHECK: [1. 4. 9.] +# + +# Run it with PyTorch. +print("pytorch") +res = net(X, Y) +print(res) +res = net(S, Y) +print(res) +res = net(X, S) +print(res) +res = net(S, S) +print(res) + +# Run it with MPACT. +print("mpact") +res = mpact_jit(net, X, Y) +print(res) +res = mpact_jit(net, S, Y) +print_sparse(res) +res = mpact_jit(net, X, S) +print_sparse(res) +res = mpact_jit(net, S, S) +print_sparse(res)
diff --git a/test/python/sddmm.py b/test/python/sddmm.py index 8793120..a0890c7 100644 --- a/test/python/sddmm.py +++ b/test/python/sddmm.py
@@ -7,6 +7,14 @@ from mpact.models.kernels import MMNet, SDDMMNet + +def print_sparse(res): + print(res[0]) + print(res[1]) + print(res[2]) + print(res[3]) + + mmnet = MMNet() sddmmnet = SDDMMNet() @@ -54,7 +62,4 @@ dense = mpact_jit(mmnet, A, B) print(dense) res = mpact_jit(sddmmnet, S, A, B) -print(res[0]) -print(res[1]) -print(res[2]) -print(res[3]) +print_sparse(res)