| import torch | 
 |  | 
 |  | 
 | def spike(input): | 
 |     return (input >= 0).float() | 
 |  | 
 |  | 
 | def sqSum(input): | 
 |     return (input * input).sum() | 
 |  | 
 |  | 
 | class LIF(torch.nn.Module): | 
 |     def __init__(self): | 
 |         super(LIF, self).__init__() | 
 |         self.thresh = 1.0 | 
 |         self.decay = 0.5 | 
 |         self.act = spike | 
 |  | 
 |     def forward(self, X): | 
 |         """A filter that yields a binary-valued sparse tensor.""" | 
 |         mem = 0 | 
 |         spike_pot = [] | 
 |         T = X.size(-1) | 
 |         for t in range(T): | 
 |             mem = mem * self.decay + X[..., t] | 
 |             spike = self.act(mem - self.thresh) | 
 |             spike = spike.to_sparse().to_dense()  # prop hack | 
 |             mem = mem * (1.0 - spike) | 
 |             spike_pot.append(spike) | 
 |         spike_pot = torch.stack(spike_pot, dim=-1) | 
 |         return spike_pot | 
 |  | 
 |  | 
 | class tdLayer(torch.nn.Module): | 
 |     def __init__(self, layer): | 
 |         super(tdLayer, self).__init__() | 
 |         self.layer = layer | 
 |  | 
 |     def forward(self, X): | 
 |         T = X.size(-1) | 
 |         out = [] | 
 |         for t in range(T): | 
 |             m = self.layer(X[..., t]) | 
 |             out.append(m) | 
 |         out = torch.stack(out, dim=-1) | 
 |         return out | 
 |  | 
 |  | 
 | class LIFSumOfSq(torch.nn.Module): | 
 |     def __init__(self): | 
 |         super(LIFSumOfSq, self).__init__() | 
 |         self.spike = LIF() | 
 |         self.layer = tdLayer(sqSum) | 
 |  | 
 |     def forward(self, X): | 
 |         out = self.spike(X) | 
 |         out = self.layer(out) | 
 |         return out |