blob: 8658e4d65f49ec0de059ebe5aa4c3032e40ca30d [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import SimpleNet
from mpact.models.train import training_loop, num_all_parameters, num_parameters
A = torch.tensor(
[
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
],
[
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 0.0],
],
[
[0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
],
dtype=torch.float32,
)
B = torch.tensor(
[
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
],
[
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
],
dtype=torch.float32,
)
# Labels 0:sparse 1:dense
labA = torch.tensor([1, 0, 1, 0])
labB = torch.tensor([0, 1])
# A toy training and validation data set consisting of dense/sparse tensors.
class TrainData(Dataset):
def __len__(self):
return A.shape[0]
def __getitem__(self, index):
return A[index], labA[index]
class ValidationData(Dataset):
def __len__(self):
return B.shape[0]
def __getitem__(self, index):
return B[index], labB[index]
train_data = TrainData()
validation_data = ValidationData()
net = SimpleNet()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_function = torch.nn.CrossEntropyLoss()
train = DataLoader(train_data, batch_size=2)
validation = DataLoader(validation_data, batch_size=2)
# CHECK-LABEL: parameters
# CHECK-COUNT-2: 182
print("parameters")
print(num_all_parameters(net))
print(num_parameters(net))
# Run it with PyTorch.
# CHECK-LABEL: pytorch
# CHECK: Epoch 9
print("pytorch")
training_loop(net, optimizer, loss_function, train, validation, epochs=10)
# Run it with MPACT.
# CHECK-LABEL: mpact
print("mpact")
# TODO: teach MPACT about autograd