|  | import torch | 
|  | import numpy as np | 
|  |  | 
|  |  | 
|  | def spike(input): | 
|  | return (input >= 0).float() | 
|  |  | 
|  |  | 
|  | class Straight(torch.nn.Module): | 
|  | def forward(self, input): | 
|  | return input | 
|  |  | 
|  |  | 
|  | class tdLayer(torch.nn.Module): | 
|  | def __init__(self, layer, bn=None): | 
|  | super(tdLayer, self).__init__() | 
|  | self.layer = layer | 
|  | self.bn = bn if bn is not None else Straight() | 
|  |  | 
|  | def forward(self, X): | 
|  | T = X.size(-1) | 
|  | out = [] | 
|  | for t in range(T): | 
|  | m = self.layer(X[..., t]) | 
|  | out.append(m) | 
|  | out = torch.stack(out, dim=-1) | 
|  | out = self.bn(out) | 
|  | return out | 
|  |  | 
|  |  | 
|  | class LIF(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super(LIF, self).__init__() | 
|  | self.thresh = 1.0 | 
|  | self.decay = 0.5 | 
|  | self.act = spike | 
|  | self.gama = 1.0 | 
|  |  | 
|  | def forward(self, X, gama=1): | 
|  | mem = 0 | 
|  | spike_pot = [] | 
|  | T = X.size(-1) | 
|  | for t in range(T): | 
|  | mem = mem * self.decay + X[..., t] | 
|  | spike = self.act(mem - self.thresh) | 
|  | mem = mem * (1.0 - spike) | 
|  | spike_pot.append(spike) | 
|  | spike_pot = torch.stack(spike_pot, dim=-1) | 
|  | return spike_pot | 
|  |  | 
|  |  | 
|  | class tdBatchNorm(torch.nn.BatchNorm2d): | 
|  | def __init__( | 
|  | self, | 
|  | num_features, | 
|  | eps=1e-05, | 
|  | momentum=0.1, | 
|  | alpha=1, | 
|  | affine=True, | 
|  | track_running_stats=True, | 
|  | ): | 
|  | super(tdBatchNorm, self).__init__( | 
|  | num_features, eps, momentum, affine, track_running_stats | 
|  | ) | 
|  | self.alpha = alpha | 
|  |  | 
|  | def forward(self, input): | 
|  | exponential_average_factor = 0.0 | 
|  | mean = self.running_mean | 
|  | var = self.running_var | 
|  | input = ( | 
|  | self.alpha | 
|  | * (input - mean[None, :, None, None, None]) | 
|  | / (torch.sqrt(var[None, :, None, None, None] + self.eps)) | 
|  | ) | 
|  | if self.affine: | 
|  | input = ( | 
|  | input * self.weight[None, :, None, None, None] | 
|  | + self.bias[None, :, None, None, None] | 
|  | ) | 
|  | return input | 
|  |  | 
|  |  | 
|  | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | 
|  | return torch.nn.Conv2d( | 
|  | in_planes, | 
|  | out_planes, | 
|  | kernel_size=3, | 
|  | stride=stride, | 
|  | padding=dilation, | 
|  | groups=groups, | 
|  | bias=False, | 
|  | dilation=dilation, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def conv1x1(in_planes, out_planes, stride=1): | 
|  | return torch.nn.Conv2d( | 
|  | in_planes, out_planes, kernel_size=1, stride=stride, bias=False | 
|  | ) | 
|  |  | 
|  |  | 
|  | class BasicBlock(torch.nn.Module): | 
|  | expansion = 1 | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | inplanes, | 
|  | planes, | 
|  | stride=1, | 
|  | downsample=None, | 
|  | groups=1, | 
|  | base_width=64, | 
|  | dilation=1, | 
|  | norm_layer=None, | 
|  | ): | 
|  | super(BasicBlock, self).__init__() | 
|  | if norm_layer is None: | 
|  | norm_layer = tdBatchNorm | 
|  | # norm_layer = nn.BatchNorm2d | 
|  | if groups != 1 or base_width != 64: | 
|  | raise ValueError("BasicBlock only supports groups=1 and base_width=64") | 
|  | if dilation > 1: | 
|  | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | 
|  | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | 
|  | self.conv1 = conv3x3(inplanes, planes, stride) | 
|  | self.bn1 = norm_layer(planes) | 
|  | self.conv2 = conv3x3(planes, planes) | 
|  | self.bn2 = norm_layer(planes) | 
|  | self.downsample = downsample | 
|  | self.stride = stride | 
|  | self.conv1_s = tdLayer(self.conv1, self.bn1) | 
|  | self.conv2_s = tdLayer(self.conv2, self.bn2) | 
|  | self.spike1 = LIF() | 
|  | self.spike2 = LIF() | 
|  |  | 
|  | def forward(self, x): | 
|  | identity = x | 
|  |  | 
|  | out = self.conv1_s(x) | 
|  | out = self.spike1(out) | 
|  | out = self.conv2_s(out) | 
|  |  | 
|  | if self.downsample is not None: | 
|  | identity = self.downsample(x) | 
|  |  | 
|  | out += identity | 
|  | out = self.spike2(out) | 
|  |  | 
|  | return out | 
|  |  | 
|  |  | 
|  | class ResNety(torch.nn.Module): | 
|  | def __init__( | 
|  | self, | 
|  | block, | 
|  | layers, | 
|  | num_classes=10, | 
|  | zero_init_residual=False, | 
|  | groups=1, | 
|  | width_per_group=64, | 
|  | replace_stride_with_dilation=None, | 
|  | norm_layer=None, | 
|  | ): | 
|  | super(ResNety, self).__init__() | 
|  | if norm_layer is None: | 
|  | norm_layer = tdBatchNorm | 
|  | # norm_layer = nn.BatchNorm2d | 
|  | self._norm_layer = norm_layer | 
|  | self.inplanes = 64 | 
|  | self.dilation = 1 | 
|  | self.groups = groups | 
|  | self.base_width = width_per_group | 
|  | self.pre = torch.nn.Sequential( | 
|  | tdLayer( | 
|  | layer=torch.nn.Conv2d( | 
|  | 3, self.inplanes, kernel_size=(3, 3), stride=(1, 1) | 
|  | ), | 
|  | bn=self._norm_layer(self.inplanes), | 
|  | ), | 
|  | LIF(), | 
|  | ) | 
|  | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | 
|  | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | 
|  | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | 
|  | self.avgpool = tdLayer(torch.nn.AdaptiveAvgPool2d((1, 1))) | 
|  | self.fc = tdLayer(torch.nn.Linear(256, num_classes)) | 
|  | self.T = 6 | 
|  | for m in self.modules(): | 
|  | if isinstance(m, torch.nn.Conv2d): | 
|  | torch.nn.init.kaiming_normal_( | 
|  | m.weight, mode="fan_out", nonlinearity="relu" | 
|  | ) | 
|  |  | 
|  | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | 
|  | norm_layer = self._norm_layer | 
|  | downsample = None | 
|  | previous_dilation = self.dilation | 
|  | if dilate: | 
|  | self.dilation *= stride | 
|  | stride = 1 | 
|  | if stride != 1 or self.inplanes != planes * block.expansion: | 
|  | downsample = tdLayer( | 
|  | conv1x1(self.inplanes, planes * block.expansion, stride), | 
|  | norm_layer(planes * block.expansion), | 
|  | ) | 
|  |  | 
|  | layers = [] | 
|  | layers.append( | 
|  | block( | 
|  | self.inplanes, | 
|  | planes, | 
|  | stride, | 
|  | downsample, | 
|  | self.groups, | 
|  | self.base_width, | 
|  | previous_dilation, | 
|  | norm_layer, | 
|  | ) | 
|  | ) | 
|  | self.inplanes = planes * block.expansion | 
|  | for _ in range(1, blocks): | 
|  | layers.append( | 
|  | block( | 
|  | self.inplanes, | 
|  | planes, | 
|  | groups=self.groups, | 
|  | base_width=self.base_width, | 
|  | dilation=self.dilation, | 
|  | norm_layer=norm_layer, | 
|  | ) | 
|  | ) | 
|  |  | 
|  | return torch.nn.Sequential(*layers) | 
|  |  | 
|  | def _forward_impl(self, input): | 
|  | out = [] | 
|  | input = input.unsqueeze(-1).repeat(1, 1, 1, 1, self.T) | 
|  | x = self.pre(input) | 
|  | x = self.layer1(x) | 
|  | x = self.layer2(x) | 
|  | x = self.layer3(x) | 
|  | x = self.avgpool(x) | 
|  | x = x.view(x.size(0), -1, x.size(-1)) | 
|  | x = self.fc(x) | 
|  | for t in range(self.T): | 
|  | out.append(x[..., t]) | 
|  | return torch.stack(out, dim=1) | 
|  |  | 
|  | def forward(self, x): | 
|  | return self._forward_impl(x) | 
|  |  | 
|  |  | 
|  | def resnet_20(): | 
|  | return ResNety(block=BasicBlock, layers=[2, 2, 2], num_classes=10) |