| # RUN: %PYTHON %s | FileCheck %s |
| |
| import torch |
| import numpy as np |
| |
| from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run |
| |
| from mpact.models.kernels import MMNet, SDDMMNet |
| |
| |
| def print_sparse(res): |
| print(res[0]) |
| print(res[1]) |
| print(res[2]) |
| print(res[3]) |
| |
| |
| mmnet = MMNet() |
| sddmmnet = SDDMMNet() |
| |
| # Construct very sparse matrix. |
| idx = torch.tensor([[0, 4], [0, 4]], dtype=torch.int64) |
| val = torch.tensor([2.0, 3.0], dtype=torch.float64) |
| S = torch.sparse_coo_tensor(idx, val, size=[5, 5]) |
| |
| # Trivial dense inputs. |
| A = torch.arange(0, 25, dtype=torch.float32).view(5, 5) |
| B = torch.arange(25, 50, dtype=torch.float32).view(5, 5) |
| |
| # |
| # CHECK: pytorch |
| # CHECK: tensor({{\[}}[ 400., 410., 420., 430., 440.], |
| # CHECK: [1275., 1310., 1345., 1380., 1415.], |
| # CHECK: [2150., 2210., 2270., 2330., 2390.], |
| # CHECK: [3025., 3110., 3195., 3280., 3365.], |
| # CHECK: [3900., 4010., 4120., 4230., 4340.]{{\]}}) |
| # CHECK: tensor(indices=tensor({{\[}}[0, 4], |
| # CHECK: [0, 4]{{\]}}), |
| # CHECK: values=tensor([ 800., 13020.]), |
| # CHECK: size=(5, 5), nnz=2, dtype=torch.float64, layout=torch.sparse_coo) |
| # CHECK: mpact |
| # CHECK: {{\[}}[ 400. 410. 420. 430. 440.] |
| # CHECK: [1275. 1310. 1345. 1380. 1415.] |
| # CHECK: [2150. 2210. 2270. 2330. 2390.] |
| # CHECK: [3025. 3110. 3195. 3280. 3365.] |
| # CHECK: [3900. 4010. 4120. 4230. 4340.]{{\]}} |
| # CHECK: [0 2] |
| # CHECK: [0 4] |
| # CHECK: [0 4] |
| # CHECK: [ 800. 13020.] |
| # |
| |
| # Run it with PyTorch. |
| print("pytorch") |
| dense = mmnet(A, B) |
| print(dense) |
| res = sddmmnet(S, A, B) |
| print(res) |
| |
| # Run it with MPACT. |
| print("mpact") |
| dense = mpact_jit(mmnet, A, B) |
| print(dense) |
| res = mpact_jit(sddmmnet, S, A, B) |
| print_sparse(res) |