|  | # RUN: %PYTHON %s | FileCheck %s | 
|  |  | 
|  | import torch | 
|  | import numpy as np | 
|  |  | 
|  | from mpact.mpactbackend import mpact_jit | 
|  |  | 
|  | from mpact.models.kernels import MulNet | 
|  |  | 
|  |  | 
|  | def print_sparse(res): | 
|  | print(res[0]) | 
|  | print(res[1]) | 
|  | print(res[2]) | 
|  |  | 
|  |  | 
|  | net = MulNet() | 
|  |  | 
|  | # Construct dense and sparse matrices. | 
|  | X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) | 
|  | Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) | 
|  | A = torch.tensor( | 
|  | [ | 
|  | [0.0, 1.0, 0.0, 0.0], | 
|  | [0.0, 0.0, 0.0, 2.0], | 
|  | [0.0, 0.0, 0.0, 0.0], | 
|  | [3.0, 0.0, 0.0, 0.0], | 
|  | ], | 
|  | dtype=torch.float32, | 
|  | ) | 
|  | S = A.to_sparse_csr() | 
|  |  | 
|  | # | 
|  | # CHECK: pytorch | 
|  | # CHECK: tensor({{\[}}[  0.,  17.,  36.,  57.], | 
|  | # CHECK:              [ 80., 105., 132., 161.], | 
|  | # CHECK:              [192., 225., 260., 297.], | 
|  | # CHECK:              [336., 377., 420., 465.]{{\]}}) | 
|  | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
|  | # CHECK:        col_indices=tensor([1, 3, 0]), | 
|  | # CHECK:        values=tensor([17., 46., 84.]), size=(4, 4), nnz=3, | 
|  | # CHECK:        layout=torch.sparse_csr) | 
|  | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
|  | # CHECK:        col_indices=tensor([1, 3, 0]), | 
|  | # CHECK:        values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3, | 
|  | # CHECK:        layout=torch.sparse_csr) | 
|  | # CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]), | 
|  | # CHECK:        col_indices=tensor([1, 3, 0]), | 
|  | # CHECK:        values=tensor([1., 4., 9.]), size=(4, 4), nnz=3, | 
|  | # CHECK:        layout=torch.sparse_csr) | 
|  | # CHECK: mpact | 
|  | # CHECK:   {{\[}}[  0.  17.  36.  57.] | 
|  | # CHECK:         [ 80. 105. 132. 161.] | 
|  | # CHECK:         [192. 225. 260. 297.] | 
|  | # CHECK:         [336. 377. 420. 465.]{{\]}} | 
|  | # CHECK:  [0 1 2 2 3] | 
|  | # CHECK:  [1 3 0] | 
|  | # CHECK:  [17. 46. 84.] | 
|  | # CHECK:  [0 1 2 2 3] | 
|  | # CHECK:  [1 3 0] | 
|  | # CHECK:  [ 1. 14. 36.] | 
|  | # CHECK:  [0 1 2 2 3] | 
|  | # CHECK:  [1 3 0] | 
|  | # CHECK:  [1. 4. 9.] | 
|  | # | 
|  |  | 
|  | # Run it with PyTorch. | 
|  | print("pytorch") | 
|  | res = net(X, Y) | 
|  | print(res) | 
|  | res = net(S, Y) | 
|  | print(res) | 
|  | res = net(X, S) | 
|  | print(res) | 
|  | res = net(S, S) | 
|  | print(res) | 
|  |  | 
|  | # Run it with MPACT. | 
|  | print("mpact") | 
|  | res = mpact_jit(net, X, Y) | 
|  | print(res) | 
|  | res = mpact_jit(net, S, Y) | 
|  | print_sparse(res) | 
|  | res = mpact_jit(net, X, S) | 
|  | print_sparse(res) | 
|  | res = mpact_jit(net, S, S) | 
|  | print_sparse(res) |