| import torch | 
 | import numpy as np | 
 |  | 
 |  | 
 | def spike(input): | 
 |     return (input >= 0).float() | 
 |  | 
 |  | 
 | class Straight(torch.nn.Module): | 
 |     def forward(self, input): | 
 |         return input | 
 |  | 
 |  | 
 | class tdLayer(torch.nn.Module): | 
 |     def __init__(self, layer, bn=None): | 
 |         super(tdLayer, self).__init__() | 
 |         self.layer = layer | 
 |         self.bn = bn if bn is not None else Straight() | 
 |  | 
 |     def forward(self, X): | 
 |         T = X.size(-1) | 
 |         out = [] | 
 |         for t in range(T): | 
 |             m = self.layer(X[..., t]) | 
 |             out.append(m) | 
 |         out = torch.stack(out, dim=-1) | 
 |         out = self.bn(out) | 
 |         return out | 
 |  | 
 |  | 
 | class LIF(torch.nn.Module): | 
 |     def __init__(self): | 
 |         super(LIF, self).__init__() | 
 |         self.thresh = 1.0 | 
 |         self.decay = 0.5 | 
 |         self.act = spike | 
 |         self.gama = 1.0 | 
 |  | 
 |     def forward(self, X, gama=1): | 
 |         mem = 0 | 
 |         spike_pot = [] | 
 |         T = X.size(-1) | 
 |         for t in range(T): | 
 |             mem = mem * self.decay + X[..., t] | 
 |             spike = self.act(mem - self.thresh) | 
 |             mem = mem * (1.0 - spike) | 
 |             spike_pot.append(spike) | 
 |         spike_pot = torch.stack(spike_pot, dim=-1) | 
 |         return spike_pot | 
 |  | 
 |  | 
 | class tdBatchNorm(torch.nn.BatchNorm2d): | 
 |     def __init__( | 
 |         self, | 
 |         num_features, | 
 |         eps=1e-05, | 
 |         momentum=0.1, | 
 |         alpha=1, | 
 |         affine=True, | 
 |         track_running_stats=True, | 
 |     ): | 
 |         super(tdBatchNorm, self).__init__( | 
 |             num_features, eps, momentum, affine, track_running_stats | 
 |         ) | 
 |         self.alpha = alpha | 
 |  | 
 |     def forward(self, input): | 
 |         exponential_average_factor = 0.0 | 
 |         mean = self.running_mean | 
 |         var = self.running_var | 
 |         input = ( | 
 |             self.alpha | 
 |             * (input - mean[None, :, None, None, None]) | 
 |             / (torch.sqrt(var[None, :, None, None, None] + self.eps)) | 
 |         ) | 
 |         if self.affine: | 
 |             input = ( | 
 |                 input * self.weight[None, :, None, None, None] | 
 |                 + self.bias[None, :, None, None, None] | 
 |             ) | 
 |         return input | 
 |  | 
 |  | 
 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | 
 |     return torch.nn.Conv2d( | 
 |         in_planes, | 
 |         out_planes, | 
 |         kernel_size=3, | 
 |         stride=stride, | 
 |         padding=dilation, | 
 |         groups=groups, | 
 |         bias=False, | 
 |         dilation=dilation, | 
 |     ) | 
 |  | 
 |  | 
 | def conv1x1(in_planes, out_planes, stride=1): | 
 |     return torch.nn.Conv2d( | 
 |         in_planes, out_planes, kernel_size=1, stride=stride, bias=False | 
 |     ) | 
 |  | 
 |  | 
 | class BasicBlock(torch.nn.Module): | 
 |     expansion = 1 | 
 |  | 
 |     def __init__( | 
 |         self, | 
 |         inplanes, | 
 |         planes, | 
 |         stride=1, | 
 |         downsample=None, | 
 |         groups=1, | 
 |         base_width=64, | 
 |         dilation=1, | 
 |         norm_layer=None, | 
 |     ): | 
 |         super(BasicBlock, self).__init__() | 
 |         if norm_layer is None: | 
 |             norm_layer = tdBatchNorm | 
 |             # norm_layer = nn.BatchNorm2d | 
 |         if groups != 1 or base_width != 64: | 
 |             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | 
 |         if dilation > 1: | 
 |             raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | 
 |         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | 
 |         self.conv1 = conv3x3(inplanes, planes, stride) | 
 |         self.bn1 = norm_layer(planes) | 
 |         self.conv2 = conv3x3(planes, planes) | 
 |         self.bn2 = norm_layer(planes) | 
 |         self.downsample = downsample | 
 |         self.stride = stride | 
 |         self.conv1_s = tdLayer(self.conv1, self.bn1) | 
 |         self.conv2_s = tdLayer(self.conv2, self.bn2) | 
 |         self.spike1 = LIF() | 
 |         self.spike2 = LIF() | 
 |  | 
 |     def forward(self, x): | 
 |         identity = x | 
 |  | 
 |         out = self.conv1_s(x) | 
 |         out = self.spike1(out) | 
 |         out = self.conv2_s(out) | 
 |  | 
 |         if self.downsample is not None: | 
 |             identity = self.downsample(x) | 
 |  | 
 |         out += identity | 
 |         out = self.spike2(out) | 
 |  | 
 |         return out | 
 |  | 
 |  | 
 | class ResNety(torch.nn.Module): | 
 |     def __init__( | 
 |         self, | 
 |         block, | 
 |         layers, | 
 |         num_classes=10, | 
 |         zero_init_residual=False, | 
 |         groups=1, | 
 |         width_per_group=64, | 
 |         replace_stride_with_dilation=None, | 
 |         norm_layer=None, | 
 |     ): | 
 |         super(ResNety, self).__init__() | 
 |         if norm_layer is None: | 
 |             norm_layer = tdBatchNorm | 
 |             # norm_layer = nn.BatchNorm2d | 
 |         self._norm_layer = norm_layer | 
 |         self.inplanes = 64 | 
 |         self.dilation = 1 | 
 |         self.groups = groups | 
 |         self.base_width = width_per_group | 
 |         self.pre = torch.nn.Sequential( | 
 |             tdLayer( | 
 |                 layer=torch.nn.Conv2d( | 
 |                     3, self.inplanes, kernel_size=(3, 3), stride=(1, 1) | 
 |                 ), | 
 |                 bn=self._norm_layer(self.inplanes), | 
 |             ), | 
 |             LIF(), | 
 |         ) | 
 |         self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | 
 |         self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | 
 |         self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | 
 |         self.avgpool = tdLayer(torch.nn.AdaptiveAvgPool2d((1, 1))) | 
 |         self.fc = tdLayer(torch.nn.Linear(256, num_classes)) | 
 |         self.T = 6 | 
 |         for m in self.modules(): | 
 |             if isinstance(m, torch.nn.Conv2d): | 
 |                 torch.nn.init.kaiming_normal_( | 
 |                     m.weight, mode="fan_out", nonlinearity="relu" | 
 |                 ) | 
 |  | 
 |     def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | 
 |         norm_layer = self._norm_layer | 
 |         downsample = None | 
 |         previous_dilation = self.dilation | 
 |         if dilate: | 
 |             self.dilation *= stride | 
 |             stride = 1 | 
 |         if stride != 1 or self.inplanes != planes * block.expansion: | 
 |             downsample = tdLayer( | 
 |                 conv1x1(self.inplanes, planes * block.expansion, stride), | 
 |                 norm_layer(planes * block.expansion), | 
 |             ) | 
 |  | 
 |         layers = [] | 
 |         layers.append( | 
 |             block( | 
 |                 self.inplanes, | 
 |                 planes, | 
 |                 stride, | 
 |                 downsample, | 
 |                 self.groups, | 
 |                 self.base_width, | 
 |                 previous_dilation, | 
 |                 norm_layer, | 
 |             ) | 
 |         ) | 
 |         self.inplanes = planes * block.expansion | 
 |         for _ in range(1, blocks): | 
 |             layers.append( | 
 |                 block( | 
 |                     self.inplanes, | 
 |                     planes, | 
 |                     groups=self.groups, | 
 |                     base_width=self.base_width, | 
 |                     dilation=self.dilation, | 
 |                     norm_layer=norm_layer, | 
 |                 ) | 
 |             ) | 
 |  | 
 |         return torch.nn.Sequential(*layers) | 
 |  | 
 |     def _forward_impl(self, input): | 
 |         out = [] | 
 |         input = input.unsqueeze(-1).repeat(1, 1, 1, 1, self.T) | 
 |         x = self.pre(input) | 
 |         x = self.layer1(x) | 
 |         x = self.layer2(x) | 
 |         x = self.layer3(x) | 
 |         x = self.avgpool(x) | 
 |         x = x.view(x.size(0), -1, x.size(-1)) | 
 |         x = self.fc(x) | 
 |         for t in range(self.T): | 
 |             out.append(x[..., t]) | 
 |         return torch.stack(out, dim=1) | 
 |  | 
 |     def forward(self, x): | 
 |         return self._forward_impl(x) | 
 |  | 
 |  | 
 | def resnet_20(): | 
 |     return ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10) |