[mpact][compiler] add gcn model (with test) (#41)
* [mpact][compiler] add gcn model (with test)
* delete print
* Python styleguide on block comments
diff --git a/python/mpact/models/gcn.py b/python/mpact/models/gcn.py
index df9b198..db63cb4 100644
--- a/python/mpact/models/gcn.py
+++ b/python/mpact/models/gcn.py
@@ -1,4 +1,5 @@
import torch
+import torch.nn.functional as F
class GraphConv(torch.nn.Module):
@@ -19,5 +20,28 @@
return output
+class GCN(torch.nn.Module):
+ """
+ Graph Convolutional Network (GCN) inspired by <https://arxiv.org/pdf/1609.02907.pdf>.
+ """
+
+ def __init__(self, input_dim, hidden_dim, output_dim, dropout_p=0.1):
+ super(GCN, self).__init__()
+ self.gc1 = GraphConv(input_dim, hidden_dim)
+ self.gc2 = GraphConv(hidden_dim, output_dim)
+ self.dropout = torch.nn.Dropout(dropout_p)
+
+ def forward(self, input_tensor, adj_mat):
+ x = self.gc1(input_tensor, adj_mat)
+ x = F.relu(x)
+ x = self.dropout(x)
+ x = self.gc2(x, adj_mat)
+ return F.log_softmax(x, dim=1)
+
+
def graphconv44():
return GraphConv(input_dim=4, output_dim=4)
+
+
+def gcn4164():
+ return GCN(input_dim=4, hidden_dim=16, output_dim=4)
diff --git a/test/python/gcn.py b/test/python/gcn.py
index d44f1ea..3a4a1d2 100644
--- a/test/python/gcn.py
+++ b/test/python/gcn.py
@@ -4,9 +4,10 @@
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
-from mpact.models.gcn import graphconv44
+from mpact.models.gcn import graphconv44, gcn4164
net = graphconv44()
+net.eval() # Switch to inference.
# Get random (but reproducible) matrices.
torch.manual_seed(0)
@@ -60,3 +61,34 @@
print("mpact run")
res = mpact_jit_run(invoker, fn, inp, adj_mat)
print(res)
+
+net = gcn4164()
+net.eval() # Switch to inference.
+
+
+# Sparse input.
+idx = torch.tensor([[0, 0, 1, 2], [0, 2, 3, 1]], dtype=torch.int64)
+val = torch.tensor([14.0, 3.0, -8.0, 11.0], dtype=torch.float32)
+S = torch.sparse_coo_tensor(idx, val, size=[4, 4])
+
+#
+# CHECK: pytorch gcn
+# CHECK: tensor({{\[}}[-1.3863, -1.3863, -1.3863, -1.3863],
+# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863],
+# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863],
+# CHECK: [-1.3863, -1.3863, -1.3863, -1.3863]])
+# CHECK: mpact gcn
+# CHECK: {{\[}}[-1.3862944 -1.3862944 -1.3862944 -1.3862944]
+# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]
+# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]
+# CHECK: [-1.3862944 -1.3862944 -1.3862944 -1.3862944]{{\]}}
+#
+with torch.no_grad():
+ # Run it with PyTorch.
+ print("pytorch gcn")
+ res = net(S, adj_mat)
+ print(res)
+
+ print("mpact gcn")
+ res = mpact_jit(net, S, adj_mat)
+ print(res)
diff --git a/test/python/resnet.py b/test/python/resnet.py
index cc8e7e1..fe8cc3c 100644
--- a/test/python/resnet.py
+++ b/test/python/resnet.py
@@ -8,7 +8,7 @@
from mpact.models.resnet import resnet20
resnet = resnet20()
-resnet.train(False) # switch to inference
+resnet.eval() # Switch to inference.
# Get a random input.
# B x RGB x H x W