[mpact][compiler] fullresnet model (#13)

* [mpact][compiler] fullresnet model

Yes, this adds 30 seconds of testing time to MPACT
but that is a good motivator to make our compiler
a lot faster!

* Update resnet.py

black and darker...

---------

Co-authored-by: Yinying Li <yinyingli@google.com>
diff --git a/test/python/resnet.py b/test/python/resnet.py
new file mode 100644
index 0000000..850028a
--- /dev/null
+++ b/test/python/resnet.py
@@ -0,0 +1,286 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+
+def spike(input):
+    return (input >= 0).float()
+
+
+class Straight(torch.nn.Module):
+    def forward(self, input):
+        return input
+
+
+class tdLayer(torch.nn.Module):
+    def __init__(self, layer, bn=None):
+        super(tdLayer, self).__init__()
+        self.layer = layer
+        self.bn = bn if bn is not None else Straight()
+
+    def forward(self, X):
+        T = X.size(-1)
+        out = []
+        for t in range(T):
+            m = self.layer(X[..., t])
+            out.append(m)
+        out = torch.stack(out, dim=-1)
+        out = self.bn(out)
+        return out
+
+
+class LIF(torch.nn.Module):
+    def __init__(self):
+        super(LIF, self).__init__()
+        self.thresh = 1.0
+        self.decay = 0.5
+        self.act = spike
+        self.gama = 1.0
+
+    def forward(self, X, gama=1):
+        mem = 0
+        spike_pot = []
+        T = X.size(-1)
+        for t in range(T):
+            mem = mem * self.decay + X[..., t]
+            spike = self.act(mem - self.thresh)
+            mem = mem * (1.0 - spike)
+            spike_pot.append(spike)
+        spike_pot = torch.stack(spike_pot, dim=-1)
+        return spike_pot
+
+
+class tdBatchNorm(torch.nn.BatchNorm2d):
+    def __init__(
+        self,
+        num_features,
+        eps=1e-05,
+        momentum=0.1,
+        alpha=1,
+        affine=True,
+        track_running_stats=True,
+    ):
+        super(tdBatchNorm, self).__init__(
+            num_features, eps, momentum, affine, track_running_stats
+        )
+        self.alpha = alpha
+
+    def forward(self, input):
+        exponential_average_factor = 0.0
+        mean = self.running_mean
+        var = self.running_var
+        input = (
+            self.alpha
+            * (input - mean[None, :, None, None, None])
+            / (torch.sqrt(var[None, :, None, None, None] + self.eps))
+        )
+        if self.affine:
+            input = (
+                input * self.weight[None, :, None, None, None]
+                + self.bias[None, :, None, None, None]
+            )
+        return input
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    return torch.nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        groups=groups,
+        bias=False,
+        dilation=dilation,
+    )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    return torch.nn.Conv2d(
+        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
+    )
+
+
+class BasicBlock(torch.nn.Module):
+    expansion = 1
+
+    def __init__(
+        self,
+        inplanes,
+        planes,
+        stride=1,
+        downsample=None,
+        groups=1,
+        base_width=64,
+        dilation=1,
+        norm_layer=None,
+    ):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = tdBatchNorm
+            # norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+        self.conv1_s = tdLayer(self.conv1, self.bn1)
+        self.conv2_s = tdLayer(self.conv2, self.bn2)
+        self.spike1 = LIF()
+        self.spike2 = LIF()
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1_s(x)
+        out = self.spike1(out)
+        out = self.conv2_s(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.spike2(out)
+
+        return out
+
+
+class ResNety(torch.nn.Module):
+    def __init__(
+        self,
+        block,
+        layers,
+        num_classes=10,
+        zero_init_residual=False,
+        groups=1,
+        width_per_group=64,
+        replace_stride_with_dilation=None,
+        norm_layer=None,
+    ):
+        super(ResNety, self).__init__()
+        if norm_layer is None:
+            norm_layer = tdBatchNorm
+            # norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+        self.inplanes = 64
+        self.dilation = 1
+        self.groups = groups
+        self.base_width = width_per_group
+        self.pre = torch.nn.Sequential(
+            tdLayer(
+                layer=torch.nn.Conv2d(
+                    3, self.inplanes, kernel_size=(3, 3), stride=(1, 1)
+                ),
+                bn=self._norm_layer(self.inplanes),
+            ),
+            LIF(),
+        )
+        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.avgpool = tdLayer(torch.nn.AdaptiveAvgPool2d((1, 1)))
+        self.fc = tdLayer(torch.nn.Linear(256, num_classes))
+        self.T = 6
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                torch.nn.init.kaiming_normal_(
+                    m.weight, mode="fan_out", nonlinearity="relu"
+                )
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = tdLayer(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(
+            block(
+                self.inplanes,
+                planes,
+                stride,
+                downsample,
+                self.groups,
+                self.base_width,
+                previous_dilation,
+                norm_layer,
+            )
+        )
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(
+                    self.inplanes,
+                    planes,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                    norm_layer=norm_layer,
+                )
+            )
+
+        return torch.nn.Sequential(*layers)
+
+    def _forward_impl(self, input):
+        out = []
+        input = input.unsqueeze(-1).repeat(1, 1, 1, 1, self.T)
+        x = self.pre(input)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1, x.size(-1))
+        x = self.fc(x)
+        for t in range(self.T):
+            out.append(x[..., t])
+        return torch.stack(out, dim=1)
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+resnet = ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10)
+resnet.train(False)  # switch to inference
+
+# Get a random input.
+#   B x RGB x H x W
+x = torch.rand(1, 3, 16, 16)
+
+#
+# CHECK: pytorch
+# CHECK: mpact
+# CHECK: passed
+#
+
+with torch.no_grad():
+    # Run it with PyTorch.
+    print("pytorch")
+    res1 = resnet(x)
+    print(res1)
+
+    # Run it with MPACT.
+    # TODO: make this work
+    print("mpact")
+    res2 = mpact_jit(resnet, x)
+    print(res2)
+
+# Completely different inputs and weights for each run,
+# so we simply verify the two results are the same.
+np.testing.assert_allclose(res1.numpy(), res2, rtol=1e-5, atol=0)
+print("passed")