blob: fcb5a55658b586fa829c66a45b1eba33fa8d089f [file] [log] [blame]
import torch
def spike(input):
return (input >= 0).float()
def sqSum(input):
return (input * input).sum()
class LIF(torch.nn.Module):
def __init__(self):
super(LIF, self).__init__()
self.thresh = 1.0
self.decay = 0.5
self.act = spike
def forward(self, X):
"""A filter that yields a binary-valued sparse tensor."""
mem = 0
spike_pot = []
T = X.size(-1)
for t in range(T):
mem = mem * self.decay + X[..., t]
spike = self.act(mem - self.thresh)
spike = spike.to_sparse().to_dense() # prop hack
mem = mem * (1.0 - spike)
spike_pot.append(spike)
spike_pot = torch.stack(spike_pot, dim=-1)
return spike_pot
class tdLayer(torch.nn.Module):
def __init__(self, layer):
super(tdLayer, self).__init__()
self.layer = layer
def forward(self, X):
T = X.size(-1)
out = []
for t in range(T):
m = self.layer(X[..., t])
out.append(m)
out = torch.stack(out, dim=-1)
return out
class LIFSumOfSq(torch.nn.Module):
def __init__(self):
super(LIFSumOfSq, self).__init__()
self.spike = LIF()
self.layer = tdLayer(sqSum)
def forward(self, X):
out = self.spike(X)
out = self.layer(out)
return out