Disable gradients in inference mode (#80)
diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py
index 2fdaf68..d5e27cc 100644
--- a/python/mpact/models/train.py
+++ b/python/mpact/models/train.py
@@ -28,20 +28,21 @@
tloss += loss.data.item()
# Switch to inference mode.
- model.eval()
- vloss = 0.0
- num_validation = len(validation) # in batches
- num_correct = 0
- num_total = 0
- for inp, target in validation: # batch loop (validation)
- output = model(inp)
- loss = loss_function(output, target)
- vloss += loss.data.item()
- correct = torch.eq(
- torch.max(F.softmax(output, dim=1), dim=1)[1], target
- ).view(-1)
- num_correct += torch.sum(correct).item()
- num_total += correct.shape[0]
+ model.eval() # disables e.g. model drop-out
+ with torch.no_grad(): # disables gradient computations
+ vloss = 0.0
+ num_validation = len(validation) # in batches
+ num_correct = 0
+ num_total = 0
+ for inp, target in validation: # batch loop (validation)
+ output = model(inp)
+ loss = loss_function(output, target)
+ vloss += loss.data.item()
+ correct = torch.eq(
+ torch.max(F.softmax(output, dim=1), dim=1)[1], target
+ ).view(-1)
+ num_correct += torch.sum(correct).item()
+ num_total += correct.shape[0]
# Report stats.
print(