[mpact][compiler] MPACT PyTorch compiler with tests This provides a completely end-to-end pipeline for the MPACT PyTorch compiler. This supports: result = mpact_jit(net, ...) # compile and run invoker, fn = mpact_jit_compile(net, ...) # compile, then result = mpact_jit_run(invoker, fn, ...) # run To test: cmake --build build --target check-mpact
diff --git a/.gitignore b/.gitignore index eac555c..b6e8e5b 100644 --- a/.gitignore +++ b/.gitignore
@@ -1,2 +1,3 @@ mpact_venv/ +__pycache__ /build/
diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..b935118 --- /dev/null +++ b/CMakeLists.txt
@@ -0,0 +1,18 @@ +#------------------------------------------------------------------------------- +# The MPACT Compiler +#------------------------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.12) + +project(mpact VERSION 1.0 LANGUAGES CXX C) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) + +set(MPACT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +set(MPACT_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") +message(STATUS "Building the MPACT compiler at ${MPACT_SOURCE_DIR} (into ${MPACT_BINARY_DIR})") + +add_subdirectory(benchmark) +add_subdirectory(python) +add_subdirectory(test)
diff --git a/README.md b/README.md index 4bed1da..fa611b7 100644 --- a/README.md +++ b/README.md
@@ -35,6 +35,12 @@ source mpact_venv/bin/activate # for each session ``` +Also make sure to set the Python paths as follows. + +```shell +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/python +``` + ### Install build requirements Note that currently we rely on `torch-mlir` requirements defined in the @@ -45,3 +51,28 @@ python -m pip install -r externals/torch-mlir/requirements.txt python -m pip install -r externals/torch-mlir/torchvision-requirements.txt ``` + +### Building the MPACT compiler in-tree + +The following command generates configuration files to build the MPACT compiler +project completely *in-tree*, which means that both LLVM as well as torch-mlir +are built from source. + +```shell +cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir;mpact" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="${PWD}/externals/torch-mlir" \ + -DLLVM_EXTERNAL_MPACT_SOURCE_DIR="${PWD}" \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + externals/torch-mlir/externals/llvm-project/llvm +``` + +Run the following to ensure the MPACT compiler builds and runs correctly. + +```shell +cmake --build build --target check-mpact +```
diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt new file mode 100644 index 0000000..2d67cdc --- /dev/null +++ b/benchmark/CMakeLists.txt
@@ -0,0 +1,5 @@ +#------------------------------------------------------------------------------- +# The MPACT Compiler Benchmarks +#------------------------------------------------------------------------------- + +# TODO(yinying): add all our benchmarks under benchmark/python/*
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 0000000..4d49440 --- /dev/null +++ b/python/CMakeLists.txt
@@ -0,0 +1,5 @@ +#------------------------------------------------------------------------------- +# The MPACT Compiler Python Modules +#------------------------------------------------------------------------------- + +declare_mlir_python_sources(MPACTPythonSources)
diff --git a/python/mpactbackend.py b/python/mpactbackend.py new file mode 100644 index 0000000..80b58df --- /dev/null +++ b/python/mpactbackend.py
@@ -0,0 +1,416 @@ +import ctypes +import numpy as np +import torch + +from typing import Any, Callable, Optional, Tuple, Dict + +from torch_mlir import ir +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.dialects import torch as torch_d +from torch_mlir.execution_engine import * +from torch_mlir.extras.fx_importer import FxImporter, SparsityMeta +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.runtime import * + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + LinalgOnTensorsBackend, +) + + +def assert_arg_type_is_supported(ty): + SUPPORTED = [ + np.float16, + np.float32, + np.float64, + np.uint8, + np.int8, + np.int32, + np.int64, + np.bool_, + np.complex64, + np.complex128, + ] + assert ( + ty in SUPPORTED + ), f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}" + + +memref_type_to_np_dtype = { + "mrf16": np.float16, + "mrf32": np.float32, + "mrf64": np.float64, + "mri1": np.bool_, + "mri8": np.int8, + "mri32": np.int32, + "mri64": np.int64, + "mrc32": np.complex64, + "mrc64": np.complex128, +} +elemental_type_to_ctype = { + "i1": ctypes.c_bool, + "i8": ctypes.c_byte, + "i64": ctypes.c_int, + "f32": ctypes.c_float, + "f64": ctypes.c_double, +} + +CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_" + +SPARSE_LAYOUTS = [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, +] + + +def get_return_funcs(module): + return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX) + return_funcs = [] + with module.context: + for func in module.body: + # Returns strings of the form `"refbackend.."` so `"` is deleted. + func_name = str(func.attributes["sym_name"]).replace('"', "") + if func_name[:return_prefix_len] == CONSUME_RETURN_FUNC_PREFIX: + return_funcs.append(func_name) + + return return_funcs + + +def get_ctype_func(func_name): + return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX) + ret_types = func_name[return_prefix_len:].split("_") + ctypes_arg = [None] + for type in ret_types: + if type in elemental_type_to_ctype: + ctypes_arg.append(elemental_type_to_ctype[type]) + elif type in memref_type_to_np_dtype: + ctypes_arg.append(ctypes.POINTER(UnrankedMemRefDescriptor)) + else: + assert False, f"Not supported type: {type}" + + return ctypes.CFUNCTYPE(*ctypes_arg), ret_types + + +class MpactBackendInvoker: + def __init__(self, module, opt_level=2, shared_libs=[]): + self.ee = ExecutionEngine(module, opt_level=opt_level, shared_libs=shared_libs) + self.result = None + + return_funcs = get_return_funcs(module) + + for ret_func in return_funcs: + ctype_wrapper, ret_types = get_ctype_func(ret_func) + + def consume_return_funcs(*args): + self.result = tuple( + [ + ( + arg + if type in elemental_type_to_ctype + else unranked_memref_to_numpy( + arg, memref_type_to_np_dtype[type] + ) + ) + for arg, type in zip(args, ret_types) + ] + ) + if len(self.result) == 1: + self.result = self.result[0] + + self.ee.register_runtime(ret_func, ctype_wrapper(consume_return_funcs)) + + def __getattr__(self, function_name: str): + def invoke(*args): + ffi_args = [] + for arg in args: + assert_arg_type_is_supported(arg.dtype) + ffi_args.append( + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(arg))) + ) + + self.ee.invoke(function_name, *ffi_args) + result = self.result + assert result is not None, "Invocation didn't produce a result" + self.result = None + return result + + return invoke + + +LOWERING_PIPELINE = ( + "builtin.module(" + + ",".join( + [ + "func.func(linalg-generalize-named-ops)", + "func.func(linalg-fuse-elementwise-ops)", + "convert-shape-to-std", + # MLIR Sparsifier mini-pipeline. + "sparse-assembler{direct-out}", + "sparsification-and-bufferization", + "sparse-storage-specifier-to-llvm", + # Buffer deallocation pass does not know how to handle realloc. + "func.func(expand-realloc)", + # Generalize pad and concat after sparse compiler, as they are handled + # differently when the operations involve sparse operands. + "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", + # Bufferize. + "func.func(scf-bufferize)", + "func.func(tm-tensor-bufferize)", + "func.func(empty-tensor-to-alloc-tensor)", + "func.func(linalg-bufferize)", + "func-bufferize", + "arith-bufferize", + "refback-mlprogram-bufferize", + "func.func(tensor-bufferize)", + "func.func(finalizing-bufferize)", + "func.func(buffer-deallocation)", + # Inline sparse helper methods where useful (but after dealloc). + "inline", + "refback-munge-calling-conventions", + "func.func(tm-tensor-to-loops)", + "func.func(refback-munge-memref-copy)", + "func.func(convert-linalg-to-loops)", + "func.func(lower-affine)", + "convert-scf-to-cf", + "func.func(refback-expand-ops-for-llvm)", + "func.func(arith-expand)", + "func.func(convert-math-to-llvm)", + "convert-math-to-libm", + "expand-strided-metadata", + "finalize-memref-to-llvm", + "lower-affine", + "convert-bufferization-to-memref", + "finalize-memref-to-llvm", + "func.func(convert-arith-to-llvm)", + "convert-vector-to-llvm", + "convert-func-to-llvm", + "convert-cf-to-llvm", + "convert-complex-to-llvm", + "reconcile-unrealized-casts", + ] + ) + + ")" +) + + +class MpactBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): + """Main entry-point for the MPACT backend.""" + + def __init__(self): + super().__init__() + + def compile(self, imported_module: Module): + """Compiles an imported module, with a flat list of functions. + The module is expected to be in linalg-on-tensors + scalar code form. + + Args: + imported_module: The MLIR module in the torch dialect. + Returns: + An opaque artifact that can be passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + LOWERING_PIPELINE, + "Lowering Linalg-on-Tensors IR to LLVM with MpactBackend", + enable_ir_printing=False, + ) + return imported_module + + def load(self, module) -> MpactBackendInvoker: + """Loads a compiled artifact into the runtime.""" + return MpactBackendInvoker(module) + + +def sparse_metadata(a: torch.Tensor) -> SparsityMeta: + """ + Returns a meta data tuple for the given sparse tensor. + + NOTE: this will be fully replaced by fx graph SparseTensorMetadata + """ + sparse_dim = a.sparse_dim() + dense_dim = a.dense_dim() + batch_dim = a.ndim - dense_dim - sparse_dim + blocksize = None + if a.layout is torch.sparse_coo: + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a._indices().dtype, + a._indices().dtype, + ) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + if a.layout is torch.sparse_bsr: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a.crow_indices().dtype, + a.col_indices().dtype, + ) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + if a.layout is torch.sparse_bsc: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] + return SparsityMeta( + a.layout, + batch_dim, + sparse_dim, + dense_dim, + blocksize, + a.ccol_indices().dtype, + a.row_indices().dtype, + ) + else: + raise RuntimeError(f"Unsupported sparse layout for {a}") + + +def sparse_export( + f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None +) -> torch.export.ExportedProgram: + """ + This is a ***temporary*** wrapper around `torch.export.export` + that eventually should be removed and simply replaced by the + standard API for exporting traced graphs. + + But until issue + + https://github.com/pytorch/pytorch/pull/117907 + + is addressed, this wrapper provides support for the sparse + tensor types by first converting all operands to dense tensors, + building the traced graph as for the dense case, then annotating + sparse parameters with their actual sparse layout attributes, + followed by some simple propagation rules. This temporary solution + accelerates testing torch-mlir with PyTorch sparse tensors until + the issue is resolved upstream. + """ + # Convert all arguments to dense. + dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) + mask = [a.layout in SPARSE_LAYOUTS for a in args] + # Build the regular FX traced graph with only dense arguments + # (the current version would crash otherwise, see issue above). + prog = torch.export.export(f, dargs, kwargs) + # Annotate sparse arguments in the graph and apply some very + # basic propagation rules for sparsity. + specs = prog.graph_signature.input_specs + alen = len(specs) + k = 0 + for i, node in enumerate(prog.graph.nodes): + if node.op == "placeholder": + # Argument. + spec = specs[i] + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + if mask[k]: + node.meta["sparsity"] = sparse_metadata(args[k]) + k = k + 1 + elif node.op == "call_function": + # TODO: use upstream _opname implementation when available + opname = node.target._schema.name.split("::")[1] + # Zero preserving elt-wise unary op. + if opname in {"abs", "neg", "relu", "sin"}: + node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + elif opname == "_to_sparse": + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 + ) + # TODO: Uncomment this to hack sparsity into the network. + # elif opname == "_to_dense": + # # hack (assumes we never really want the to_dense for now) + # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + elif opname == "select" and node.args[0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 + ) + elif opname == "stack" and node.args[0][0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64 + ) + return prog + + +def export_and_import(f, *args, **kwargs): + """This method implements Stella's importer, stripped down to essentials.""" + context = ir.Context() + torch_d.register_dialect(context) + fx_importer = FxImporter(context=context) + prog = sparse_export(f, args, kwargs) + fx_importer.import_frozen_program(prog) + return fx_importer.module + + +def mpact_jit_compile(f, *args, **kwargs): + """This method compiles the given callable using the MPACT backend.""" + # Import module and lower into Linalg IR. + module = export_and_import(f, *args, **kwargs) + run_pipeline_with_repro_report( + module, + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-linalg-on-tensors-backend-pipeline)" + ), + "Lowering TorchFX IR -> Linalg IR", + enable_ir_printing=False, + ) + # Compile with MPACT backend. + backend = MpactBackendLinalgOnTensorsBackend() + compiled = backend.compile(module) + invoker = backend.load(compiled) + return invoker, f + + +def mpact_jit_run(invoker, f, *args, **kwargs): + """This method runs the given callable using the given MPACT invoker.""" + xargs = [] + # Prepare the buffer parameters (assume all dense). + # TODO: filters out scalar arguments, anything else? + params = dict(f.named_buffers(remove_duplicate=True)) + params_flat, params_spec = torch.utils._pytree.tree_flatten(params) + for p in params_flat: + if len(p.shape) > 0: + xargs.append(p.numpy()) + # Prepare input parameters. Sparse input tensors are split into + # their composite tensors. All PyTorch tensors are converted + # to their backing numpy arrays. Note that the output consists + # of numpy arrays as well, which can trivially be reconstructed + # into PyTorch tensors (dense and sparse). + for a in args: + if a.layout is torch.sparse_coo: + # Construct the additional position array required by MLIR with data + # array([0, nnz]). The COO format always uses int64 indices. + xargs.append(np.array([0, a._nnz()], dtype=np.int64)) + # Transform a tensor<ndim x nnz> into ndim x tensor<nnz> to conform + # to the MLIR SoA COO representation. + for idx in a._indices(): + xargs.append(idx.numpy()) + xargs.append(a._values().numpy()) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + xargs.append(a.crow_indices().numpy()) + xargs.append(a.col_indices().numpy()) + xargs.append(a.values().numpy()) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + xargs.append(a.ccol_indices().numpy()) + xargs.append(a.row_indices().numpy()) + xargs.append(a.values().numpy()) + else: + xargs.append(a.numpy()) + # Invoke. + return invoker.main(*xargs) + + +def mpact_jit(f, *args, **kwargs): + """This method compiles and runs the given callable using the MPACT backend.""" + invoker, fn = mpact_jit_compile(f, *args, **kwargs) + return mpact_jit_run(invoker, fn, *args, **kwargs)
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..378640d --- /dev/null +++ b/test/CMakeLists.txt
@@ -0,0 +1,25 @@ +#------------------------------------------------------------------------------- +# The MPACT Compiler Tests +#------------------------------------------------------------------------------- + +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +set(MPACT_TEST_DEPENDS + FileCheck count not + MPACTPythonSources + TorchMLIRPythonModules + torch-mlir-opt + ) + +add_lit_testsuite(check-mpact "Running the MPACT regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${MPACT_TEST_DEPENDS} + ) +set_target_properties(check-mpact PROPERTIES FOLDER "Tests") + +add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
diff --git a/test/lit.cfg.py b/test/lit.cfg.py new file mode 100644 index 0000000..3a9297d --- /dev/null +++ b/test/lit.cfg.py
@@ -0,0 +1,80 @@ +#------------------------------------------------------------------------------- +# The MPACT Compiler LIT Configuration +#------------------------------------------------------------------------------- + +import os +import platform +import re +import subprocess +import tempfile + +import lit.formats +import lit.util + +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst +from lit.llvm.subst import FindTool + +# The name of this test suite. +config.name = "MPACT" + +# The test format. +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# A list of file extensions to treat as test files. +config.suffixes = [".py"] + +# A list of files to exclude from the test suite. +config.excludes = [ + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", +] + +# The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# The root path where tests should be run. +config.test_exec_root = os.path.join(config.mpact_obj_root, "test") +config.standalone_tools_dir = os.path.join(config.mpact_obj_root, "bin") + +# Substitutions. +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) +llvm_config.with_environment( + "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True +) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) + +# On Windows the path to python could contains spaces in which case it needs to +# be provided in quotes. This is the equivalent of how %python is setup in +# llvm/utils/lit/lit/llvm/config.py. +if "Windows" in config.host_os: + config.python_executable = '"%s"' % (config.python_executable) + +# Tools. +tool_dirs = [ + config.standalone_tools_dir, + config.llvm_tools_dir, + config.mpact_obj_root, +] +tools = [ + "torch-mlir-opt", + ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"), +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) + +llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.mpact_src_root, "python"), + os.path.join(config.torch_mlir_obj_root, "python_packages/torch_mlir"), + ], + append_path=True, +)
diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in new file mode 100644 index 0000000..01f3d01 --- /dev/null +++ b/test/lit.site.cfg.py.in
@@ -0,0 +1,23 @@ +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.host_os = "@HOST_OS@" +config.mpact_src_root = "@MPACT_SOURCE_DIR@" +config.mpact_obj_root = "@MPACT_BINARY_DIR@" +config.torch_mlir_obj_root = "@LLVM_BINARY_DIR@/tools/torch-mlir" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_build_dir = "@CMAKE_BINARY_DIR@" +config.llvm_lib_dir = "@LLVM_LIBS_DIR@" +config.llvm_shlib_dir = "@SHLIBDIR@" +config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_exe_ext = "@EXEEXT@" +config.python_executable = "@Python3_EXECUTABLE@" + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config(config, "@MPACT_SOURCE_DIR@/test/lit.cfg.py")
diff --git a/test/python/sparse_gcn.py b/test/python/sparse_gcn.py new file mode 100644 index 0000000..f2d7b2f --- /dev/null +++ b/test/python/sparse_gcn.py
@@ -0,0 +1,67 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + + +class GraphConv(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super(GraphConv, self).__init__() + self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, output_dim)) + torch.nn.init.ones_(self.kernel) + self.bias = torch.nn.Parameter(torch.Tensor(output_dim)) + torch.nn.init.ones_(self.bias) + + def forward(self, inp, adj_mat): + # Input matrix times weight matrix. + support = torch.mm(inp, self.kernel) + # Sparse adjacency matrix times support matrix. + output = torch.spmm(adj_mat, support) + # Add bias. + output = output + self.bias + return output + + +net = GraphConv(4, 4) + +# Get random (but reproducible) matrices. +torch.manual_seed(0) +inp = torch.rand(4, 4) +adj_mat = torch.rand(4, 4).to_sparse() + +# +# CHECK: pytorch +# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778], +# CHECK: [5.7502, 5.7502, 5.7502, 5.7502], +# CHECK: [4.6980, 4.6980, 4.6980, 4.6980], +# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}}) +# CHECK: mpact compile and run +# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ] +# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717] +# CHECK: [4.697952 4.697952 4.697952 4.697952 ] +# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}} +# CHECK: mpact compile +# CHECK: mpact run +# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ] +# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717] +# CHECK: [4.697952 4.697952 4.697952 4.697952 ] +# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}} +# +with torch.no_grad(): + # Run it with PyTorch. + print("pytorch") + res = net(inp, adj_mat) + print(res) + + # Run it with MPACT (compile and run at once). + print("mpact compile and run") + res = mpact_jit(net, inp, adj_mat) + print(res) + + # Run it with MPACT (with separate compile and run steps). + print("mpact compile") + invoker, fn = mpact_jit_compile(net, inp, adj_mat) + print("mpact run") + res = mpact_jit_run(invoker, fn, inp, adj_mat) + print(res)
diff --git a/test/python/sparse_lif.py b/test/python/sparse_lif.py new file mode 100644 index 0000000..5f57f32 --- /dev/null +++ b/test/python/sparse_lif.py
@@ -0,0 +1,87 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + + +def spike(input): + return (input >= 0).float() + + +def sqSum(input): + return (input * input).sum() + + +class LIF(torch.nn.Module): + def __init__(self): + super(LIF, self).__init__() + self.thresh = 1.0 + self.decay = 0.5 + self.act = spike + + def forward(self, X): + """A filter that yields a binary-valued sparse tensor.""" + mem = 0 + spike_pot = [] + T = X.size(-1) + for t in range(T): + mem = mem * self.decay + X[..., t] + spike = self.act(mem - self.thresh) + spike = spike.to_sparse().to_dense() # prop hack + mem = mem * (1.0 - spike) + spike_pot.append(spike) + spike_pot = torch.stack(spike_pot, dim=-1) + return spike_pot + + +class tdLayer(torch.nn.Module): + def __init__(self, layer): + super(tdLayer, self).__init__() + self.layer = layer + + def forward(self, X): + T = X.size(-1) + out = [] + for t in range(T): + m = self.layer(X[..., t]) + out.append(m) + out = torch.stack(out, dim=-1) + return out + + +class Block(torch.nn.Module): + def __init__(self): + super(Block, self).__init__() + self.spike = LIF() + self.layer = tdLayer(sqSum) + + def forward(self, X): + out = self.spike(X) + out = self.layer(out) + return out + + +net = Block() + +# Get a random (but reproducible) input, so that a +# general sparse tensor appears after LIF. +torch.manual_seed(0) +x = torch.rand(2, 3, 8, 8) + +# +# CHECK: pytorch +# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) +# CHECK: mpact +# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] +# + +# Run it with PyTorch. +print("pytorch") +res = net(x) +print(res) + +# Run it with MPACT. +print("mpact") +res = mpact_jit(net, x) +print(res)
diff --git a/test/python/spmv.py b/test/python/spmv.py new file mode 100644 index 0000000..cb0e491 --- /dev/null +++ b/test/python/spmv.py
@@ -0,0 +1,33 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from mpactbackend import mpact_jit, mpact_jit_compile, mpact_jit_run + +class SpMVNet(torch.nn.Module): + def forward(self, x, v): + return torch.mv(x, v) + +net = SpMVNet() + +# Get a fixed vector and matrix (which we make 2x2 block "sparse"). +dense_vector = torch.arange(1, 11, dtype=torch.float32) +dense_input = torch.arange(1, 101, dtype=torch.float32).view(10, 10) +sparse_matrix = dense_input.to_sparse_bsr(blocksize=(2, 2)) + +# +# CHECK: pytorch +# CHECK: tensor([ 385., 935., 1485., 2035., 2585., 3135., 3685., 4235., 4785., 5335.]) +# CHECK: mpact +# CHECK: [ 385. 935. 1485. 2035. 2585. 3135. 3685. 4235. 4785. 5335.] +# + +# Run it with PyTorch. +print("pytorch") +res = net(sparse_matrix, dense_vector) +print(res) + +# Run it with MPACT. +print("mpact") +res = mpact_jit(net, sparse_matrix, dense_vector) +print(res)