[mpact][compiler] move sample models into centralized models directory (#17)
Good idea by Yinying, so that tests and benchmark can draw from our
sample model pool, without having to repeat the PyTorch code all
over the place
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index ac10ed2..4b3953a 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -14,6 +14,13 @@
mpactbackend.py
)
+declare_mlir_python_sources(MPACTPythonSources.SampleModels
+ ROOT_DIR "${MPACT_PYTHON_ROOT_DIR}"
+ ADD_TO_PARENT MPACTPythonSources
+ SOURCES
+ models/resnet.py
+)
+
#-------------------------------------------------------------------------------
# Python Modules
#-------------------------------------------------------------------------------
diff --git a/python/mpact/models/resnet.py b/python/mpact/models/resnet.py
new file mode 100644
index 0000000..c3f5925
--- /dev/null
+++ b/python/mpact/models/resnet.py
@@ -0,0 +1,255 @@
+import torch
+import numpy as np
+
+
+def spike(input):
+ return (input >= 0).float()
+
+
+class Straight(torch.nn.Module):
+ def forward(self, input):
+ return input
+
+
+class tdLayer(torch.nn.Module):
+ def __init__(self, layer, bn=None):
+ super(tdLayer, self).__init__()
+ self.layer = layer
+ self.bn = bn if bn is not None else Straight()
+
+ def forward(self, X):
+ T = X.size(-1)
+ out = []
+ for t in range(T):
+ m = self.layer(X[..., t])
+ out.append(m)
+ out = torch.stack(out, dim=-1)
+ out = self.bn(out)
+ return out
+
+
+class LIF(torch.nn.Module):
+ def __init__(self):
+ super(LIF, self).__init__()
+ self.thresh = 1.0
+ self.decay = 0.5
+ self.act = spike
+ self.gama = 1.0
+
+ def forward(self, X, gama=1):
+ mem = 0
+ spike_pot = []
+ T = X.size(-1)
+ for t in range(T):
+ mem = mem * self.decay + X[..., t]
+ spike = self.act(mem - self.thresh)
+ mem = mem * (1.0 - spike)
+ spike_pot.append(spike)
+ spike_pot = torch.stack(spike_pot, dim=-1)
+ return spike_pot
+
+
+class tdBatchNorm(torch.nn.BatchNorm2d):
+ def __init__(
+ self,
+ num_features,
+ eps=1e-05,
+ momentum=0.1,
+ alpha=1,
+ affine=True,
+ track_running_stats=True,
+ ):
+ super(tdBatchNorm, self).__init__(
+ num_features, eps, momentum, affine, track_running_stats
+ )
+ self.alpha = alpha
+
+ def forward(self, input):
+ exponential_average_factor = 0.0
+ mean = self.running_mean
+ var = self.running_var
+ input = (
+ self.alpha
+ * (input - mean[None, :, None, None, None])
+ / (torch.sqrt(var[None, :, None, None, None] + self.eps))
+ )
+ if self.affine:
+ input = (
+ input * self.weight[None, :, None, None, None]
+ + self.bias[None, :, None, None, None]
+ )
+ return input
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ return torch.nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return torch.nn.Conv2d(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
+ )
+
+
+class BasicBlock(torch.nn.Module):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=None,
+ ):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = tdBatchNorm
+ # norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.conv1_s = tdLayer(self.conv1, self.bn1)
+ self.conv2_s = tdLayer(self.conv2, self.bn2)
+ self.spike1 = LIF()
+ self.spike2 = LIF()
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1_s(x)
+ out = self.spike1(out)
+ out = self.conv2_s(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.spike2(out)
+
+ return out
+
+
+class ResNety(torch.nn.Module):
+ def __init__(
+ self,
+ block,
+ layers,
+ num_classes=10,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ norm_layer=None,
+ ):
+ super(ResNety, self).__init__()
+ if norm_layer is None:
+ norm_layer = tdBatchNorm
+ # norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ self.inplanes = 64
+ self.dilation = 1
+ self.groups = groups
+ self.base_width = width_per_group
+ self.pre = torch.nn.Sequential(
+ tdLayer(
+ layer=torch.nn.Conv2d(
+ 3, self.inplanes, kernel_size=(3, 3), stride=(1, 1)
+ ),
+ bn=self._norm_layer(self.inplanes),
+ ),
+ LIF(),
+ )
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.avgpool = tdLayer(torch.nn.AdaptiveAvgPool2d((1, 1)))
+ self.fc = tdLayer(torch.nn.Linear(256, num_classes))
+ self.T = 6
+ for m in self.modules():
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.init.kaiming_normal_(
+ m.weight, mode="fan_out", nonlinearity="relu"
+ )
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = tdLayer(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
+
+ return torch.nn.Sequential(*layers)
+
+ def _forward_impl(self, input):
+ out = []
+ input = input.unsqueeze(-1).repeat(1, 1, 1, 1, self.T)
+ x = self.pre(input)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1, x.size(-1))
+ x = self.fc(x)
+ for t in range(self.T):
+ out.append(x[..., t])
+ return torch.stack(out, dim=1)
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+def resnet20():
+ return ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10)
diff --git a/test/python/resnet.py b/test/python/resnet.py
index 850028a..b353e8d 100644
--- a/test/python/resnet.py
+++ b/test/python/resnet.py
@@ -5,257 +5,9 @@
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+from mpact.models.resnet import resnet20
-def spike(input):
- return (input >= 0).float()
-
-
-class Straight(torch.nn.Module):
- def forward(self, input):
- return input
-
-
-class tdLayer(torch.nn.Module):
- def __init__(self, layer, bn=None):
- super(tdLayer, self).__init__()
- self.layer = layer
- self.bn = bn if bn is not None else Straight()
-
- def forward(self, X):
- T = X.size(-1)
- out = []
- for t in range(T):
- m = self.layer(X[..., t])
- out.append(m)
- out = torch.stack(out, dim=-1)
- out = self.bn(out)
- return out
-
-
-class LIF(torch.nn.Module):
- def __init__(self):
- super(LIF, self).__init__()
- self.thresh = 1.0
- self.decay = 0.5
- self.act = spike
- self.gama = 1.0
-
- def forward(self, X, gama=1):
- mem = 0
- spike_pot = []
- T = X.size(-1)
- for t in range(T):
- mem = mem * self.decay + X[..., t]
- spike = self.act(mem - self.thresh)
- mem = mem * (1.0 - spike)
- spike_pot.append(spike)
- spike_pot = torch.stack(spike_pot, dim=-1)
- return spike_pot
-
-
-class tdBatchNorm(torch.nn.BatchNorm2d):
- def __init__(
- self,
- num_features,
- eps=1e-05,
- momentum=0.1,
- alpha=1,
- affine=True,
- track_running_stats=True,
- ):
- super(tdBatchNorm, self).__init__(
- num_features, eps, momentum, affine, track_running_stats
- )
- self.alpha = alpha
-
- def forward(self, input):
- exponential_average_factor = 0.0
- mean = self.running_mean
- var = self.running_var
- input = (
- self.alpha
- * (input - mean[None, :, None, None, None])
- / (torch.sqrt(var[None, :, None, None, None] + self.eps))
- )
- if self.affine:
- input = (
- input * self.weight[None, :, None, None, None]
- + self.bias[None, :, None, None, None]
- )
- return input
-
-
-def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
- return torch.nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- groups=groups,
- bias=False,
- dilation=dilation,
- )
-
-
-def conv1x1(in_planes, out_planes, stride=1):
- return torch.nn.Conv2d(
- in_planes, out_planes, kernel_size=1, stride=stride, bias=False
- )
-
-
-class BasicBlock(torch.nn.Module):
- expansion = 1
-
- def __init__(
- self,
- inplanes,
- planes,
- stride=1,
- downsample=None,
- groups=1,
- base_width=64,
- dilation=1,
- norm_layer=None,
- ):
- super(BasicBlock, self).__init__()
- if norm_layer is None:
- norm_layer = tdBatchNorm
- # norm_layer = nn.BatchNorm2d
- if groups != 1 or base_width != 64:
- raise ValueError("BasicBlock only supports groups=1 and base_width=64")
- if dilation > 1:
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
- # Both self.conv1 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = norm_layer(planes)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = norm_layer(planes)
- self.downsample = downsample
- self.stride = stride
- self.conv1_s = tdLayer(self.conv1, self.bn1)
- self.conv2_s = tdLayer(self.conv2, self.bn2)
- self.spike1 = LIF()
- self.spike2 = LIF()
-
- def forward(self, x):
- identity = x
-
- out = self.conv1_s(x)
- out = self.spike1(out)
- out = self.conv2_s(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.spike2(out)
-
- return out
-
-
-class ResNety(torch.nn.Module):
- def __init__(
- self,
- block,
- layers,
- num_classes=10,
- zero_init_residual=False,
- groups=1,
- width_per_group=64,
- replace_stride_with_dilation=None,
- norm_layer=None,
- ):
- super(ResNety, self).__init__()
- if norm_layer is None:
- norm_layer = tdBatchNorm
- # norm_layer = nn.BatchNorm2d
- self._norm_layer = norm_layer
- self.inplanes = 64
- self.dilation = 1
- self.groups = groups
- self.base_width = width_per_group
- self.pre = torch.nn.Sequential(
- tdLayer(
- layer=torch.nn.Conv2d(
- 3, self.inplanes, kernel_size=(3, 3), stride=(1, 1)
- ),
- bn=self._norm_layer(self.inplanes),
- ),
- LIF(),
- )
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
- self.avgpool = tdLayer(torch.nn.AdaptiveAvgPool2d((1, 1)))
- self.fc = tdLayer(torch.nn.Linear(256, num_classes))
- self.T = 6
- for m in self.modules():
- if isinstance(m, torch.nn.Conv2d):
- torch.nn.init.kaiming_normal_(
- m.weight, mode="fan_out", nonlinearity="relu"
- )
-
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
- norm_layer = self._norm_layer
- downsample = None
- previous_dilation = self.dilation
- if dilate:
- self.dilation *= stride
- stride = 1
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = tdLayer(
- conv1x1(self.inplanes, planes * block.expansion, stride),
- norm_layer(planes * block.expansion),
- )
-
- layers = []
- layers.append(
- block(
- self.inplanes,
- planes,
- stride,
- downsample,
- self.groups,
- self.base_width,
- previous_dilation,
- norm_layer,
- )
- )
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(
- block(
- self.inplanes,
- planes,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=norm_layer,
- )
- )
-
- return torch.nn.Sequential(*layers)
-
- def _forward_impl(self, input):
- out = []
- input = input.unsqueeze(-1).repeat(1, 1, 1, 1, self.T)
- x = self.pre(input)
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.avgpool(x)
- x = x.view(x.size(0), -1, x.size(-1))
- x = self.fc(x)
- for t in range(self.T):
- out.append(x[..., t])
- return torch.stack(out, dim=1)
-
- def forward(self, x):
- return self._forward_impl(x)
-
-
-resnet = ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10)
+resnet = resnet20()
resnet.train(False) # switch to inference
# Get a random input.