[mpact][compiler] add training loop to models with simple test (#60)
* [mpact][compiler] add training loop to models with simple test
Note that although MPACT currently does not support autograd yet,
eventually we need to support this too. The current PR adds a very
simple training loop to the models, together with a simple neural
network that uses the training loop to learn classification of
simple sparse/dense tensors in a toy training set.
* linter for darker (I tested with black?!)
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py
index d18d88a..71dd319 100644
--- a/python/mpact/models/kernels.py
+++ b/python/mpact/models/kernels.py
@@ -52,3 +52,18 @@
reciprocal_vector[reciprocal_vector == float("inf")] = 0
scaling_diagonal = torch.diag(reciprocal_vector).to_sparse()
return scaling_diagonal @ A @ scaling_diagonal
+
+
+class SimpleNet(torch.nn.Module):
+ def __init__(self):
+ super(SimpleNet, self).__init__()
+ # Model parameters (weights and biases of linear layers).
+ self.fc1 = torch.nn.Linear(16, 8)
+ self.fc2 = torch.nn.Linear(8, 4)
+ self.fc3 = torch.nn.Linear(4, 2)
+
+ def forward(self, x):
+ x = x.view(-1, 16)
+ x = torch.nn.functional.relu(self.fc1(x))
+ x = torch.nn.functional.relu(self.fc2(x))
+ return self.fc3(x) # assumes: softmax in loss function
diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py
new file mode 100644
index 0000000..791edcc
--- /dev/null
+++ b/python/mpact/models/train.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn.functional as F
+
+
+def training_loop(model, optimizer, loss_function, train, validation, epochs=10):
+ """A rudimentary PyTorch training loop for classification with training and validation data."""
+ for epoch in range(epochs):
+ # Switch to training mode.
+ model.train()
+ tloss = 0.0
+ num_train = len(train) # in batches
+ for inp, target in train: # batch loop (training)
+ optimizer.zero_grad()
+ output = model(inp)
+ loss = loss_function(output, target)
+ loss.backward()
+ optimizer.step()
+ tloss += loss.data.item()
+
+ # Switch to inference mode.
+ model.eval()
+ vloss = 0.0
+ num_validation = len(validation) # in batches
+ num_correct = 0
+ num_total = 0
+ for inp, target in validation: # batch loop (validation)
+ output = model(inp)
+ loss = loss_function(output, target)
+ vloss += loss.data.item()
+ correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], target).view(-1)
+ num_correct += torch.sum(correct).item()
+ num_total += correct.shape[0]
+
+ # Report stats.
+ print(
+ "Epoch {:d}, Training loss = {:.2f} #{:d}, Validation loss = {:.2f} #{:d}, Accuracy = {:.2f} #{:d}".format(
+ epoch,
+ (tloss / num_train) if num_train != 0 else 0,
+ num_train,
+ (vloss / num_validation) if num_validation != 0 else 0,
+ num_validation,
+ (num_correct / num_total) if num_total != 0 else 0,
+ num_total,
+ )
+ )
diff --git a/test/python/train_simple.py b/test/python/train_simple.py
new file mode 100644
index 0000000..dc9d0f6
--- /dev/null
+++ b/test/python/train_simple.py
@@ -0,0 +1,107 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from torch.utils.data import Dataset, DataLoader
+
+from mpact.mpactbackend import mpact_jit
+from mpact.models.kernels import SimpleNet
+from mpact.models.train import training_loop
+
+
+A = torch.tensor(
+ [
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [0.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ ],
+ [
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0, 0.0],
+ ],
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 0.0],
+ ],
+ [
+ [0.0, 0.0, 0.0, 1.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ ],
+ ],
+ dtype=torch.float32,
+)
+
+
+B = torch.tensor(
+ [
+ [
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ ],
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ ],
+ ],
+ dtype=torch.float32,
+)
+
+# Labels 0:sparse 1:dense
+
+labA = torch.tensor([1, 0, 1, 0])
+
+labB = torch.tensor([0, 1])
+
+# A toy training and validation data set consisting of dense/sparse tensors.
+
+
+class TrainData(Dataset):
+ def __len__(self):
+ return A.shape[0]
+
+ def __getitem__(self, index):
+ return A[index], labA[index]
+
+
+class ValidationData(Dataset):
+ def __len__(self):
+ return B.shape[0]
+
+ def __getitem__(self, index):
+ return B[index], labB[index]
+
+
+train_data = TrainData()
+validation_data = ValidationData()
+
+net = SimpleNet()
+optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
+loss_function = torch.nn.CrossEntropyLoss()
+train = DataLoader(train_data, batch_size=2)
+validation = DataLoader(validation_data, batch_size=2)
+
+
+# Run it with PyTorch.
+# CHECK-LABEL: pytorch
+# CHECK: Epoch 99
+# CHECK-SAME: Accuracy = 1.00
+print("pytorch")
+training_loop(net, optimizer, loss_function, train, validation, epochs=100)
+
+# Run it with MPACT.
+# CHECK-LABEL: mpact
+print("mpact")
+# TODO: teach MPACT about autograd