format
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py
index ec944fa..0d358d0 100644
--- a/python/mpact/mpactbackend.py
+++ b/python/mpact/mpactbackend.py
@@ -278,8 +278,7 @@
class MpactBackendCompiler:
"""Main entry-point for the MPACT backend compiler."""
- def __init__(self, opt_level, use_sp_it, parallel,
- enable_ir_printing, num_threads):
+ def __init__(self, opt_level, use_sp_it, parallel, enable_ir_printing, num_threads):
self.opt_level = opt_level
self.use_sp_it = use_sp_it
self.parallel = parallel
@@ -292,13 +291,14 @@
if self.use_sp_it
else "vl=16 enable-simd-index32"
)
- omp_options = (f"num-threads={self.num_threads}")
+ omp_options = f"num-threads={self.num_threads}"
# TODO: enable the parallelization strategy
# once MLIR bump is completed.
# if self.parallel:
# sp_options += f" parallelization-strategy={self.parallel}"
LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(
- sp_options=sp_options, omp_options=omp_options)
+ sp_options=sp_options, omp_options=omp_options
+ )
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
@@ -473,9 +473,16 @@
return fx_importer.module
-def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False,
- parallel="none", enable_ir_printing=False,
- num_threads = 1, **kwargs):
+def mpact_jit_compile(
+ f,
+ *args,
+ opt_level=2,
+ use_sp_it=False,
+ parallel="none",
+ enable_ir_printing=False,
+ num_threads=1,
+ **kwargs,
+):
"""This method compiles the given callable using the MPACT backend."""
# Import module and lower into Linalg IR.
module = export_and_import(f, *args, **kwargs)
@@ -490,11 +497,13 @@
enable_ir_printing=enable_ir_printing,
)
# Compile with MPACT backend compiler.
- backend = MpactBackendCompiler(opt_level=opt_level,
- use_sp_it=use_sp_it,
- parallel=parallel,
- enable_ir_printing=enable_ir_printing,
- num_threads=num_threads)
+ backend = MpactBackendCompiler(
+ opt_level=opt_level,
+ use_sp_it=use_sp_it,
+ parallel=parallel,
+ enable_ir_printing=enable_ir_printing,
+ num_threads=num_threads,
+ )
compiled = backend.compile(module)
invoker = backend.load(compiled)
return invoker, f
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 30b9164..b93e6e2 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -24,8 +24,8 @@
set_target_properties(check-mpact PROPERTIES FOLDER "Tests")
# TODO: find omp library.
-find_package(OpenMP REQUIRED)
-add_compile_options(${OpenMP_CXX_FLAGS})
+# find_package(OpenMP REQUIRED)
+# add_compile_options(${OpenMP_CXX_FLAGS})
# target_link_libraries(check-mpact OpenMP::OpenMP_CXX)
add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
diff --git a/test/python/parallel.py b/test/python/parallel.py
index edcf859..427155a 100644
--- a/test/python/parallel.py
+++ b/test/python/parallel.py
@@ -15,6 +15,7 @@
f(*args, **kwargs)
gc.collect()
+
net = MMNet()
# Construct dense and sparse matrices.
@@ -35,6 +36,12 @@
# TODO: enable the check test.
# C-HECK: omp.parallel
# CHECK: openmp
-run_test(mpact_jit, net, X, Y,
- parallel="any-storage-any-loop", enable_ir_printing=True,
- num_threads=10)
+run_test(
+ mpact_jit,
+ net,
+ X,
+ Y,
+ parallel="any-storage-any-loop",
+ enable_ir_printing=True,
+ num_threads=10,
+)