[mpact][compiler] add gat model (with test) (#42)

* [mpact][compiler] add gat model (with test)

* use underscore for readability of parameters

* extra underscore
diff --git a/benchmark/python/benchmarks/resnet_benchmark.py b/benchmark/python/benchmarks/resnet_benchmark.py
index b455ef0..89b317e 100644
--- a/benchmark/python/benchmarks/resnet_benchmark.py
+++ b/benchmark/python/benchmarks/resnet_benchmark.py
@@ -1,6 +1,6 @@
 import torch
 import numpy as np
-from mpact.models.resnet import resnet20
+from mpact.models.resnet import resnet_20
 from mpact_benchmark.utils.benchmark_utils import benchmark, Backends
 
 
@@ -27,7 +27,7 @@
 )
 def resnet() -> torch.nn.Module:
     """Restnet20 model."""
-    resnet_model = resnet20()
+    resnet_model = resnet_20()
     resnet_model.train(False)
     return resnet_model
 
diff --git a/python/mpact/models/gat.py b/python/mpact/models/gat.py
new file mode 100644
index 0000000..b8ab229
--- /dev/null
+++ b/python/mpact/models/gat.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn.functional as F
+
+
+class GraphAttentionLayer(torch.nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        n_heads: int,
+        dropout: float = 0.4,
+        leaky_relu_slope: float = 0.2,
+    ):
+        super(GraphAttentionLayer, self).__init__()
+        self.n_heads = n_heads
+        self.dropout = dropout
+        self.n_hidden = out_features
+        self.W = torch.nn.Parameter(
+            torch.empty(size=(in_features, self.n_hidden * n_heads))
+        )
+        self.a = torch.nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1)))
+        self.leakyrelu = torch.nn.LeakyReLU(leaky_relu_slope)
+        self.softmax = torch.nn.Softmax(dim=1)
+        torch.nn.init.ones_(self.W)
+        torch.nn.init.ones_(self.a)
+
+    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
+        n_nodes = h.shape[0]
+        h_transformed = torch.mm(h, self.W)
+        h_transformed = F.dropout(h_transformed, self.dropout, training=self.training)
+        h_transformed = h_transformed.view(
+            n_nodes, self.n_heads, self.n_hidden
+        ).permute(1, 0, 2)
+        e = self._get_attention_scores(h_transformed)
+        connectivity_mask = -9e16 * torch.ones_like(e)
+        e = torch.where(adj_mat > 0, e, connectivity_mask)
+        attention = F.softmax(e, dim=-1)
+        attention = F.dropout(attention, self.dropout, training=self.training)
+        h_prime = torch.matmul(attention, h_transformed)
+        return h_prime.mean(dim=0)
+
+    def _get_attention_scores(self, h_transformed: torch.Tensor):
+        source_scores = torch.matmul(h_transformed, self.a[:, : self.n_hidden, :])
+        target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden :, :])
+        e = source_scores + target_scores.mT
+        return self.leakyrelu(e)
+
+
+class GAT(torch.nn.Module):
+    """
+    Graph Attention Network (GAT) inspired by <https://arxiv.org/pdf/1710.10903.pdf>.
+    """
+
+    def __init__(
+        self,
+        in_features,
+        n_hidden,
+        n_heads,
+        num_classes,
+        dropout=0.4,
+        leaky_relu_slope=0.2,
+    ):
+        super(GAT, self).__init__()
+        self.gat1 = GraphAttentionLayer(
+            in_features=in_features,
+            out_features=n_hidden,
+            n_heads=n_heads,
+            dropout=dropout,
+            leaky_relu_slope=leaky_relu_slope,
+        )
+        self.gat2 = GraphAttentionLayer(
+            in_features=n_hidden,
+            out_features=num_classes,
+            n_heads=1,
+            dropout=dropout,
+            leaky_relu_slope=leaky_relu_slope,
+        )
+
+    def forward(self, input_tensor: torch.Tensor, adj_mat: torch.Tensor):
+        x = self.gat1(input_tensor, adj_mat)
+        x = F.elu(x)
+        x = self.gat2(x, adj_mat)
+        return F.log_softmax(x, dim=1)
+
+
+def gat_4_64_8_3():
+    return GAT(in_features=4, n_hidden=64, n_heads=8, num_classes=3)
diff --git a/python/mpact/models/gcn.py b/python/mpact/models/gcn.py
index db63cb4..c41e6d9 100644
--- a/python/mpact/models/gcn.py
+++ b/python/mpact/models/gcn.py
@@ -39,9 +39,9 @@
         return F.log_softmax(x, dim=1)
 
 
