[mpact][compiler] add torch aten decomposition features (#53)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index aab4f9a..4d72267 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt
@@ -52,9 +52,10 @@ MLIRPythonExtension.Core MLIRPythonExtension.RegisterEverything - # We need the FxImporter from torch-mlir + # We need various modules form torch-mlir. TorchMLIRPythonSources.Importers TorchMLIRPythonSources.Dialects + TorchMLIRPythonSources.PublicAPI TorchMLIRPythonExtensions MPACTPythonSources
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index 3e8ae4b..a98bc60 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py
@@ -15,6 +15,7 @@ from mpact.ir import Module from mpact.dialects import torch as torch_d from mpact.execution_engine import * +from mpact.extras.fx_decomp_util import get_decomposition_table from mpact.extras.fx_importer import FxImporter, SparsityMeta from mpact.ir import * from mpact.passmanager import * @@ -61,14 +62,6 @@ pm.enable_ir_printing() pm.run(module.operation) except Exception as e: - # TODO: More robust. - # - don't arbitrarily clutter up /tmp. When a test suite has many - # tests, this can be a big disk cost (also, /tmp/ is frequently a - # RAM fs, which increases worries about capacity). - # - don't have colliding filenames (hard to do without cluttering - # up /tmp) - # - if we do have have colliding filenames, writes should at least - # avoid being racy. filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") with open(filename, "w") as f: f.write(asm_for_error_report) @@ -398,6 +391,9 @@ # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) # Annotate sparse arguments in the graph and apply some very # basic propagation rules for sparsity. specs = prog.graph_signature.input_specs @@ -412,7 +408,6 @@ node.meta["sparsity"] = sparse_metadata(args[k]) k = k + 1 elif node.op == "call_function": - # TODO: use upstream _opname implementation when available opname = node.target._schema.name.split("::")[1] # Zero preserving elt-wise unary op. if opname in {"abs", "neg", "relu", "sin"}: