| # RUN: %PYTHON %s | FileCheck %s | 
 |  | 
 | import torch | 
 | import numpy as np | 
 |  | 
 | from torch.utils.data import Dataset, DataLoader | 
 |  | 
 | from mpact.mpactbackend import mpact_jit | 
 | from mpact.models.kernels import SimpleNet | 
 | from mpact.models.train import training_loop, num_all_parameters, num_parameters | 
 |  | 
 |  | 
 | A = torch.tensor( | 
 |     [ | 
 |         [ | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [0.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |         ], | 
 |         [ | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [1.0, 0.0, 0.0, 0.0], | 
 |         ], | 
 |         [ | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 0.0], | 
 |         ], | 
 |         [ | 
 |             [0.0, 0.0, 0.0, 1.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |         ], | 
 |     ], | 
 |     dtype=torch.float32, | 
 | ) | 
 |  | 
 |  | 
 | B = torch.tensor( | 
 |     [ | 
 |         [ | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |             [0.0, 0.0, 0.0, 0.0], | 
 |         ], | 
 |         [ | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |             [1.0, 1.0, 1.0, 1.0], | 
 |         ], | 
 |     ], | 
 |     dtype=torch.float32, | 
 | ) | 
 |  | 
 | # Labels 0:sparse 1:dense | 
 |  | 
 | labA = torch.tensor([1, 0, 1, 0]) | 
 |  | 
 | labB = torch.tensor([0, 1]) | 
 |  | 
 | # A toy training and validation data set consisting of dense/sparse tensors. | 
 |  | 
 |  | 
 | class TrainData(Dataset): | 
 |     def __len__(self): | 
 |         return A.shape[0] | 
 |  | 
 |     def __getitem__(self, index): | 
 |         return A[index], labA[index] | 
 |  | 
 |  | 
 | class ValidationData(Dataset): | 
 |     def __len__(self): | 
 |         return B.shape[0] | 
 |  | 
 |     def __getitem__(self, index): | 
 |         return B[index], labB[index] | 
 |  | 
 |  | 
 | train_data = TrainData() | 
 | validation_data = ValidationData() | 
 |  | 
 | net = SimpleNet() | 
 | optimizer = torch.optim.Adam(net.parameters(), lr=0.001) | 
 | loss_function = torch.nn.CrossEntropyLoss() | 
 | train = DataLoader(train_data, batch_size=2) | 
 | validation = DataLoader(validation_data, batch_size=2) | 
 |  | 
 |  | 
 | # CHECK-LABEL: parameters | 
 | # CHECK-COUNT-2: 182 | 
 | print("parameters") | 
 | print(num_all_parameters(net)) | 
 | print(num_parameters(net)) | 
 |  | 
 | # Run it with PyTorch. | 
 | # CHECK-LABEL: pytorch | 
 | # CHECK:       Epoch 9 | 
 | print("pytorch") | 
 | training_loop(net, optimizer, loss_function, train, validation, epochs=10) | 
 |  | 
 | # Run it with MPACT. | 
 | # CHECK-LABEL: mpact | 
 | print("mpact") | 
 | # TODO: teach MPACT about autograd |