| # RUN: %PYTHON %s | FileCheck %s |
| |
| import torch |
| import numpy as np |
| |
| from mpact.mpactbackend import mpact_jit |
| |
| from mpact.models.kernels import MMNet |
| |
| |
| def print_sparse(res): |
| print(res[0]) |
| print(res[1]) |
| print(res[2]) |
| |
| |
| net = MMNet() |
| |
| # Construct dense and sparse matrices. |
| X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) |
| Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) |
| A = torch.tensor( |
| [ |
| [0.0, 1.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0, 2.0], |
| [0.0, 0.0, 0.0, 0.0], |
| [3.0, 0.0, 0.0, 0.0], |
| ], |
| dtype=torch.float32, |
| ) |
| S = A.to_sparse_csr() |
| |
| # |
| # CHECK: pytorch |
| # CHECK: tensor({{\[}}[ 152., 158., 164., 170.], |
| # CHECK: [ 504., 526., 548., 570.], |
| # CHECK: [ 856., 894., 932., 970.], |
| # CHECK: [1208., 1262., 1316., 1370.]{{\]}}) |
| # CHECK: tensor({{\[}}[20., 21., 22., 23.], |
| # CHECK: [56., 58., 60., 62.], |
| # CHECK: [ 0., 0., 0., 0.], |
| # CHECK: [48., 51., 54., 57.]{{\]}}) |
| # CHECK: tensor({{\[}}[ 9., 0., 0., 2.], |
| # CHECK: [21., 4., 0., 10.], |
| # CHECK: [33., 8., 0., 18.], |
| # CHECK: [45., 12., 0., 26.]{{\]}}) |
| # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), |
| # CHECK: col_indices=tensor([3, 0, 1]), |
| # CHECK: values=tensor([2., 6., 3.]), size=(4, 4), nnz=3, |
| # CHECK: layout=torch.sparse_csr) |
| # CHECK: mpact |
| # CHECK: {{\[}}[ 152. 158. 164. 170.] |
| # CHECK: [ 504. 526. 548. 570.] |
| # CHECK: [ 856. 894. 932. 970.] |
| # CHECK: [1208. 1262. 1316. 1370.]{{\]}} |
| # CHECK: {{\[}}[20. 21. 22. 23.] |
| # CHECK: [56. 58. 60. 62.] |
| # CHECK: [ 0. 0. 0. 0.] |
| # CHECK: [48. 51. 54. 57.]{{\]}} |
| # CHECK: {{\[}}[ 9. 0. 0. 2.] |
| # CHECK: [21. 4. 0. 10.] |
| # CHECK: [33. 8. 0. 18.] |
| # CHECK: [45. 12. 0. 26.]{{\]}} |
| # CHECK: [0 1 2 2 3] |
| # CHECK: [3 0 1] |
| # CHECK: [2. 6. 3.] |
| # |
| |
| # Run it with PyTorch. |
| print("pytorch") |
| res = net(X, Y) |
| print(res) |
| res = net(S, Y) |
| print(res) |
| res = net(X, S) |
| print(res) |
| res = net(S, S) |
| print(res) |
| |
| # Run it with MPACT. |
| print("mpact") |
| res = mpact_jit(net, X, Y) |
| print(res) |
| res = mpact_jit(net, S, Y) |
| print(res) |
| res = mpact_jit(net, X, S) |
| print(res) |
| res = mpact_jit(net, S, S) |
| print_sparse(res) |