| import torch |
| |
| |
| def spike(input): |
| return (input >= 0).float() |
| |
| |
| def sqSum(input): |
| return (input * input).sum() |
| |
| |
| class LIF(torch.nn.Module): |
| def __init__(self): |
| super(LIF, self).__init__() |
| self.thresh = 1.0 |
| self.decay = 0.5 |
| self.act = spike |
| |
| def forward(self, X): |
| """A filter that yields a binary-valued sparse tensor.""" |
| mem = 0 |
| spike_pot = [] |
| T = X.size(-1) |
| for t in range(T): |
| mem = mem * self.decay + X[..., t] |
| spike = self.act(mem - self.thresh) |
| spike = spike.to_sparse().to_dense() # prop hack |
| mem = mem * (1.0 - spike) |
| spike_pot.append(spike) |
| spike_pot = torch.stack(spike_pot, dim=-1) |
| return spike_pot |
| |
| |
| class tdLayer(torch.nn.Module): |
| def __init__(self, layer): |
| super(tdLayer, self).__init__() |
| self.layer = layer |
| |
| def forward(self, X): |
| T = X.size(-1) |
| out = [] |
| for t in range(T): |
| m = self.layer(X[..., t]) |
| out.append(m) |
| out = torch.stack(out, dim=-1) |
| return out |
| |
| |
| class LIFSumOfSq(torch.nn.Module): |
| def __init__(self): |
| super(LIFSumOfSq, self).__init__() |
| self.spike = LIF() |
| self.layer = tdLayer(sqSum) |
| |
| def forward(self, X): |
| out = self.spike(X) |
| out = self.layer(out) |
| return out |