|  | # RUN: %PYTHON %s | FileCheck %s | 
|  |  | 
|  | import torch | 
|  | import numpy as np | 
|  |  | 
|  | from mpact.mpactbackend import mpact_jit | 
|  |  | 
|  | from mpact.models.kernels import AddNet | 
|  |  | 
|  |  | 
|  | def print_sparse(res): | 
|  | print(res[0]) | 
|  | print(res[1]) | 
|  | print(res[2]) | 
|  |  | 
|  |  | 
|  | net = AddNet() | 
|  |  | 
|  | # Construct dense and sparse matrices. | 
|  | X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) | 
|  | Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) | 
|  | A = torch.tensor( | 
|  | [ | 
|  | [0.0, 1.0, 0.0, 0.0], | 
|  | [0.0, 0.0, 0.0, 2.0], | 
|  | [0.0, 0.0, 0.0, 0.0], | 
|  | [3.0, 0.0, 0.0, 0.0], | 
|  | ], | 
|  | dtype=torch.float32, | 
|  | ) | 
|  | S = A.to_sparse_csr() | 
|  |  | 
|  | # | 
|  | # CHECK: pytorch | 
|  | # CHECK:   tensor({{\[}}[16., 18., 20., 22.], | 
|  | # CHECK:                [24., 26., 28., 30.], | 
|  | # CHECK:                [32., 34., 36., 38.], | 
|  | # CHECK:                [40., 42., 44., 46.]{{\]}}) | 
|  | # CHECK:  tensor({{\[}}[16., 18., 18., 19.], | 
|  | # CHECK:               [20., 21., 22., 25.], | 
|  | # CHECK:               [24., 25., 26., 27.], | 
|  | # CHECK:               [31., 29., 30., 31.]{{\]}}) | 
|  | # CHECK:  tensor({{\[}}[ 0.,  2.,  2.,  3.], | 
|  | # CHECK:               [ 4.,  5.,  6.,  9.], | 
|  | # CHECK:               [ 8.,  9., 10., 11.], | 
|  | # CHECK:               [15., 13., 14., 15.]{{\]}}) | 
|  | # CHECK:  tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
|  | # CHECK:         col_indices=tensor([1, 3, 0]), | 
|  | # CHECK:         values=tensor([2., 4., 6.]), size=(4, 4), nnz=3, | 
|  | # CHECK:         layout=torch.sparse_csr) | 
|  | # CHECK: mpact | 
|  | # CHECK:   {{\[}}[16. 18. 20. 22.] | 
|  | # CHECK:         [24. 26. 28. 30.] | 
|  | # CHECK:         [32. 34. 36. 38.] | 
|  | # CHECK:         [40. 42. 44. 46.]{{\]}} | 
|  | # CHECK:   {{\[}}[16. 18. 18. 19.] | 
|  | # CHECK:         [20. 21. 22. 25.] | 
|  | # CHECK:         [24. 25. 26. 27.] | 
|  | # CHECK:         [31. 29. 30. 31.]{{\]}} | 
|  | # CHECK:   {{\[}}[ 0.  2.  2.  3.] | 
|  | # CHECK:         [ 4.  5.  6.  9.] | 
|  | # CHECK:         [ 8.  9. 10. 11.] | 
|  | # CHECK:         [15. 13. 14. 15.]{{\]}} | 
|  | # CHECK:  [0 1 2 2 3] | 
|  | # CHECK:  [1 3 0] | 
|  | # CHECK:  [2. 4. 6.] | 
|  | # | 
|  |  | 
|  | # Run it with PyTorch. | 
|  | print("pytorch") | 
|  | res = net(X, Y) | 
|  | print(res) | 
|  | res = net(S, Y) | 
|  | print(res) | 
|  | res = net(X, S) | 
|  | print(res) | 
|  | res = net(S, S) | 
|  | print(res) | 
|  |  | 
|  | # Run it with MPACT. | 
|  | print("mpact") | 
|  | res = mpact_jit(net, X, Y) | 
|  | print(res) | 
|  | res = mpact_jit(net, S, Y) | 
|  | print(res) | 
|  | res = mpact_jit(net, X, S) | 
|  | print(res) | 
|  | res = mpact_jit(net, S, S) | 
|  | print_sparse(res) |