blob: a0890c71dd050feeb9b880f94fb13601a82dcf12 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
from mpact.models.kernels import MMNet, SDDMMNet
def print_sparse(res):
print(res[0])
print(res[1])
print(res[2])
print(res[3])
mmnet = MMNet()
sddmmnet = SDDMMNet()
# Construct very sparse matrix.
idx = torch.tensor([[0, 4], [0, 4]], dtype=torch.int64)
val = torch.tensor([2.0, 3.0], dtype=torch.float64)
S = torch.sparse_coo_tensor(idx, val, size=[5, 5])
# Trivial dense inputs.
A = torch.arange(0, 25, dtype=torch.float32).view(5, 5)
B = torch.arange(25, 50, dtype=torch.float32).view(5, 5)
#
# CHECK: pytorch
# CHECK: tensor({{\[}}[ 400., 410., 420., 430., 440.],
# CHECK: [1275., 1310., 1345., 1380., 1415.],
# CHECK: [2150., 2210., 2270., 2330., 2390.],
# CHECK: [3025., 3110., 3195., 3280., 3365.],
# CHECK: [3900., 4010., 4120., 4230., 4340.]{{\]}})
# CHECK: tensor(indices=tensor({{\[}}[0, 4],
# CHECK: [0, 4]{{\]}}),
# CHECK: values=tensor([ 800., 13020.]),
# CHECK: size=(5, 5), nnz=2, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: mpact
# CHECK: {{\[}}[ 400. 410. 420. 430. 440.]
# CHECK: [1275. 1310. 1345. 1380. 1415.]
# CHECK: [2150. 2210. 2270. 2330. 2390.]
# CHECK: [3025. 3110. 3195. 3280. 3365.]
# CHECK: [3900. 4010. 4120. 4230. 4340.]{{\]}}
# CHECK: [0 2]
# CHECK: [0 4]
# CHECK: [0 4]
# CHECK: [ 800. 13020.]
#
# Run it with PyTorch.
print("pytorch")
dense = mmnet(A, B)
print(dense)
res = sddmmnet(S, A, B)
print(res)
# Run it with MPACT.
print("mpact")
dense = mpact_jit(mmnet, A, B)
print(dense)
res = mpact_jit(sddmmnet, S, A, B)
print_sparse(res)