-def graphconv44():
+def graphconv_4_4():
     return GraphConv(input_dim=4, output_dim=4)
 
 
-def gcn4164():
+def gcn_4_16_4():
     return GCN(input_dim=4, hidden_dim=16, output_dim=4)
diff --git a/python/mpact/models/resnet.py b/python/mpact/models/resnet.py
index c3f5925..2556597 100644
--- a/python/mpact/models/resnet.py
+++ b/python/mpact/models/resnet.py
@@ -251,5 +251,5 @@
         return self._forward_impl(x)
 
 
-def resnet20():
+def resnet_20():
     return ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10)
diff --git a/test/python/gat.py b/test/python/gat.py
new file mode 100644
index 0000000..64a2e7f
--- /dev/null
+++ b/test/python/gat.py
@@ -0,0 +1,48 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.gat import gat_4_64_8_3
+
+net = gat_4_64_8_3()
+net.eval()  # Switch to inference.
+
+# Sparse input.
+idx = torch.tensor([[0, 0, 1, 2], [0, 2, 3, 1]], dtype=torch.int64)
+val = torch.tensor([14.0, 3.0, -8.0, 11.0], dtype=torch.float32)
+S = torch.sparse_coo_tensor(idx, val, size=[4, 4])
+
+# Construct adjacency matrix.
+V = 4
+edges = np.array([[0, 1], [0, 2], [1, 2], [1, 3], [2, 3]], dtype=np.int32)
+E = edges.shape[0]
+adj_mat = torch.sparse_coo_tensor(edges.T, torch.ones(E), (V, V), dtype=torch.int64)
+adj_mat = (
+    torch.eye(V) + adj_mat
+)  # Add self-loops to the adjacency matrix (becomes dense)
+
+
+#
+# CHECK: pytorch gat
+# CHECK:   tensor({{\[}}[-1.0986, -1.0986, -1.0986],
+# CHECK:                [-1.0986, -1.0986, -1.0986],
+# CHECK:                [-1.0986, -1.0986, -1.0986],
+# CHECK:                [-1.0986, -1.0986, -1.0986]{{\]}}
+# CHECK: mpact gat
+# CHECK:   {{\[}}[-1.0986123 -1.0986123 -1.0986123]
+# CHECK:         [-1.0986123 -1.0986123 -1.0986123]
+# CHECK:         [-1.0986123 -1.0986123 -1.0986123]
+# CHECK:         [-1.0986123 -1.0986123 -1.0986123]{{\]}}
+#
+with torch.no_grad():
+    # Run it with PyTorch.
+    print("pytorch gat")
+    res = net(S, adj_mat)
+    print(res)
+
+    print("mpact gat")
+    res = mpact_jit(net, S, adj_mat)
+    print(res)
diff --git a/test/python/gcn.py b/test/python/gcn.py
index 3a4a1d2..e51f7cf 100644
--- a/test/python/gcn.py
+++ b/test/python/gcn.py
@@ -4,9 +4,9 @@
 
 from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
 
-from mpact.models.gcn import graphconv44, gcn4164
+from mpact.models.gcn import graphconv_4_4, gcn_4_16_4
 
-net = graphconv44()
+net = graphconv_4_4()
 net.eval()  # Switch to inference.
 
 # Get random (but reproducible) matrices.
@@ -62,7 +62,7 @@
     res = mpact_jit_run(invoker, fn, inp, adj_mat)
     print(res)
 
-net = gcn4164()
+net = gcn_4_16_4()
 net.eval()  # Switch to inference.
 
 
diff --git a/test/python/resnet.py b/test/python/resnet.py
index fe8cc3c..7ac317b 100644
--- a/test/python/resnet.py
+++ b/test/python/resnet.py
@@ -5,9 +5,9 @@
 
 from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
 
-from mpact.models.resnet import resnet20
+from mpact.models.resnet import resnet_20
 
-resnet = resnet20()
+resnet = resnet_20()
 resnet.eval()  # Switch to inference.
 
 # Get a random input.