blob: 3cdd90acf1785dded3087ef0311696cbe3f332fd [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import CountEq
net = CountEq()
# Construct dense and sparse matrices.
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 1.0, 1.0],
[3.0, 0.0, 3.0, 0.0],
],
dtype=torch.float32,
)
# TODO: very interesting idiom to sparsify (collapse the sum
# into the eq for full sparsity), but needs PyTorch support
S = A
# S = A.to_sparse()
# S = A.to_sparse_csr()
#
# CHECK: pytorch
# CHECK: 10
# CHECK: 3
# CHECK: 1
# CHECK: 2
# CHECK: 0
# CHECK: mpact
# CHECK: 10
# CHECK: 3
# CHECK: 1
# CHECK: 2
# CHECK: 0
#
# Run it with PyTorch.
print("pytorch")
for i in range(5):
target = torch.tensor(i)
res = net(S, target).item()
print(res)
print("mpact")
for i in range(5):
target = torch.tensor(i)
res = mpact_jit(net, S, target)
print(res)