[mpact][compiler] add number of model parameters utility (#63)

Also addressed recent change in softmax that now
requires an explicit dimension.
diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py
index 791edcc..2fdaf68 100644
--- a/python/mpact/models/train.py
+++ b/python/mpact/models/train.py
@@ -2,6 +2,16 @@
 import torch.nn.functional as F
 
 
+def num_all_parameters(model):
+    """Returns the number of all parameters in a model."""
+    return sum(p.numel() for p in model.parameters())
+
+
+def num_parameters(model):
+    """Returns the number of trainable parameters in a model."""
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
 def training_loop(model, optimizer, loss_function, train, validation, epochs=10):
     """A rudimentary PyTorch training loop for classification with training and validation data."""
     for epoch in range(epochs):
@@ -27,7 +37,9 @@
             output = model(inp)
             loss = loss_function(output, target)
             vloss += loss.data.item()
-            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], target).view(-1)
+            correct = torch.eq(
+                torch.max(F.softmax(output, dim=1), dim=1)[1], target
+            ).view(-1)
             num_correct += torch.sum(correct).item()
             num_total += correct.shape[0]
 
diff --git a/test/python/train_simple.py b/test/python/train_simple.py
index dd4b0f3..8658e4d 100644
--- a/test/python/train_simple.py
+++ b/test/python/train_simple.py
@@ -7,7 +7,7 @@
 
 from mpact.mpactbackend import mpact_jit
 from mpact.models.kernels import SimpleNet
-from mpact.models.train import training_loop
+from mpact.models.train import training_loop, num_all_parameters, num_parameters
 
 
 A = torch.tensor(
@@ -94,6 +94,12 @@
 validation = DataLoader(validation_data, batch_size=2)
 
 
+# CHECK-LABEL: parameters
+# CHECK-COUNT-2: 182
+print("parameters")
+print(num_all_parameters(net))
+print(num_parameters(net))
+
 # Run it with PyTorch.
 # CHECK-LABEL: pytorch
 # CHECK:       Epoch 9