[mpact][test] Seperate models from tests (#21)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index 12c46bd..7815708 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -41,8 +41,8 @@
declare_mlir_python_sources(MPACTPythonSources.SampleModels
ROOT_DIR "${MPACT_PYTHON_ROOT_DIR}"
ADD_TO_PARENT MPACTPythonSources
- SOURCES
- models/resnet.py
+ SOURCES_GLOB
+ models/*.py
)
#-------------------------------------------------------------------------------
diff --git a/python/mpact/models/gcn.py b/python/mpact/models/gcn.py
new file mode 100644
index 0000000..8a015de
--- /dev/null
+++ b/python/mpact/models/gcn.py
@@ -0,0 +1,19 @@
+import torch
+
+
+class GraphConv(torch.nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super(GraphConv, self).__init__()
+ self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim))
+ torch.nn.init.ones_(self.kernel)
+ self.bias = torch.nn.Parameter(torch.Tensor(output_dim))
+ torch.nn.init.ones_(self.bias)
+
+ def forward(self, inp, adj_mat):
+ # Input matrix times weight matrix.
+ support = torch.mm(inp, self.kernel)
+ # Sparse adjacency matrix times support matrix.
+ output = torch.spmm(adj_mat, support)
+ # Add bias.
+ output = output + self.bias
+ return output
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py
new file mode 100644
index 0000000..f8dbf17
--- /dev/null
+++ b/python/mpact/models/kernels.py
@@ -0,0 +1,31 @@
+import torch
+
+
+class MVNet(torch.nn.Module):
+ def forward(self, x, v):
+ return torch.mv(x, v)
+
+
+class MMNet(torch.nn.Module):
+ def forward(self, x, v):
+ return torch.mm(x, v)
+
+
+class AddNet(torch.nn.Module):
+ def forward(self, x, v):
+ return torch.add(x, v)
+
+
+class MulNet(torch.nn.Module):
+ def forward(self, x, v):
+ return torch.mul(x, v)
+
+
+class SelfNet(torch.nn.Module):
+ def forward(self, x):
+ return x
+
+
+class SDDMMNet(torch.nn.Module):
+ def forward(self, x, y, z):
+ return torch.mul(x, torch.mm(y, z))
diff --git a/python/mpact/models/lif.py b/python/mpact/models/lif.py
new file mode 100644
index 0000000..37bf94f
--- /dev/null
+++ b/python/mpact/models/lif.py
@@ -0,0 +1,58 @@
+import torch
+
+
+def spike(input):
+ return (input >= 0).float()
+
+
+def sqSum(input):
+ return (input * input).sum()
+
+
+class LIF(torch.nn.Module):
+ def __init__(self):
+ super(LIF, self).__init__()
+ self.thresh = 1.0
+ self.decay = 0.5
+ self.act = spike
+
+ def forward(self, X):
+ """A filter that yields a binary-valued sparse tensor."""
+ mem = 0
+ spike_pot = []
+ T = X.size(-1)
+ for t in range(T):
+ mem = mem * self.decay + X[..., t]
+ spike = self.act(mem - self.thresh)
+ spike = spike.to_sparse().to_dense() # prop hack
+ mem = mem * (1.0 - spike)
+ spike_pot.append(spike)
+ spike_pot = torch.stack(spike_pot, dim=-1)
+ return spike_pot
+
+
+class tdLayer(torch.nn.Module):
+ def __init__(self, layer):
+ super(tdLayer, self).__init__()
+ self.layer = layer
+
+ def forward(self, X):
+ T = X.size(-1)
+ out = []
+ for t in range(T):
+ m = self.layer(X[..., t])
+ out.append(m)
+ out = torch.stack(out, dim=-1)
+ return out
+
+
+class Block(torch.nn.Module):
+ def __init__(self):
+ super(Block, self).__init__()
+ self.spike = LIF()
+ self.layer = tdLayer(sqSum)
+
+ def forward(self, X):
+ out = self.spike(X)
+ out = self.layer(out)
+ return out
diff --git a/test/python/sparse_gcn.py b/test/python/sparse_gcn.py
index d1d00a0..ce89411 100644
--- a/test/python/sparse_gcn.py
+++ b/test/python/sparse_gcn.py
@@ -3,25 +3,7 @@
import torch
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
-
-
-class GraphConv(torch.nn.Module):
- def __init__(self, input_dim, output_dim):
- super(GraphConv, self).__init__()
- self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim))
- torch.nn.init.ones_(self.kernel)
- self.bias = torch.nn.Parameter(torch.Tensor(output_dim))
- torch.nn.init.ones_(self.bias)
-
- def forward(self, inp, adj_mat):
- # Input matrix times weight matrix.
- support = torch.mm(inp, self.kernel)
- # Sparse adjacency matrix times support matrix.
- output = torch.spmm(adj_mat, support)
- # Add bias.
- output = output + self.bias
- return output
-
+from mpact.models.gcn import GraphConv
net = GraphConv(4, 4)
diff --git a/test/python/sparse_lif.py b/test/python/sparse_lif.py
index d717af3..38973c9 100644
--- a/test/python/sparse_lif.py
+++ b/test/python/sparse_lif.py
@@ -3,64 +3,7 @@
import torch
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
-
-
-def spike(input):
- return (input >= 0).float()
-
-
-def sqSum(input):
- return (input * input).sum()
-
-
-class LIF(torch.nn.Module):
- def __init__(self):
- super(LIF, self).__init__()
- self.thresh = 1.0
- self.decay = 0.5
- self.act = spike
-
- def forward(self, X):
- """A filter that yields a binary-valued sparse tensor."""
- mem = 0
- spike_pot = []
- T = X.size(-1)
- for t in range(T):
- mem = mem * self.decay + X[..., t]
- spike = self.act(mem - self.thresh)
- spike = spike.to_sparse().to_dense() # prop hack
- mem = mem * (1.0 - spike)
- spike_pot.append(spike)
- spike_pot = torch.stack(spike_pot, dim=-1)
- return spike_pot
-
-
-class tdLayer(torch.nn.Module):
- def __init__(self, layer):
- super(tdLayer, self).__init__()
- self.layer = layer
-
- def forward(self, X):
- T = X.size(-1)
- out = []
- for t in range(T):
- m = self.layer(X[..., t])
- out.append(m)
- out = torch.stack(out, dim=-1)
- return out
-
-
-class Block(torch.nn.Module):
- def __init__(self):
- super(Block, self).__init__()
- self.spike = LIF()
- self.layer = tdLayer(sqSum)
-
- def forward(self, X):
- out = self.spike(X)
- out = self.layer(out)
- return out
-
+from mpact.models.lif import Block
net = Block()
diff --git a/test/python/spmv.py b/test/python/spmv.py
index 29467e1..604fab7 100644
--- a/test/python/spmv.py
+++ b/test/python/spmv.py
@@ -3,12 +3,9 @@
import torch
from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+from mpact.models.kernels import MVNet
-class SpMVNet(torch.nn.Module):
- def forward(self, x, v):
- return torch.mv(x, v)
-
-net = SpMVNet()
+net = MVNet()
# Get a fixed vector and matrix (which we make 2x2 block "sparse").
dense_vector = torch.arange(1, 11, dtype=torch.float32)