|  | import torch | 
|  | import torch.nn.functional as F | 
|  |  | 
|  |  | 
|  | def num_all_parameters(model): | 
|  | """Returns the number of all parameters in a model.""" | 
|  | return sum(p.numel() for p in model.parameters()) | 
|  |  | 
|  |  | 
|  | def num_parameters(model): | 
|  | """Returns the number of trainable parameters in a model.""" | 
|  | return sum(p.numel() for p in model.parameters() if p.requires_grad) | 
|  |  | 
|  |  | 
|  | def training_loop(model, optimizer, loss_function, train, validation, epochs=10): | 
|  | """A rudimentary PyTorch training loop for classification with training and validation data.""" | 
|  | for epoch in range(epochs): | 
|  | # Switch to training mode. | 
|  | model.train() | 
|  | tloss = 0.0 | 
|  | num_train = len(train)  # in batches | 
|  | for inp, target in train:  # batch loop (training) | 
|  | optimizer.zero_grad() | 
|  | output = model(inp) | 
|  | loss = loss_function(output, target) | 
|  | loss.backward() | 
|  | optimizer.step() | 
|  | tloss += loss.data.item() | 
|  |  | 
|  | # Switch to inference mode. | 
|  | model.eval() | 
|  | vloss = 0.0 | 
|  | num_validation = len(validation)  # in batches | 
|  | num_correct = 0 | 
|  | num_total = 0 | 
|  | for inp, target in validation:  # batch loop (validation) | 
|  | output = model(inp) | 
|  | loss = loss_function(output, target) | 
|  | vloss += loss.data.item() | 
|  | correct = torch.eq( | 
|  | torch.max(F.softmax(output, dim=1), dim=1)[1], target | 
|  | ).view(-1) | 
|  | num_correct += torch.sum(correct).item() | 
|  | num_total += correct.shape[0] | 
|  |  | 
|  | # Report stats. | 
|  | print( | 
|  | "Epoch {:d}, Training loss = {:.2f} #{:d}, Validation loss = {:.2f} #{:d}, Accuracy = {:.2f} #{:d}".format( | 
|  | epoch, | 
|  | (tloss / num_train) if num_train != 0 else 0, | 
|  | num_train, | 
|  | (vloss / num_validation) if num_validation != 0 else 0, | 
|  | num_validation, | 
|  | (num_correct / num_total) if num_total != 0 else 0, | 
|  | num_total, | 
|  | ) | 
|  | ) |