| import torch | 
 | import torch.nn.functional as F | 
 |  | 
 |  | 
 | class GraphAttentionLayer(torch.nn.Module): | 
 |     def __init__( | 
 |         self, | 
 |         in_features: int, | 
 |         out_features: int, | 
 |         n_heads: int, | 
 |         dropout: float = 0.4, | 
 |         leaky_relu_slope: float = 0.2, | 
 |     ): | 
 |         super(GraphAttentionLayer, self).__init__() | 
 |         self.n_heads = n_heads | 
 |         self.dropout = dropout | 
 |         self.n_hidden = out_features | 
 |         self.W = torch.nn.Parameter( | 
 |             torch.empty(size=(in_features, self.n_hidden * n_heads)) | 
 |         ) | 
 |         self.a = torch.nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1))) | 
 |         self.leakyrelu = torch.nn.LeakyReLU(leaky_relu_slope) | 
 |         self.softmax = torch.nn.Softmax(dim=1) | 
 |         torch.nn.init.ones_(self.W) | 
 |         torch.nn.init.ones_(self.a) | 
 |  | 
 |     def forward(self, h: torch.Tensor, adj_mat: torch.Tensor): | 
 |         n_nodes = h.shape[0] | 
 |         h_transformed = torch.mm(h, self.W) | 
 |         h_transformed = F.dropout(h_transformed, self.dropout, training=self.training) | 
 |         h_transformed = h_transformed.view( | 
 |             n_nodes, self.n_heads, self.n_hidden | 
 |         ).permute(1, 0, 2) | 
 |         e = self._get_attention_scores(h_transformed) | 
 |         connectivity_mask = -9e16 * torch.ones_like(e) | 
 |         e = torch.where(adj_mat > 0, e, connectivity_mask) | 
 |         attention = F.softmax(e, dim=-1) | 
 |         attention = F.dropout(attention, self.dropout, training=self.training) | 
 |         h_prime = torch.matmul(attention, h_transformed) | 
 |         return h_prime.mean(dim=0) | 
 |  | 
 |     def _get_attention_scores(self, h_transformed: torch.Tensor): | 
 |         source_scores = torch.matmul(h_transformed, self.a[:, : self.n_hidden, :]) | 
 |         target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden :, :]) | 
 |         e = source_scores + target_scores.mT | 
 |         return self.leakyrelu(e) | 
 |  | 
 |  | 
 | class GAT(torch.nn.Module): | 
 |     """ | 
 |     Graph Attention Network (GAT) inspired by <https://arxiv.org/pdf/1710.10903.pdf>. | 
 |     """ | 
 |  | 
 |     def __init__( | 
 |         self, | 
 |         in_features, | 
 |         n_hidden, | 
 |         n_heads, | 
 |         num_classes, | 
 |         dropout=0.4, | 
 |         leaky_relu_slope=0.2, | 
 |     ): | 
 |         super(GAT, self).__init__() | 
 |         self.gat1 = GraphAttentionLayer( | 
 |             in_features=in_features, | 
 |             out_features=n_hidden, | 
 |             n_heads=n_heads, | 
 |             dropout=dropout, | 
 |             leaky_relu_slope=leaky_relu_slope, | 
 |         ) | 
 |         self.gat2 = GraphAttentionLayer( | 
 |             in_features=n_hidden, | 
 |             out_features=num_classes, | 
 |             n_heads=1, | 
 |             dropout=dropout, | 
 |             leaky_relu_slope=leaky_relu_slope, | 
 |         ) | 
 |  | 
 |     def forward(self, input_tensor: torch.Tensor, adj_mat: torch.Tensor): | 
 |         x = self.gat1(input_tensor, adj_mat) | 
 |         x = F.elu(x) | 
 |         x = self.gat2(x, adj_mat) | 
 |         return F.log_softmax(x, dim=1) | 
 |  | 
 |  | 
 | def gat_4_64_8_3(): | 
 |     return GAT(in_features=4, n_hidden=64, n_heads=8, num_classes=3) |