blob: 2fdaf68656a881fb9d39431cb59ded5ce3ee73b5 [file] [log] [blame]
import torch
import torch.nn.functional as F
def num_all_parameters(model):
"""Returns the number of all parameters in a model."""
return sum(p.numel() for p in model.parameters())
def num_parameters(model):
"""Returns the number of trainable parameters in a model."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def training_loop(model, optimizer, loss_function, train, validation, epochs=10):
"""A rudimentary PyTorch training loop for classification with training and validation data."""
for epoch in range(epochs):
# Switch to training mode.
model.train()
tloss = 0.0
num_train = len(train) # in batches
for inp, target in train: # batch loop (training)
optimizer.zero_grad()
output = model(inp)
loss = loss_function(output, target)
loss.backward()
optimizer.step()
tloss += loss.data.item()
# Switch to inference mode.
model.eval()
vloss = 0.0
num_validation = len(validation) # in batches
num_correct = 0
num_total = 0
for inp, target in validation: # batch loop (validation)
output = model(inp)
loss = loss_function(output, target)
vloss += loss.data.item()
correct = torch.eq(
torch.max(F.softmax(output, dim=1), dim=1)[1], target
).view(-1)
num_correct += torch.sum(correct).item()
num_total += correct.shape[0]
# Report stats.
print(
"Epoch {:d}, Training loss = {:.2f} #{:d}, Validation loss = {:.2f} #{:d}, Accuracy = {:.2f} #{:d}".format(
epoch,
(tloss / num_train) if num_train != 0 else 0,
num_train,
(vloss / num_validation) if num_validation != 0 else 0,
num_validation,
(num_correct / num_total) if num_total != 0 else 0,
num_total,
)
)