[mpact][compiler] add number of model parameters utility (#63)
Also addressed recent change in softmax that now
requires an explicit dimension.
diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py
index 791edcc..2fdaf68 100644
--- a/python/mpact/models/train.py
+++ b/python/mpact/models/train.py
@@ -2,6 +2,16 @@
import torch.nn.functional as F
+def num_all_parameters(model):
+ """Returns the number of all parameters in a model."""
+ return sum(p.numel() for p in model.parameters())
+
+
+def num_parameters(model):
+ """Returns the number of trainable parameters in a model."""
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
def training_loop(model, optimizer, loss_function, train, validation, epochs=10):
"""A rudimentary PyTorch training loop for classification with training and validation data."""
for epoch in range(epochs):
@@ -27,7 +37,9 @@
output = model(inp)
loss = loss_function(output, target)
vloss += loss.data.item()
- correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], target).view(-1)
+ correct = torch.eq(
+ torch.max(F.softmax(output, dim=1), dim=1)[1], target
+ ).view(-1)
num_correct += torch.sum(correct).item()
num_total += correct.shape[0]
diff --git a/test/python/train_simple.py b/test/python/train_simple.py
index dd4b0f3..8658e4d 100644
--- a/test/python/train_simple.py
+++ b/test/python/train_simple.py
@@ -7,7 +7,7 @@
from mpact.mpactbackend import mpact_jit
from mpact.models.kernels import SimpleNet
-from mpact.models.train import training_loop
+from mpact.models.train import training_loop, num_all_parameters, num_parameters
A = torch.tensor(
@@ -94,6 +94,12 @@
validation = DataLoader(validation_data, batch_size=2)
+# CHECK-LABEL: parameters
+# CHECK-COUNT-2: 182
+print("parameters")
+print(num_all_parameters(net))
+print(num_parameters(net))
+
# Run it with PyTorch.
# CHECK-LABEL: pytorch
# CHECK: Epoch 9