[mpact][sparse] add sqsum kernel test. (#47)
diff --git a/python/mpact/models/kernels.py b/python/mpact/models/kernels.py index dce3355..d18d88a 100644 --- a/python/mpact/models/kernels.py +++ b/python/mpact/models/kernels.py
@@ -31,6 +31,11 @@ return torch.mul(x, torch.mm(y, z)) +class SqSum(torch.nn.Module): + def forward(self, x): + return (x * x).sum() + + class FeatureScale(torch.nn.Module): def forward(self, F): sum_vector = torch.sum(F, dim=1)
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index d944376..3e8ae4b 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py
@@ -214,7 +214,7 @@ return invoke -LOWERING_PIPELINE = ( +LOWERING_PIPELINE_TEMPLATE = ( "builtin.module(" + ",".join( [ @@ -229,8 +229,8 @@ # use the PyTorch assembler conventions # enable vectorization with VL=16 (more or less assumes AVX512 for float) # allow 32-bit index optimizations (unsafe for very large dimensions) - "sparse-assembler{direct-out}", - "sparsification-and-bufferization{vl=16 enable-simd-index32}", + "sparse-assembler{{direct-out}}", + "sparsification-and-bufferization{{{sp_options}}}", "sparse-storage-specifier-to-llvm", # Buffer deallocation pass does not know how to handle realloc. "func.func(expand-realloc)", @@ -240,7 +240,7 @@ "func.func(refback-generalize-tensor-concat)", # Bufferize. "func.func(tm-tensor-bufferize)", - "one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}", + "one-shot-bufferize{{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}}", "refback-mlprogram-bufferize", "func.func(finalizing-bufferize)", "func.func(buffer-deallocation)", @@ -266,7 +266,7 @@ # allow fp reductions to reassociate # allow 32-bit index optimizations (unsafe for very large dimensions) # assume we are running on a good ol' Intel X86 (disable for ARM/other) - "convert-vector-to-llvm{reassociate-fp-reductions force-32bit-vector-indices enable-x86vector}", + "convert-vector-to-llvm{{reassociate-fp-reductions force-32bit-vector-indices enable-x86vector}}", "convert-func-to-llvm", "convert-cf-to-llvm", "convert-complex-to-llvm", @@ -280,10 +280,17 @@ class MpactBackendCompiler: """Main entry-point for the MPACT backend compiler.""" - def __init__(self, opt_level): + def __init__(self, opt_level, use_sp_it): self.opt_level = opt_level + self.use_sp_it = use_sp_it def compile(self, imported_module: Module) -> MpactCompiledArtifact: + sp_options = ( + "sparse-emit-strategy=sparse-iterator" + if self.use_sp_it + else "vl=16 enable-simd-index32" + ) + LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(sp_options=sp_options) """Compiles an imported module, with a flat list of functions. The module is expected to be in linalg-on-tensors + scalar code form. @@ -456,7 +463,7 @@ return fx_importer.module -def mpact_jit_compile(f, *args, opt_level=2, **kwargs): +def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs): """This method compiles the given callable using the MPACT backend.""" # Import module and lower into Linalg IR. module = export_and_import(f, *args, **kwargs) @@ -471,7 +478,7 @@ enable_ir_printing=False, ) # Compile with MPACT backend compiler. - backend = MpactBackendCompiler(opt_level=opt_level) + backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it) compiled = backend.compile(module) invoker = backend.load(compiled) return invoker, f
diff --git a/test/python/sqsum.py b/test/python/sqsum.py new file mode 100644 index 0000000..2d21204 --- /dev/null +++ b/test/python/sqsum.py
@@ -0,0 +1,35 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + +from mpact.models.kernels import SqSum + +net = SqSum() + +# Construct adjacency matrix. +V = 8 +edges = np.array([[0, 1], [0, 4], [1, 4], [3, 4], [4, 3]], dtype=np.int32) +E = edges.shape[0] +adj_mat = torch.sparse_coo_tensor(edges.T, torch.ones(E), (V, V), dtype=torch.int64) + +# +# CHECK: pytorch +# CHECK: tensor(5) +# CHECK: mpact +# CHECK: 5 + +# Run it with PyTorch. +print("pytorch") +res = net(adj_mat) +print(res) + +# Run it with MPACT. +print("mpact") +# TODO: make this work, expose `sparse-emit-strategy=sparse-iterator` to +# mini-pipeline. +# res = mpact_jit(net, adj_mat, use_sp_it=True) +res = mpact_jit(net, adj_mat) +print(res)