add parallelization to mpact
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index 425413a..ec944fa 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -247,10 +247,12 @@
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(lower-affine)",
+ "convert-scf-to-openmp{{{omp_options}}}",
"convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)",
"func.func(convert-math-to-llvm)",
+ "convert-openmp-to-llvm",
"convert-math-to-libm",
"expand-strided-metadata",
"finalize-memref-to-llvm",
@@ -276,9 +278,13 @@
class MpactBackendCompiler:
"""Main entry-point for the MPACT backend compiler."""
- def __init__(self, opt_level, use_sp_it):
+ def __init__(self, opt_level, use_sp_it, parallel,
+ enable_ir_printing, num_threads):
self.opt_level = opt_level
self.use_sp_it = use_sp_it
+ self.parallel = parallel
+ self.enable_ir_printing = enable_ir_printing
+ self.num_threads = num_threads
def compile(self, imported_module: Module) -> MpactCompiledArtifact:
sp_options = (
@@ -286,7 +292,13 @@
if self.use_sp_it
else "vl=16 enable-simd-index32"
)
- LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(sp_options=sp_options)
+ omp_options = (f"num-threads={self.num_threads}")
+ # TODO: enable the parallelization strategy
+ # once MLIR bump is completed.
+ # if self.parallel:
+ # sp_options += f" parallelization-strategy={self.parallel}"
+ LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(
+ sp_options=sp_options, omp_options=omp_options)
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
@@ -299,7 +311,7 @@
imported_module,
LOWERING_PIPELINE,
"Lowering Linalg-on-Tensors IR to LLVM with MpactBackendCompiler",
- enable_ir_printing=False,
+ enable_ir_printing=self.enable_ir_printing,
)
return imported_module
@@ -461,7 +473,9 @@
return fx_importer.module
-def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
+def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False,
+ parallel="none", enable_ir_printing=False,
+ num_threads = 1, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
# Import module and lower into Linalg IR.
module = export_and_import(f, *args, **kwargs)
@@ -473,10 +487,14 @@
"torch-backend-to-linalg-on-tensors-backend-pipeline)"
),
"Lowering TorchFX IR -> Linalg IR",
- enable_ir_printing=False,
+ enable_ir_printing=enable_ir_printing,
)
# Compile with MPACT backend compiler.
- backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it)
+ backend = MpactBackendCompiler(opt_level=opt_level,
+ use_sp_it=use_sp_it,
+ parallel=parallel,
+ enable_ir_printing=enable_ir_printing,
+ num_threads=num_threads)
compiled = backend.compile(module)
invoker = backend.load(compiled)
return invoker, f
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 43c3ab9..30b9164 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -23,4 +23,9 @@
)
set_target_properties(check-mpact PROPERTIES FOLDER "Tests")
+# TODO: find omp library.
+find_package(OpenMP REQUIRED)
+add_compile_options(${OpenMP_CXX_FLAGS})
+# target_link_libraries(check-mpact OpenMP::OpenMP_CXX)
+
add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
diff --git a/test/python/parallel.py b/test/python/parallel.py
new file mode 100644
index 0000000..edcf859
--- /dev/null
+++ b/test/python/parallel.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON -s %s 2>&1 | FileCheck %s
+
+import gc
+import sys
+import torch
+import numpy as np
+
+from mpact.mpactbackend import mpact_jit
+
+from mpact.models.kernels import MMNet
+
+
+def run_test(f, *args, **kwargs):
+ print("TEST:", f.__name__, file=sys.stderr)
+ f(*args, **kwargs)
+ gc.collect()
+
+net = MMNet()
+
+# Construct dense and sparse matrices.
+X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
+Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
+A = torch.tensor(
+ [
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 2.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [3.0, 0.0, 0.0, 0.0],
+ ],
+ dtype=torch.float32,
+)
+S = A.to_sparse_csr()
+
+# Run it with MPACT.
+# TODO: enable the check test.
+# C-HECK: omp.parallel
+# CHECK: openmp
+run_test(mpact_jit, net, X, Y,
+ parallel="any-storage-any-loop", enable_ir_printing=True,
+ num_threads=10)