[mpact][compiler] add torch aten decomposition features (#53)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index aab4f9a..4d72267 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -52,9 +52,10 @@
MLIRPythonExtension.Core
MLIRPythonExtension.RegisterEverything
- # We need the FxImporter from torch-mlir
+ # We need various modules form torch-mlir.
TorchMLIRPythonSources.Importers
TorchMLIRPythonSources.Dialects
+ TorchMLIRPythonSources.PublicAPI
TorchMLIRPythonExtensions
MPACTPythonSources
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 3e8ae4b..a98bc60 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -15,6 +15,7 @@
from mpact.ir import Module
from mpact.dialects import torch as torch_d
from mpact.execution_engine import *
+from mpact.extras.fx_decomp_util import get_decomposition_table
from mpact.extras.fx_importer import FxImporter, SparsityMeta
from mpact.ir import *
from mpact.passmanager import *
@@ -61,14 +62,6 @@
pm.enable_ir_printing()
pm.run(module.operation)
except Exception as e:
- # TODO: More robust.
- # - don't arbitrarily clutter up /tmp. When a test suite has many
- # tests, this can be a big disk cost (also, /tmp/ is frequently a
- # RAM fs, which increases worries about capacity).
- # - don't have colliding filenames (hard to do without cluttering
- # up /tmp)
- # - if we do have have colliding filenames, writes should at least
- # avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, "w") as f:
f.write(asm_for_error_report)
@@ -398,6 +391,9 @@
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
+ decomposition_table = get_decomposition_table()
+ if decomposition_table:
+ prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
@@ -412,7 +408,6 @@
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
- # TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}: