blob: d1d00a049227e0d433e61dc5104649fbd08d4772 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
class GraphConv(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConv, self).__init__()
self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim))
torch.nn.init.ones_(self.kernel)
self.bias = torch.nn.Parameter(torch.Tensor(output_dim))
torch.nn.init.ones_(self.bias)
def forward(self, inp, adj_mat):
# Input matrix times weight matrix.
support = torch.mm(inp, self.kernel)
# Sparse adjacency matrix times support matrix.
output = torch.spmm(adj_mat, support)
# Add bias.
output = output + self.bias
return output
net = GraphConv(4, 4)
# Get random (but reproducible) matrices.
torch.manual_seed(0)
inp = torch.rand(4, 4)
adj_mat = torch.rand(4, 4).to_sparse()
#
# CHECK: pytorch
# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778],
# CHECK: [5.7502, 5.7502, 5.7502, 5.7502],
# CHECK: [4.6980, 4.6980, 4.6980, 4.6980],
# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}})
# CHECK: mpact compile and run
# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
# CHECK: mpact compile
# CHECK: mpact run
# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
#
with torch.no_grad():
# Run it with PyTorch.
print("pytorch")
res = net(inp, adj_mat)
print(res)
# Run it with MPACT (compile and run at once).
print("mpact compile and run")
res = mpact_jit(net, inp, adj_mat)
print(res)
# Run it with MPACT (with separate compile and run steps).
print("mpact compile")
invoker, fn = mpact_jit_compile(net, inp, adj_mat)
print("mpact run")
res = mpact_jit_run(invoker, fn, inp, adj_mat)
print(res)