[mpact][sparse] add sqsum kernel test. (#47)
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py
index dce3355..d18d88a 100644
--- a/python/mpact/models/kernels.py
+++ b/python/mpact/models/kernels.py
@@ -31,6 +31,11 @@
return torch.mul(x, torch.mm(y, z))
+class SqSum(torch.nn.Module):
+ def forward(self, x):
+ return (x * x).sum()
+
+
class FeatureScale(torch.nn.Module):
def forward(self, F):
sum_vector = torch.sum(F, dim=1)
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index d944376..3e8ae4b 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -214,7 +214,7 @@
return invoke
-LOWERING_PIPELINE = (
+LOWERING_PIPELINE_TEMPLATE = (
"builtin.module("
+ ",".join(
[
@@ -229,8 +229,8 @@
# use the PyTorch assembler conventions
# enable vectorization with VL=16 (more or less assumes AVX512 for float)
# allow 32-bit index optimizations (unsafe for very large dimensions)
- "sparse-assembler{direct-out}",
- "sparsification-and-bufferization{vl=16 enable-simd-index32}",
+ "sparse-assembler{{direct-out}}",
+ "sparsification-and-bufferization{{{sp_options}}}",
"sparse-storage-specifier-to-llvm",
# Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
@@ -240,7 +240,7 @@
"func.func(refback-generalize-tensor-concat)",
# Bufferize.
"func.func(tm-tensor-bufferize)",
- "one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}",
+ "one-shot-bufferize{{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}}",
"refback-mlprogram-bufferize",
"func.func(finalizing-bufferize)",
"func.func(buffer-deallocation)",
@@ -266,7 +266,7 @@
# allow fp reductions to reassociate
# allow 32-bit index optimizations (unsafe for very large dimensions)
# assume we are running on a good ol' Intel X86 (disable for ARM/other)
- "convert-vector-to-llvm{reassociate-fp-reductions force-32bit-vector-indices enable-x86vector}",
+ "convert-vector-to-llvm{{reassociate-fp-reductions force-32bit-vector-indices enable-x86vector}}",
"convert-func-to-llvm",
"convert-cf-to-llvm",
"convert-complex-to-llvm",
@@ -280,10 +280,17 @@
class MpactBackendCompiler:
"""Main entry-point for the MPACT backend compiler."""
- def __init__(self, opt_level):
+ def __init__(self, opt_level, use_sp_it):
self.opt_level = opt_level
+ self.use_sp_it = use_sp_it
def compile(self, imported_module: Module) -> MpactCompiledArtifact:
+ sp_options = (
+ "sparse-emit-strategy=sparse-iterator"
+ if self.use_sp_it
+ else "vl=16 enable-simd-index32"
+ )
+ LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(sp_options=sp_options)
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
@@ -456,7 +463,7 @@
return fx_importer.module
-def mpact_jit_compile(f, *args, opt_level=2, **kwargs):
+def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
# Import module and lower into Linalg IR.
module = export_and_import(f, *args, **kwargs)
@@ -471,7 +478,7 @@
enable_ir_printing=False,
)
# Compile with MPACT backend compiler.
- backend = MpactBackendCompiler(opt_level=opt_level)
+ backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it)
compiled = backend.compile(module)
invoker = backend.load(compiled)
return invoker, f
diff --git a/test/python/sqsum.py b/test/python/sqsum.py
new file mode 100644
index 0000000..2d21204
--- /dev/null
+++ b/test/python/sqsum.py
@@ -0,0 +1,35 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+from mpact.models.kernels import SqSum
+
+net = SqSum()
+
+# Construct adjacency matrix.
+V = 8
+edges = np.array([[0, 1], [0, 4], [1, 4], [3, 4], [4, 3]], dtype=np.int32)
+E = edges.shape[0]
+adj_mat = torch.sparse_coo_tensor(edges.T, torch.ones(E), (V, V), dtype=torch.int64)
+
+#
+# CHECK: pytorch
+# CHECK: tensor(5)
+# CHECK: mpact
+# CHECK: 5
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(adj_mat)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+# TODO: make this work, expose `sparse-emit-strategy=sparse-iterator` to
+# mini-pipeline.
+# res = mpact_jit(net, adj_mat, use_sp_it=True)
+res = mpact_jit(net, adj_mat)
+print(res)