[mpact][compiler] remove all e2e boilerplate code (#34)
Since the Mpact compiler/invoker are now stand-alone product classes,
we no longer need to carry the abstract base class information that
was used for torch-mlir e2e testing.
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 379dcfe..5607082 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -1,9 +1,7 @@
# Initialize mpact python extension.
import mpact._mlir_libs._mpact
-import abc
import ctypes
-from enum import Enum
from io import StringIO
import numpy as np
import os
@@ -26,44 +24,8 @@
SUPPORT_LIB = os.getenv("SUPPORT_LIB", default=None)
SHARED_LIBS = [] if SUPPORT_LIB is None else [SUPPORT_LIB]
-# A type shared between the result of `LinalgOnTensorsBackend.compile` and the
-# input to `LinalgOnTensorsBackend.load`. Each backend will likely have a
-# different definition of this type.
-CompiledArtifact = TypeVar("CompiledArtifact")
-
-# A wrapper around a backend-specific loaded program representation
-# that uniformly translates the `x.method(...)` interface expected of
-# Torch modules into appropriate lower-level operations.
-Invoker = TypeVar("Invoker")
-
-
-class LinalgOnTensorsBackend(abc.ABC):
- """The interface to an linalg-on-tensors backend.
-
- Backends are recommended to raise meaningful exceptions in case of error,
- ideally with easy reproduction instructions.
- """
-
- @abc.abstractmethod
- def compile(self, module: Module) -> CompiledArtifact:
- """Compile the provided MLIR module into a compiled artifact.
-
- The module adheres to the linalg-on-tensors backend contract
- (see the VerifyLinalgOnTensorsBackendContract pass).
-
- The compiled artifact can be any type, but must be correctly
- interpreted by the `load` method.
- """
-
- @abc.abstractmethod
- def load(self, artifact: CompiledArtifact) -> Invoker:
- """Load the compiled artifact into a uniformly invokable form.
-
- The compiled artifact is the result of a previous call to `compile`.
-
- See the description of `Invoker` for the requirements on the returned
- type.
- """
+# The result of MPACT compile() and input to load().
+MpactCompiledArtifact = TypeVar("MpactCompiledArtifact")
def get_module_name_for_debug_dump(module):
@@ -315,14 +277,13 @@
)
-class MpactBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
- """Main entry-point for the MPACT backend."""
+class MpactBackendCompiler:
+ """Main entry-point for the MPACT backend compiler."""
def __init__(self, opt_level):
- super().__init__()
self.opt_level = opt_level
- def compile(self, imported_module: Module):
+ def compile(self, imported_module: Module) -> MpactCompiledArtifact:
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
@@ -334,13 +295,19 @@
run_pipeline_with_repro_report(
imported_module,
LOWERING_PIPELINE,
- "Lowering Linalg-on-Tensors IR to LLVM with MpactBackend",
+ "Lowering Linalg-on-Tensors IR to LLVM with MpactBackendCompiler",
enable_ir_printing=False,
)
return imported_module
- def load(self, module) -> MpactBackendInvoker:
- """Loads a compiled artifact into the runtime."""
+ def load(self, module: MpactCompiledArtifact) -> MpactBackendInvoker:
+ """Loads a compiled artifact into the runtime.
+
+ Args:
+ module: The result of a previous call to `compile`.
+ Returns:
+ MPactInvoker to call a compiled method (viz `invoker.foo(...)`).
+ """
return MpactBackendInvoker(module, self.opt_level)
@@ -483,8 +450,8 @@
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
)
- # Compile with MPACT backend.
- backend = MpactBackendLinalgOnTensorsBackend(opt_level=opt_level)
+ # Compile with MPACT backend compiler.
+ backend = MpactBackendCompiler(opt_level=opt_level)
compiled = backend.compile(module)
invoker = backend.load(compiled)
return invoker, f
@@ -493,8 +460,8 @@
def mpact_jit_run(invoker, f, *args, **kwargs):
"""This method runs the given callable using the given MPACT invoker."""
xargs = []
- # Prepare the buffer parameters (assume all dense).
- # TODO: filters out scalar arguments, anything else?
+ # Prepare all the named buffer parameters (assume all dense).
+ # All scalar arguments are filtered out since they appear inline.
params = dict(f.named_buffers(remove_duplicate=True))
params_flat, params_spec = torch.utils._pytree.tree_flatten(params)
for p in params_flat: