| import torch | |
| class MVNet(torch.nn.Module): | |
| def forward(self, x, v): | |
| return torch.mv(x, v) | |
| class MMNet(torch.nn.Module): | |
| def forward(self, x, v): | |
| return torch.mm(x, v) | |
| class AddNet(torch.nn.Module): | |
| def forward(self, x, v): | |
| return torch.add(x, v) | |
| class MulNet(torch.nn.Module): | |
| def forward(self, x, v): | |
| return torch.mul(x, v) | |
| class SelfNet(torch.nn.Module): | |
| def forward(self, x): | |
| return x | |
| class SDDMMNet(torch.nn.Module): | |
| def forward(self, x, y, z): | |
| return torch.mul(x, torch.mm(y, z)) |