[mpact][test] Seperate models from tests (#21)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 12c46bd..7815708 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt
@@ -41,8 +41,8 @@ declare_mlir_python_sources(MPACTPythonSources.SampleModels ROOT_DIR "${MPACT_PYTHON_ROOT_DIR}" ADD_TO_PARENT MPACTPythonSources - SOURCES - models/resnet.py + SOURCES_GLOB + models/*.py ) #-------------------------------------------------------------------------------
diff --git a/python/mpact/models/gcn.py b/python/mpact/models/gcn.py new file mode 100644 index 0000000..8a015de --- /dev/null +++ b/python/mpact/models/gcn.py
@@ -0,0 +1,19 @@ +import torch + + +class GraphConv(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super(GraphConv, self).__init__() + self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim)) + torch.nn.init.ones_(self.kernel) + self.bias = torch.nn.Parameter(torch.Tensor(output_dim)) + torch.nn.init.ones_(self.bias) + + def forward(self, inp, adj_mat): + # Input matrix times weight matrix. + support = torch.mm(inp, self.kernel) + # Sparse adjacency matrix times support matrix. + output = torch.spmm(adj_mat, support) + # Add bias. + output = output + self.bias + return output
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py new file mode 100644 index 0000000..f8dbf17 --- /dev/null +++ b/python/mpact/models/kernels.py
@@ -0,0 +1,31 @@ +import torch + + +class MVNet(torch.nn.Module): + def forward(self, x, v): + return torch.mv(x, v) + + +class MMNet(torch.nn.Module): + def forward(self, x, v): + return torch.mm(x, v) + + +class AddNet(torch.nn.Module): + def forward(self, x, v): + return torch.add(x, v) + + +class MulNet(torch.nn.Module): + def forward(self, x, v): + return torch.mul(x, v) + + +class SelfNet(torch.nn.Module): + def forward(self, x): + return x + + +class SDDMMNet(torch.nn.Module): + def forward(self, x, y, z): + return torch.mul(x, torch.mm(y, z))
diff --git a/python/mpact/models/lif.py b/python/mpact/models/lif.py new file mode 100644 index 0000000..37bf94f --- /dev/null +++ b/python/mpact/models/lif.py
@@ -0,0 +1,58 @@ +import torch + + +def spike(input): + return (input >= 0).float() + + +def sqSum(input): + return (input * input).sum() + + +class LIF(torch.nn.Module): + def __init__(self): + super(LIF, self).__init__() + self.thresh = 1.0 + self.decay = 0.5 + self.act = spike + + def forward(self, X): + """A filter that yields a binary-valued sparse tensor.""" + mem = 0 + spike_pot = [] + T = X.size(-1) + for t in range(T): + mem = mem * self.decay + X[..., t] + spike = self.act(mem - self.thresh) + spike = spike.to_sparse().to_dense() # prop hack + mem = mem * (1.0 - spike) + spike_pot.append(spike) + spike_pot = torch.stack(spike_pot, dim=-1) + return spike_pot + + +class tdLayer(torch.nn.Module): + def __init__(self, layer): + super(tdLayer, self).__init__() + self.layer = layer + + def forward(self, X): + T = X.size(-1) + out = [] + for t in range(T): + m = self.layer(X[..., t]) + out.append(m) + out = torch.stack(out, dim=-1) + return out + + +class Block(torch.nn.Module): + def __init__(self): + super(Block, self).__init__() + self.spike = LIF() + self.layer = tdLayer(sqSum) + + def forward(self, X): + out = self.spike(X) + out = self.layer(out) + return out
diff --git a/test/python/sparse_gcn.py b/test/python/sparse_gcn.py index d1d00a0..ce89411 100644 --- a/test/python/sparse_gcn.py +++ b/test/python/sparse_gcn.py
@@ -3,25 +3,7 @@ import torch from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run - - -class GraphConv(torch.nn.Module): - def __init__(self, input_dim, output_dim): - super(GraphConv, self).__init__() - self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim)) - torch.nn.init.ones_(self.kernel) - self.bias = torch.nn.Parameter(torch.Tensor(output_dim)) - torch.nn.init.ones_(self.bias) - - def forward(self, inp, adj_mat): - # Input matrix times weight matrix. - support = torch.mm(inp, self.kernel) - # Sparse adjacency matrix times support matrix. - output = torch.spmm(adj_mat, support) - # Add bias. - output = output + self.bias - return output - +from mpact.models.gcn import GraphConv net = GraphConv(4, 4)
diff --git a/test/python/sparse_lif.py b/test/python/sparse_lif.py index d717af3..38973c9 100644 --- a/test/python/sparse_lif.py +++ b/test/python/sparse_lif.py
@@ -3,64 +3,7 @@ import torch from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run - - -def spike(input): - return (input >= 0).float() - - -def sqSum(input): - return (input * input).sum() - - -class LIF(torch.nn.Module): - def __init__(self): - super(LIF, self).__init__() - self.thresh = 1.0 - self.decay = 0.5 - self.act = spike - - def forward(self, X): - """A filter that yields a binary-valued sparse tensor.""" - mem = 0 - spike_pot = [] - T = X.size(-1) - for t in range(T): - mem = mem * self.decay + X[..., t] - spike = self.act(mem - self.thresh) - spike = spike.to_sparse().to_dense() # prop hack - mem = mem * (1.0 - spike) - spike_pot.append(spike) - spike_pot = torch.stack(spike_pot, dim=-1) - return spike_pot - - -class tdLayer(torch.nn.Module): - def __init__(self, layer): - super(tdLayer, self).__init__() - self.layer = layer - - def forward(self, X): - T = X.size(-1) - out = [] - for t in range(T): - m = self.layer(X[..., t]) - out.append(m) - out = torch.stack(out, dim=-1) - return out - - -class Block(torch.nn.Module): - def __init__(self): - super(Block, self).__init__() - self.spike = LIF() - self.layer = tdLayer(sqSum) - - def forward(self, X): - out = self.spike(X) - out = self.layer(out) - return out - +from mpact.models.lif import Block net = Block()
diff --git a/test/python/spmv.py b/test/python/spmv.py index 29467e1..604fab7 100644 --- a/test/python/spmv.py +++ b/test/python/spmv.py
@@ -3,12 +3,9 @@ import torch from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run +from mpact.models.kernels import MVNet -class SpMVNet(torch.nn.Module): - def forward(self, x, v): - return torch.mv(x, v) - -net = SpMVNet() +net = MVNet() # Get a fixed vector and matrix (which we make 2x2 block "sparse"). dense_vector = torch.arange(1, 11, dtype=torch.float32)