[mpact][compiler] MPACT PyTorch compiler with tests
This provides a completely end-to-end pipeline for
the MPACT PyTorch compiler.
This supports:
result = mpact_jit(net, ...) # compile and run
invoker, fn = mpact_jit_compile(net, ...) # compile, then
result = mpact_jit_run(invoker, fn, ...) # run
To test:
cmake --build build --target check-mpact
diff --git a/.gitignore b/.gitignore
index eac555c..b6e8e5b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
mpact_venv/
+__pycache__
/build/
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..b935118
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,18 @@
+#-------------------------------------------------------------------------------
+# The MPACT Compiler
+#-------------------------------------------------------------------------------
+
+cmake_minimum_required(VERSION 3.12)
+
+project(mpact VERSION 1.0 LANGUAGES CXX C)
+
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_CXX_STANDARD 17)
+
+set(MPACT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
+set(MPACT_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
+message(STATUS "Building the MPACT compiler at ${MPACT_SOURCE_DIR} (into ${MPACT_BINARY_DIR})")
+
+add_subdirectory(benchmark)
+add_subdirectory(python)
+add_subdirectory(test)
diff --git a/README.md b/README.md
index 4bed1da..fa611b7 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,12 @@
source mpact_venv/bin/activate # for each session
```
+Also make sure to set the Python paths as follows.
+
+```shell
+export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/python
+```
+
### Install build requirements
Note that currently we rely on `torch-mlir` requirements defined in the
@@ -45,3 +51,28 @@
python -m pip install -r externals/torch-mlir/requirements.txt
python -m pip install -r externals/torch-mlir/torchvision-requirements.txt
```
+
+### Building the MPACT compiler in-tree
+
+The following command generates configuration files to build the MPACT compiler
+project completely *in-tree*, which means that both LLVM as well as torch-mlir
+are built from source.
+
+```shell
+cmake -GNinja -Bbuild \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DPython3_FIND_VIRTUALENV=ONLY \
+ -DLLVM_ENABLE_PROJECTS=mlir \
+ -DLLVM_EXTERNAL_PROJECTS="torch-mlir;mpact" \
+ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="${PWD}/externals/torch-mlir" \
+ -DLLVM_EXTERNAL_MPACT_SOURCE_DIR="${PWD}" \
+ -DLLVM_TARGETS_TO_BUILD=host \
+ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
+ externals/torch-mlir/externals/llvm-project/llvm
+```
+
+Run the following to ensure the MPACT compiler builds and runs correctly.
+
+```shell
+cmake --build build --target check-mpact
+```
diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt
new file mode 100644
index 0000000..2d67cdc
--- /dev/null
+++ b/benchmark/CMakeLists.txt
@@ -0,0 +1,5 @@
+#-------------------------------------------------------------------------------
+# The MPACT Compiler Benchmarks
+#-------------------------------------------------------------------------------
+
+# TODO(yinying): add all our benchmarks under benchmark/python/*
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
new file mode 100644
index 0000000..4d49440
--- /dev/null
+++ b/python/CMakeLists.txt
@@ -0,0 +1,5 @@
+#-------------------------------------------------------------------------------
+# The MPACT Compiler Python Modules
+#-------------------------------------------------------------------------------
+
+declare_mlir_python_sources(MPACTPythonSources)
diff --git a/python/mpactbackend.py b/python/mpactbackend.py
new file mode 100644
index 0000000..80b58df
--- /dev/null
+++ b/python/mpactbackend.py
@@ -0,0 +1,416 @@
+import ctypes
+import numpy as np
+import torch
+
+from typing import Any, Callable, Optional, Tuple, Dict
+
+from torch_mlir import ir
+from torch_mlir.compiler_utils import run_pipeline_with_repro_report
+from torch_mlir.dialects import torch as torch_d
+from torch_mlir.execution_engine import *
+from torch_mlir.extras.fx_importer import FxImporter, SparsityMeta
+from torch_mlir.ir import *
+from torch_mlir.passmanager import *
+from torch_mlir.runtime import *
+
+from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
+ LinalgOnTensorsBackend,
+)
+
+
+def assert_arg_type_is_supported(ty):
+ SUPPORTED = [
+ np.float16,
+ np.float32,
+ np.float64,
+ np.uint8,
+ np.int8,
+ np.int32,
+ np.int64,
+ np.bool_,
+ np.complex64,
+ np.complex128,
+ ]
+ assert (
+ ty in SUPPORTED
+ ), f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
+
+
+memref_type_to_np_dtype = {
+ "mrf16": np.float16,
+ "mrf32": np.float32,
+ "mrf64": np.float64,
+ "mri1": np.bool_,
+ "mri8": np.int8,
+ "mri32": np.int32,
+ "mri64": np.int64,
+ "mrc32": np.complex64,
+ "mrc64": np.complex128,
+}
+elemental_type_to_ctype = {
+ "i1": ctypes.c_bool,
+ "i8": ctypes.c_byte,
+ "i64": ctypes.c_int,
+ "f32": ctypes.c_float,
+ "f64": ctypes.c_double,
+}
+
+CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"
+
+SPARSE_LAYOUTS = [
+ torch.sparse_coo,
+ torch.sparse_csr,
+ torch.sparse_csc,
+ torch.sparse_bsr,
+ torch.sparse_bsc,
+]
+
+
+def get_return_funcs(module):
+ return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
+ return_funcs = []
+ with module.context:
+ for func in module.body:
+ # Returns strings of the form `"refbackend.."` so `"` is deleted.
+ func_name = str(func.attributes["sym_name"]).replace('"', "")
+ if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX:
+ return_funcs.append(func_name)
+
+ return return_funcs
+
+
+def get_ctype_func(func_name):
+ return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
+ ret_types = func_name[return_prefix_len:].split("_")
+ ctypes_arg = [None]
+ for type in ret_types:
+ if type in elemental_type_to_ctype:
+ ctypes_arg.append(elemental_type_to_ctype[type])
+ elif type in memref_type_to_np_dtype:
+ ctypes_arg.append(ctypes.POINTER(UnrankedMemRefDescriptor))
+ else:
+ assert False, f"Not supported type: {type}"
+
+ return ctypes.CFUNCTYPE(*ctypes_arg), ret_types
+
+
+class MpactBackendInvoker:
+ def __init__(self, module, opt_level=2, shared_libs=[]):
+ self.ee = ExecutionEngine(module, opt_level=opt_level, shared_libs=shared_libs)
+ self.result = None
+
+ return_funcs = get_return_funcs(module)
+
+ for ret_func in return_funcs:
+ ctype_wrapper, ret_types = get_ctype_func(ret_func)
+
+ def consume_return_funcs(*args):
+ self.result = tuple(
+ [
+ (
+ arg
+ if type in elemental_type_to_ctype
+ else unranked_memref_to_numpy(
+ arg, memref_type_to_np_dtype[type]
+ )
+ )
+ for arg, type in zip(args, ret_types)
+ ]
+ )
+ if len(self.result) == 1:
+ self.result = self.result[0]
+
+ self.ee.register_runtime(ret_func, ctype_wrapper(consume_return_funcs))
+
+ def __getattr__(self, function_name: str):
+ def invoke(*args):
+ ffi_args = []
+ for arg in args:
+ assert_arg_type_is_supported(arg.dtype)
+ ffi_args.append(
+ ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg)))
+ )
+
+ self.ee.invoke(function_name, *ffi_args)
+ result = self.result
+ assert result is not None, "Invocation didn't produce a result"
+ self.result = None
+ return result
+
+ return invoke
+
+
+LOWERING_PIPELINE = (
+ "builtin.module("
+ + ",".join(
+ [
+ "func.func(linalg-generalize-named-ops)",
+ "func.func(linalg-fuse-elementwise-ops)",
+ "convert-shape-to-std",
+ # MLIR Sparsifier mini-pipeline.
+ "sparse-assembler{direct-out}",
+ "sparsification-and-bufferization",
+ "sparse-storage-specifier-to-llvm",
+ # Buffer deallocation pass does not know how to handle realloc.
+ "func.func(expand-realloc)",
+ # Generalize pad and concat after sparse compiler, as they are handled
+ # differently when the operations involve sparse operands.
+ "func.func(refback-generalize-tensor-pad)",
+ "func.func(refback-generalize-tensor-concat)",
+ # Bufferize.
+ "func.func(scf-bufferize)",
+ "func.func(tm-tensor-bufferize)",
+ "func.func(empty-tensor-to-alloc-tensor)",
+ "func.func(linalg-bufferize)",
+ "func-bufferize",
+ "arith-bufferize",
+ "refback-mlprogram-bufferize",
+ "func.func(tensor-bufferize)",
+ "func.func(finalizing-bufferize)",
+ "func.func(buffer-deallocation)",
+ # Inline sparse helper methods where useful (but after dealloc).
+ "inline",
+ "refback-munge-calling-conventions",
+ "func.func(tm-tensor-to-loops)",
+ "func.func(refback-munge-memref-copy)",
+ "func.func(convert-linalg-to-loops)",
+ "func.func(lower-affine)",
+ "convert-scf-to-cf",
+ "func.func(refback-expand-ops-for-llvm)",
+ "func.func(arith-expand)",
+ "func.func(convert-math-to-llvm)",
+ "convert-math-to-libm",
+ "expand-strided-metadata",
+ "finalize-memref-to-llvm",
+ "lower-affine",
+ "convert-bufferization-to-memref",
+ "finalize-memref-to-llvm",
+ "func.func(convert-arith-to-llvm)",
+ "convert-vector-to-llvm",
+ "convert-func-to-llvm",
+ "convert-cf-to-llvm",
+ "convert-complex-to-llvm",
+ "reconcile-unrealized-casts",
+ ]
+ )
+ + ")"
+)
+
+
+class MpactBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
+ """Main entry-point for the MPACT backend."""
+
+ def __init__(self):
+ super().__init__()
+
+ def compile(self, imported_module: Module):
+ """Compiles an imported module, with a flat list of functions.
+ The module is expected to be in linalg-on-tensors + scalar code form.
+
+ Args:
+ imported_module: The MLIR module in the torch dialect.
+ Returns:
+ An opaque artifact that can be passed to `load`.
+ """
+ run_pipeline_with_repro_report(
+ imported_module,
+ LOWERING_PIPELINE,
+ "Lowering Linalg-on-Tensors IR to LLVM with MpactBackend",
+ enable_ir_printing=False,
+ )
+ return imported_module
+
+ def load(self, module) -> MpactBackendInvoker:
+ """Loads a compiled artifact into the runtime."""
+ return MpactBackendInvoker(module)
+
+
+def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
+ """
+ Returns a meta data tuple for the given sparse tensor.
+
+ NOTE: this will be fully replaced by fx graph SparseTensorMetadata
+ """
+ sparse_dim = a.sparse_dim()
+ dense_dim = a.dense_dim()
+ batch_dim = a.ndim - dense_dim - sparse_dim
+ blocksize = None
+ if a.layout is torch.sparse_coo:
+ return SparsityMeta(
+ a.layout,
+ batch_dim,
+ sparse_dim,
+ dense_dim,
+ blocksize,
+ a._indices().dtype,
+ a._indices().dtype,
+ )
+ elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
+ if a.layout is torch.sparse_bsr:
+ blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
+ return SparsityMeta(
+ a.layout,
+ batch_dim,
+ sparse_dim,
+ dense_dim,
+ blocksize,
+ a.crow_indices().dtype,
+ a.col_indices().dtype,
+ )
+ elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
+ if a.layout is torch.sparse_bsc:
+ blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
+ return SparsityMeta(
+ a.layout,
+ batch_dim,
+ sparse_dim,
+ dense_dim,
+ blocksize,
+ a.ccol_indices().dtype,
+ a.row_indices().dtype,
+ )
+ else:
+ raise RuntimeError(f"Unsupported sparse layout for {a}")
+
+
+def sparse_export(
+ f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
+) -> torch.export.ExportedProgram:
+ """
+ This is a ***temporary*** wrapper around `torch.export.export`
+ that eventually should be removed and simply replaced by the
+ standard API for exporting traced graphs.
+
+ But until issue
+
+ https://github.com/pytorch/pytorch/pull/117907
+
+ is addressed, this wrapper provides support for the sparse
+ tensor types by first converting all operands to dense tensors,
+ building the traced graph as for the dense case, then annotating
+ sparse parameters with their actual sparse layout attributes,
+ followed by some simple propagation rules. This temporary solution
+ accelerates testing torch-mlir with PyTorch sparse tensors until
+ the issue is resolved upstream.
+ """
+ # Convert all arguments to dense.
+ dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
+ mask = [a.layout in SPARSE_LAYOUTS for a in args]
+ # Build the regular FX traced graph with only dense arguments
+ # (the current version would crash otherwise, see issue above).
+ prog = torch.export.export(f, dargs, kwargs)
+ # Annotate sparse arguments in the graph and apply some very
+ # basic propagation rules for sparsity.
+ specs = prog.graph_signature.input_specs
+ alen = len(specs)
+ k = 0
+ for i, node in enumerate(prog.graph.nodes):
+ if node.op == "placeholder":
+ # Argument.
+ spec = specs[i]
+ if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
+ if mask[k]:
+ node.meta["sparsity"] = sparse_metadata(args[k])
+ k = k + 1
+ elif node.op == "call_function":
+ # TODO: use upstream _opname implementation when available
+ opname = node.target._schema.name.split("::")[1]
+ # Zero preserving elt-wise unary op.
+ if opname in {"abs", "neg", "relu", "sin"}:
+ node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
+ elif opname == "_to_sparse":
+ dim = len(node.meta.get("val").shape)
+ node.meta["sparsity"] = SparsityMeta(
+ torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
+ )
+ # TODO: Uncomment this to hack sparsity into the network.
+ # elif opname == "_to_dense":
+ # # hack (assumes we never really want the to_dense for now)
+ # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
+ elif opname == "select" and node.args[0].meta.get("sparsity", None):
+ dim = len(node.meta.get("val").shape)
+ node.meta["sparsity"] = SparsityMeta(
+ torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
+ )
+ elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
+ dim = len(node.meta.get("val").shape)
+ node.meta["sparsity"] = SparsityMeta(
+ torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
+ )
+ return prog
+
+
+def export_and_import(f, *args, **kwargs):
+ """This method implements Stella's importer, stripped down to essentials."""
+ context = ir.Context()
+ torch_d.register_dialect(context)
+ fx_importer = FxImporter(context=context)
+ prog = sparse_export(f, args, kwargs)
+ fx_importer.import_frozen_program(prog)
+ return fx_importer.module
+
+
+def mpact_jit_compile(f, *args, **kwargs):
+ """This method compiles the given callable using the MPACT backend."""
+ # Import module and lower into Linalg IR.
+ module = export_and_import(f, *args, **kwargs)
+ run_pipeline_with_repro_report(
+ module,
+ (
+ "builtin.module("
+ "func.func(torch-decompose-complex-ops),"
+ "torch-backend-to-linalg-on-tensors-backend-pipeline)"
+ ),
+ "Lowering TorchFX IR -> Linalg IR",
+ enable_ir_printing=False,
+ )
+ # Compile with MPACT backend.
+ backend = MpactBackendLinalgOnTensorsBackend()
+ compiled = backend.compile(module)
+ invoker = backend.load(compiled)
+ return invoker, f
+
+
+def mpact_jit_run(invoker, f, *args, **kwargs):
+ """This method runs the given callable using the given MPACT invoker."""
+ xargs = []
+ # Prepare the buffer parameters (assume all dense).
+ # TODO: filters out scalar arguments, anything else?
+ params = dict(f.named_buffers(remove_duplicate=True))
+ params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
+ for p in params_flat:
+ if len(p.shape) > 0:
+ xargs.append(p.numpy())
+ # Prepare input parameters. Sparse input tensors are split into
+ # their composite tensors. All PyTorch tensors are converted
+ # to their backing numpy arrays. Note that the output consists
+ # of numpy arrays as well, which can trivially be reconstructed
+ # into PyTorch tensors (dense and sparse).
+ for a in args:
+ if a.layout is torch.sparse_coo:
+ # Construct the additional position array required by MLIR with data
+ # array([0, nnz]). The COO format always uses int64 indices.
+ xargs.append(np.array([0, a._nnz()], dtype=np.int64))
+ # Transform a tensor<ndim x nnz> into ndim x tensor<nnz> to conform
+ # to the MLIR SoA COO representation.
+ for idx in a._indices():
+ xargs.append(idx.numpy())
+ xargs.append(a._values().numpy())
+ elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
+ xargs.append(a.crow_indices().numpy())
+ xargs.append(a.col_indices().numpy())
+ xargs.append(a.values().numpy())
+ elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
+ xargs.append(a.ccol_indices().numpy())
+ xargs.append(a.row_indices().numpy())
+ xargs.append(a.values().numpy())
+ else:
+ xargs.append(a.numpy())
+ # Invoke.
+ return invoker.main(*xargs)
+
+
+def mpact_jit(f, *args, **kwargs):
+ """This method compiles and runs the given callable using the MPACT backend."""
+ invoker, fn = mpact_jit_compile(f, *args, **kwargs)
+ return mpact_jit_run(invoker, fn, *args, **kwargs)
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
new file mode 100644
index 0000000..378640d
--- /dev/null
+++ b/test/CMakeLists.txt
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------------------------
+# The MPACT Compiler Tests
+#-------------------------------------------------------------------------------
+
+configure_lit_site_cfg(
+ ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
+ ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
+ MAIN_CONFIG
+ ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
+)
+
+set(MPACT_TEST_DEPENDS
+ FileCheck count not
+ MPACTPythonSources
+ TorchMLIRPythonModules
+ torch-mlir-opt
+ )
+
+add_lit_testsuite(check-mpact "Running the MPACT regression tests"
+ ${CMAKE_CURRENT_BINARY_DIR}
+ DEPENDS ${MPACT_TEST_DEPENDS}
+ )
+set_target_properties(check-mpact PROPERTIES FOLDER "Tests")
+
+add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
diff --git a/test/lit.cfg.py b/test/lit.cfg.py
new file mode 100644
index 0000000..3a9297d
--- /dev/null
+++ b/test/lit.cfg.py
@@ -0,0 +1,80 @@
+#-------------------------------------------------------------------------------
+# The MPACT Compiler LIT Configuration
+#-------------------------------------------------------------------------------
+
+import os
+import platform
+import re
+import subprocess
+import tempfile
+
+import lit.formats
+import lit.util
+
+from lit.llvm import llvm_config
+from lit.llvm.subst import ToolSubst
+from lit.llvm.subst import FindTool
+
+# The name of this test suite.
+config.name = "MPACT"
+
+# The test format.
+config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
+
+# A list of file extensions to treat as test files.
+config.suffixes = [".py"]
+
+# A list of files to exclude from the test suite.
+config.excludes = [
+ "CMakeLists.txt",
+ "README.txt",
+ "LICENSE.txt",
+ "lit.cfg.py",
+ "lit.site.cfg.py",
+]
+
+# The root path where tests are located.
+config.test_source_root = os.path.dirname(__file__)
+
+# The root path where tests should be run.
+config.test_exec_root = os.path.join(config.mpact_obj_root, "test")
+config.standalone_tools_dir = os.path.join(config.mpact_obj_root, "bin")
+
+# Substitutions.
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
+
+# Tweak the PATH to include the tools dir.
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment(
+ "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True
+)
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
+
+# On Windows the path to python could contains spaces in which case it needs to
+# be provided in quotes. This is the equivalent of how %python is setup in
+# llvm/utils/lit/lit/llvm/config.py.
+if "Windows" in config.host_os:
+ config.python_executable = '"%s"' % (config.python_executable)
+
+# Tools.
+tool_dirs = [
+ config.standalone_tools_dir,
+ config.llvm_tools_dir,
+ config.mpact_obj_root,
+]
+tools = [
+ "torch-mlir-opt",
+ ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"),
+]
+
+llvm_config.add_tool_substitutions(tools, tool_dirs)
+
+llvm_config.with_environment(
+ "PYTHONPATH",
+ [
+ os.path.join(config.mpact_src_root, "python"),
+ os.path.join(config.torch_mlir_obj_root, "python_packages/torch_mlir"),
+ ],
+ append_path=True,
+)
diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in
new file mode 100644
index 0000000..01f3d01
--- /dev/null
+++ b/test/lit.site.cfg.py.in
@@ -0,0 +1,23 @@
+@LIT_SITE_CFG_IN_HEADER@
+
+import sys
+
+config.host_os = "@HOST_OS@"
+config.mpact_src_root = "@MPACT_SOURCE_DIR@"
+config.mpact_obj_root = "@MPACT_BINARY_DIR@"
+config.torch_mlir_obj_root = "@LLVM_BINARY_DIR@/tools/torch-mlir"
+config.llvm_src_root = "@LLVM_SOURCE_DIR@"
+config.llvm_obj_root = "@LLVM_BINARY_DIR@"
+config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
+config.llvm_build_dir = "@CMAKE_BINARY_DIR@"
+config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
+config.llvm_shlib_dir = "@SHLIBDIR@"
+config.llvm_shlib_ext = "@SHLIBEXT@"
+config.llvm_exe_ext = "@EXEEXT@"
+config.python_executable = "@Python3_EXECUTABLE@"
+
+import lit.llvm
+lit.llvm.initialize(lit_config, config)
+
+# Let the main config do the real work.
+lit_config.load_config(config, "@MPACT_SOURCE_DIR@/test/lit.cfg.py")
diff --git a/test/python/sparse_gcn.py b/test/python/sparse_gcn.py
new file mode 100644
index 0000000..f2d7b2f
--- /dev/null
+++ b/test/python/sparse_gcn.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+
+from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+
+class GraphConv(torch.nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super(GraphConv, self).__init__()
+ self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim))
+ torch.nn.init.ones_(self.kernel)
+ self.bias = torch.nn.Parameter(torch.Tensor(output_dim))
+ torch.nn.init.ones_(self.bias)
+
+ def forward(self, inp, adj_mat):
+ # Input matrix times weight matrix.
+ support = torch.mm(inp, self.kernel)
+ # Sparse adjacency matrix times support matrix.
+ output = torch.spmm(adj_mat, support)
+ # Add bias.
+ output = output + self.bias
+ return output
+
+
+net = GraphConv(4, 4)
+
+# Get random (but reproducible) matrices.
+torch.manual_seed(0)
+inp = torch.rand(4, 4)
+adj_mat = torch.rand(4, 4).to_sparse()
+
+#
+# CHECK: pytorch
+# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778],
+# CHECK: [5.7502, 5.7502, 5.7502, 5.7502],
+# CHECK: [4.6980, 4.6980, 4.6980, 4.6980],
+# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}})
+# CHECK: mpact compile and run
+# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
+# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
+# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
+# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
+# CHECK: mpact compile
+# CHECK: mpact run
+# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ]
+# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717]
+# CHECK: [4.697952 4.697952 4.697952 4.697952 ]
+# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}}
+#
+with torch.no_grad():
+ # Run it with PyTorch.
+ print("pytorch")
+ res = net(inp, adj_mat)
+ print(res)
+
+ # Run it with MPACT (compile and run at once).
+ print("mpact compile and run")
+ res = mpact_jit(net, inp, adj_mat)
+ print(res)
+
+ # Run it with MPACT (with separate compile and run steps).
+ print("mpact compile")
+ invoker, fn = mpact_jit_compile(net, inp, adj_mat)
+ print("mpact run")
+ res = mpact_jit_run(invoker, fn, inp, adj_mat)
+ print(res)
diff --git a/test/python/sparse_lif.py b/test/python/sparse_lif.py
new file mode 100644
index 0000000..5f57f32
--- /dev/null
+++ b/test/python/sparse_lif.py
@@ -0,0 +1,87 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+
+from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+
+def spike(input):
+ return (input >= 0).float()
+
+
+def sqSum(input):
+ return (input * input).sum()
+
+
+class LIF(torch.nn.Module):
+ def __init__(self):
+ super(LIF, self).__init__()
+ self.thresh = 1.0
+ self.decay = 0.5
+ self.act = spike
+
+ def forward(self, X):
+ """A filter that yields a binary-valued sparse tensor."""
+ mem = 0
+ spike_pot = []
+ T = X.size(-1)
+ for t in range(T):
+ mem = mem * self.decay + X[..., t]
+ spike = self.act(mem - self.thresh)
+ spike = spike.to_sparse().to_dense() # prop hack
+ mem = mem * (1.0 - spike)
+ spike_pot.append(spike)
+ spike_pot = torch.stack(spike_pot, dim=-1)
+ return spike_pot
+
+
+class tdLayer(torch.nn.Module):
+ def __init__(self, layer):
+ super(tdLayer, self).__init__()
+ self.layer = layer
+
+ def forward(self, X):
+ T = X.size(-1)
+ out = []
+ for t in range(T):
+ m = self.layer(X[..., t])
+ out.append(m)
+ out = torch.stack(out, dim=-1)
+ return out
+
+
+class Block(torch.nn.Module):
+ def __init__(self):
+ super(Block, self).__init__()
+ self.spike = LIF()
+ self.layer = tdLayer(sqSum)
+
+ def forward(self, X):
+ out = self.spike(X)
+ out = self.layer(out)
+ return out
+
+
+net = Block()
+
+# Get a random (but reproducible) input, so that a
+# general sparse tensor appears after LIF.
+torch.manual_seed(0)
+x = torch.rand(2, 3, 8, 8)
+
+#
+# CHECK: pytorch
+# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.])
+# CHECK: mpact
+# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(x)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, x)
+print(res)
diff --git a/test/python/spmv.py b/test/python/spmv.py
new file mode 100644
index 0000000..cb0e491
--- /dev/null
+++ b/test/python/spmv.py
@@ -0,0 +1,33 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import torch
+
+from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run
+
+class SpMVNet(torch.nn.Module):
+ def forward(self, x, v):
+ return torch.mv(x, v)
+
+net = SpMVNet()
+
+# Get a fixed vector and matrix (which we make 2x2 block "sparse").
+dense_vector = torch.arange(1, 11, dtype=torch.float32)
+dense_input = torch.arange(1, 101, dtype=torch.float32).view(10, 10)
+sparse_matrix = dense_input.to_sparse_bsr(blocksize=(2, 2))
+
+#
+# CHECK: pytorch
+# CHECK: tensor([ 385., 935., 1485., 2035., 2585., 3135., 3685., 4235., 4785., 5335.])
+# CHECK: mpact
+# CHECK: [ 385. 935. 1485. 2035. 2585. 3135. 3685. 4235. 4785. 5335.]
+#
+
+# Run it with PyTorch.
+print("pytorch")
+res = net(sparse_matrix, dense_vector)
+print(res)
+
+# Run it with MPACT.
+print("mpact")
+res = mpact_jit(net, sparse_matrix, dense_vector)
+print(res)