| # RUN: %PYTHON %s | FileCheck %s | 
 |  | 
 | import torch | 
 | import numpy as np | 
 |  | 
 | from mpact.mpactbackend import mpact_jit | 
 |  | 
 | from mpact.models.kernels import MulNet | 
 |  | 
 |  | 
 | def print_sparse(res): | 
 |     print(res[0]) | 
 |     print(res[1]) | 
 |     print(res[2]) | 
 |  | 
 |  | 
 | net = MulNet() | 
 |  | 
 | # Construct dense and sparse matrices. | 
 | X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) | 
 | Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) | 
 | A = torch.tensor( | 
 |     [ | 
 |         [0.0, 1.0, 0.0, 0.0], | 
 |         [0.0, 0.0, 0.0, 2.0], | 
 |         [0.0, 0.0, 0.0, 0.0], | 
 |         [3.0, 0.0, 0.0, 0.0], | 
 |     ], | 
 |     dtype=torch.float32, | 
 | ) | 
 | S = A.to_sparse_csr() | 
 |  | 
 | # | 
 | # CHECK: pytorch | 
 | # CHECK: tensor({{\[}}[  0.,  17.,  36.,  57.], | 
 | # CHECK:              [ 80., 105., 132., 161.], | 
 | # CHECK:              [192., 225., 260., 297.], | 
 | # CHECK:              [336., 377., 420., 465.]{{\]}}) | 
 | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
 | # CHECK:        col_indices=tensor([1, 3, 0]), | 
 | # CHECK:        values=tensor([17., 46., 84.]), size=(4, 4), nnz=3, | 
 | # CHECK:        layout=torch.sparse_csr) | 
 | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
 | # CHECK:        col_indices=tensor([1, 3, 0]), | 
 | # CHECK:        values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3, | 
 | # CHECK:        layout=torch.sparse_csr) | 
 | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
 | # CHECK:        col_indices=tensor([1, 3, 0]), | 
 | # CHECK:        values=tensor([1., 4., 9.]), size=(4, 4), nnz=3, | 
 | # CHECK:        layout=torch.sparse_csr) | 
 | # CHECK: mpact | 
 | # CHECK:   {{\[}}[  0.  17.  36.  57.] | 
 | # CHECK:         [ 80. 105. 132. 161.] | 
 | # CHECK:         [192. 225. 260. 297.] | 
 | # CHECK:         [336. 377. 420. 465.]{{\]}} | 
 | # CHECK:  [0 1 2 2 3] | 
 | # CHECK:  [1 3 0] | 
 | # CHECK:  [17. 46. 84.] | 
 | # CHECK:  [0 1 2 2 3] | 
 | # CHECK:  [1 3 0] | 
 | # CHECK:  [ 1. 14. 36.] | 
 | # CHECK:  [0 1 2 2 3] | 
 | # CHECK:  [1 3 0] | 
 | # CHECK:  [1. 4. 9.] | 
 | # | 
 |  | 
 | # Run it with PyTorch. | 
 | print("pytorch") | 
 | res = net(X, Y) | 
 | print(res) | 
 | res = net(S, Y) | 
 | print(res) | 
 | res = net(X, S) | 
 | print(res) | 
 | res = net(S, S) | 
 | print(res) | 
 |  | 
 | # Run it with MPACT. | 
 | print("mpact") | 
 | res = mpact_jit(net, X, Y) | 
 | print(res) | 
 | res = mpact_jit(net, S, Y) | 
 | print_sparse(res) | 
 | res = mpact_jit(net, X, S) | 
 | print_sparse(res) | 
 | res = mpact_jit(net, S, S) | 
 | print_sparse(res) |