blob: 14f505c63a82ccf98504e1e2d27adc2601d43177 [file] [log] [blame]
import torch
class MVNet(torch.nn.Module):
def forward(self, x, v):
return torch.mv(x, v)
class MMNet(torch.nn.Module):
def forward(self, x, y):
return torch.mm(x, y)
class AddNet(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
class MulNet(torch.nn.Module):
def forward(self, x, y):
return torch.mul(x, y)
class SelfNet(torch.nn.Module):
def forward(self, x):
return x
class SDDMMNet(torch.nn.Module):
def forward(self, x, y, z):
return torch.mul(x, torch.mm(y, z))
class SqSum(torch.nn.Module):
def forward(self, x):
return (x * x).sum()
class CountEq(torch.nn.Module):
def forward(self, x, s):
nums = (x == s).sum()
return nums
class FeatureScale(torch.nn.Module):
def forward(self, F):
sum_vector = torch.sum(F, dim=1)
reciprocal_vector = 1 / sum_vector
reciprocal_vector[reciprocal_vector == float("inf")] = 0
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
return scaling_diagonal @ F
class Normalization(torch.nn.Module):
def forward(self, A):
sum_vector = torch.sum(A, dim=1)
reciprocal_vector = 1 / sum_vector
reciprocal_vector[reciprocal_vector == float("inf")] = 0
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
return scaling_diagonal @ A @ scaling_diagonal
class SimpleNet(torch.nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
# Model parameters (weights and biases of linear layers).
self.fc1 = torch.nn.Linear(16, 8)
self.fc2 = torch.nn.Linear(8, 4)
self.fc3 = torch.nn.Linear(4, 2)
def forward(self, x):
x = x.view(-1, 16)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
return self.fc3(x) # assumes: softmax in loss function