blob: 34b886d126b91dd25fe2917e3ab15615d876d83e [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import MulNet
def print_sparse(res):
print(res[0])
print(res[1])
print(res[2])
net = MulNet()
# Construct dense and sparse matrices.
X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 0.0, 0.0],
[3.0, 0.0, 0.0, 0.0],
],
dtype=torch.float32,
)
S = A.to_sparse_csr()
#
# CHECK: pytorch
# CHECK: tensor({{\[}}[ 0., 17., 36., 57.],
# CHECK: [ 80., 105., 132., 161.],
# CHECK: [192., 225., 260., 297.],
# CHECK: [336., 377., 420., 465.]{{\]}})
# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
# CHECK: col_indices=tensor([1, 3, 0]),
# CHECK: values=tensor([17., 46., 84.]), size=(4, 4), nnz=3,
# CHECK: layout=torch.sparse_csr)
# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
# CHECK: col_indices=tensor([1, 3, 0]),
# CHECK: values=tensor([ 1., 14., 36.]), size=(4, 4), nnz=3,
# CHECK: layout=torch.sparse_csr)
# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
# CHECK: col_indices=tensor([1, 3, 0]),
# CHECK: values=tensor([1., 4., 9.]), size=(4, 4), nnz=3,
# CHECK: layout=torch.sparse_csr)
# CHECK: mpact
# CHECK: {{\[}}[ 0. 17. 36. 57.]
# CHECK: [ 80. 105. 132. 161.]
# CHECK: [192. 225. 260. 297.]
# CHECK: [336. 377. 420. 465.]{{\]}}
# CHECK: [0 1 2 2 3]
# CHECK: [1 3 0]
# CHECK: [17. 46. 84.]
# CHECK: [0 1 2 2 3]
# CHECK: [1 3 0]
# CHECK: [ 1. 14. 36.]
# CHECK: [0 1 2 2 3]
# CHECK: [1 3 0]
# CHECK: [1. 4. 9.]
#
# Run it with PyTorch.
print("pytorch")
res = net(X, Y)
print(res)
res = net(S, Y)
print(res)
res = net(X, S)
print(res)
res = net(S, S)
print(res)
# Run it with MPACT.
print("mpact")
res = mpact_jit(net, X, Y)
print(res)
res = mpact_jit(net, S, Y)
print_sparse(res)
res = mpact_jit(net, X, S)
print_sparse(res)
res = mpact_jit(net, S, S)
print_sparse(res)