[mpact][compiler] add stable hlo pipeline (#78)
adds a lowering to stable hlo method in addition
to lowering to linalg; note that this can be used
as an alternative path into the mpact pipeline
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index eb09e0b..346e5a2 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -320,7 +320,7 @@
def mpact_linalg(f, *args, **kwargs):
- """Imports a function as module and lowers it into Linalg IR."""
+ """Imports a callable as module and lowers it into Linalg IR."""
module = export_and_import(f, *args, **kwargs)
run_pipeline_with_repro_report(
module,
@@ -335,6 +335,22 @@
return module
+def mpact_stablehlo(f, *args, **kwargs):
+ """Imports a callable as module and lowers it into StableHLO IR."""
+ module = export_and_import(f, *args, **kwargs)
+ run_pipeline_with_repro_report(
+ module,
+ (
+ "builtin.module("
+ "func.func(torch-decompose-complex-ops),"
+ "torch-backend-to-stablehlo-backend-pipeline)"
+ ),
+ "Lowering TorchFX IR -> StableHLO IR",
+ enable_ir_printing=False,
+ )
+ return module
+
+
def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
module = mpact_linalg(f, *args, **kwargs)
diff --git a/test/python/mm_print.py b/test/python/mm_print.py
index 976c10c..8065e77 100644
--- a/test/python/mm_print.py
+++ b/test/python/mm_print.py
@@ -3,7 +3,7 @@
import torch
import numpy as np
-from mpact.mpactbackend import mpact_linalg
+from mpact.mpactbackend import mpact_linalg, mpact_stablehlo
from mpact.models.kernels import MMNet
@@ -29,3 +29,14 @@
linalg = mpact_linalg(net, X, Y)
print(linalg)
+
+#
+# CHECK: module {
+# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> {
+# CHECK: %[[T0:.*]] = stablehlo.dot_general %[[A0]], %[[A1]], contracting_dims = [1] x [0] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
+# CHECK: return %[[T0]] : tensor<4x4xf32>
+# CHECK: }
+# CHECK: }
+
+stablehlo = mpact_stablehlo(net, X, Y)
+print(stablehlo)