blob: 8a015de0d6b93ca9c154ffd994a69c7d1d50236d [file] [log] [blame] [edit]
import torch
class GraphConv(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConv, self).__init__()
self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim))
torch.nn.init.ones_(self.kernel)
self.bias = torch.nn.Parameter(torch.Tensor(output_dim))
torch.nn.init.ones_(self.bias)
def forward(self, inp, adj_mat):
# Input matrix times weight matrix.
support = torch.mm(inp, self.kernel)
# Sparse adjacency matrix times support matrix.
output = torch.spmm(adj_mat, support)
# Add bias.
output = output + self.bias
return output