blob: 662e11a0cf0058dc415b936de1f356e976ab4615 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import MVNet
net = MVNet()
# Get a fixed vector and matrix (which we make 2x2 block "sparse").
dense_vector = torch.arange(1, 11, dtype=torch.float32)
dense_input = torch.arange(1, 101, dtype=torch.float32).view(10, 10)
sparse_matrix = dense_input.to_sparse_bsr(blocksize=(2, 2))
#
# CHECK: pytorch
# CHECK: tensor([ 385., 935., 1485., 2035., 2585., 3135., 3685., 4235., 4785., 5335.])
# CHECK: mpact
# CHECK: [ 385. 935. 1485. 2035. 2585. 3135. 3685. 4235. 4785. 5335.]
#
# Run it with PyTorch.
print("pytorch")
res = net(sparse_matrix, dense_vector)
print(res)
# Run it with MPACT.
print("mpact")
res = mpact_jit(net, sparse_matrix, dense_vector)
print(res)