blob: d0bff7608a9f436bd6ab48ade66135c8b96b8c7f [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import AddNet
def print_sparse(res):
print(res[0])
print(res[1])
print(res[2])
net = AddNet()
# Construct dense and sparse matrices.
X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 0.0, 0.0],
[3.0, 0.0, 0.0, 0.0],
],
dtype=torch.float32,
)
S = A.to_sparse_csr()
#
# CHECK: pytorch
# CHECK: tensor({{\[}}[16., 18., 20., 22.],
# CHECK: [24., 26., 28., 30.],
# CHECK: [32., 34., 36., 38.],
# CHECK: [40., 42., 44., 46.]{{\]}})
# CHECK: tensor({{\[}}[16., 18., 18., 19.],
# CHECK: [20., 21., 22., 25.],
# CHECK: [24., 25., 26., 27.],
# CHECK: [31., 29., 30., 31.]{{\]}})
# CHECK: tensor({{\[}}[ 0., 2., 2., 3.],
# CHECK: [ 4., 5., 6., 9.],
# CHECK: [ 8., 9., 10., 11.],
# CHECK: [15., 13., 14., 15.]{{\]}})
# CHECK: tensor(crow_indices=tensor([0, 1, 2, 2, 3]),
# CHECK: col_indices=tensor([1, 3, 0]),
# CHECK: values=tensor([2., 4., 6.]), size=(4, 4), nnz=3,
# CHECK: layout=torch.sparse_csr)
# CHECK: mpact
# CHECK: {{\[}}[16. 18. 20. 22.]
# CHECK: [24. 26. 28. 30.]
# CHECK: [32. 34. 36. 38.]
# CHECK: [40. 42. 44. 46.]{{\]}}
# CHECK: {{\[}}[16. 18. 18. 19.]
# CHECK: [20. 21. 22. 25.]
# CHECK: [24. 25. 26. 27.]
# CHECK: [31. 29. 30. 31.]{{\]}}
# CHECK: {{\[}}[ 0. 2. 2. 3.]
# CHECK: [ 4. 5. 6. 9.]
# CHECK: [ 8. 9. 10. 11.]
# CHECK: [15. 13. 14. 15.]{{\]}}
# CHECK: [0 1 2 2 3]
# CHECK: [1 3 0]
# CHECK: [2. 4. 6.]
#
# Run it with PyTorch.
print("pytorch", torch.__version__)
res = net(X, Y)
print(res)
res = net(S, Y)
print(res)
res = net(X, S)
print(res)
res = net(S, S)
print(res)
# Run it with MPACT.
print("mpact")
res = mpact_jit(net, X, Y)
print(res)
res = mpact_jit(net, S, Y)
print(res)
res = mpact_jit(net, X, S)
print(res)
res = mpact_jit(net, S, S)
print_sparse(res)