| # RUN: %PYTHON %s | FileCheck %s |
| |
| import torch |
| import numpy as np |
| |
| from mpact.mpactbackend import mpact_jit |
| |
| from mpact.models.kernels import AddNet |
| |
| |
| def print_sparse(res): |
| print(res[0]) |
| print(res[1]) |
| print(res[2]) |
| |
| |
| net = AddNet() |
| |
| # Construct dense and sparse matrices. |
| X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) |
| Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) |
| A = torch.tensor( |
| [ |
| [0.0, 1.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0, 2.0], |
| [0.0, 0.0, 0.0, 0.0], |
| [3.0, 0.0, 0.0, 0.0], |
| ], |
| dtype=torch.float32, |
| ) |
| S = A.to_sparse_csr() |
| |
| # |
| # CHECK: pytorch |
| # CHECK: tensor({{\[}}[16., 18., 20., 22.], |
| # CHECK: [24., 26., 28., 30.], |
| # CHECK: [32., 34., 36., 38.], |
| # CHECK: [40., 42., 44., 46.]{{\]}}) |
| # CHECK: tensor({{\[}}[16., 18., 18., 19.], |
| # CHECK: [20., 21., 22., 25.], |
| # CHECK: [24., 25., 26., 27.], |
| # CHECK: [31., 29., 30., 31.]{{\]}}) |
| # CHECK: tensor({{\[}}[ 0., 2., 2., 3.], |
| # CHECK: [ 4., 5., 6., 9.], |
| # CHECK: [ 8., 9., 10., 11.], |
| # CHECK: [15., 13., 14., 15.]{{\]}}) |
| # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), |
| # CHECK: col_indices=tensor([1, 3, 0]), |
| # CHECK: values=tensor([2., 4., 6.]), size=(4, 4), nnz=3, |
| # CHECK: layout=torch.sparse_csr) |
| # CHECK: mpact |
| # CHECK: {{\[}}[16. 18. 20. 22.] |
| # CHECK: [24. 26. 28. 30.] |
| # CHECK: [32. 34. 36. 38.] |
| # CHECK: [40. 42. 44. 46.]{{\]}} |
| # CHECK: {{\[}}[16. 18. 18. 19.] |
| # CHECK: [20. 21. 22. 25.] |
| # CHECK: [24. 25. 26. 27.] |
| # CHECK: [31. 29. 30. 31.]{{\]}} |
| # CHECK: {{\[}}[ 0. 2. 2. 3.] |
| # CHECK: [ 4. 5. 6. 9.] |
| # CHECK: [ 8. 9. 10. 11.] |
| # CHECK: [15. 13. 14. 15.]{{\]}} |
| # CHECK: [0 1 2 2 3] |
| # CHECK: [1 3 0] |
| # CHECK: [2. 4. 6.] |
| # |
| |
| # Run it with PyTorch. |
| print("pytorch", torch.__version__) |
| res = net(X, Y) |
| print(res) |
| res = net(S, Y) |
| print(res) |
| res = net(X, S) |
| print(res) |
| res = net(S, S) |
| print(res) |
| |
| # Run it with MPACT. |
| print("mpact") |
| res = mpact_jit(net, X, Y) |
| print(res) |
| res = mpact_jit(net, S, Y) |
| print(res) |
| res = mpact_jit(net, X, S) |
| print(res) |
| res = mpact_jit(net, S, S) |
| print_sparse(res) |