format
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index ec944fa..0d358d0 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py
@@ -278,8 +278,7 @@ class MpactBackendCompiler: """Main entry-point for the MPACT backend compiler.""" - def __init__(self, opt_level, use_sp_it, parallel, - enable_ir_printing, num_threads): + def __init__(self, opt_level, use_sp_it, parallel, enable_ir_printing, num_threads): self.opt_level = opt_level self.use_sp_it = use_sp_it self.parallel = parallel @@ -292,13 +291,14 @@ if self.use_sp_it else "vl=16 enable-simd-index32" ) - omp_options = (f"num-threads={self.num_threads}") + omp_options = f"num-threads={self.num_threads}" # TODO: enable the parallelization strategy # once MLIR bump is completed. # if self.parallel: # sp_options += f" parallelization-strategy={self.parallel}" LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format( - sp_options=sp_options, omp_options=omp_options) + sp_options=sp_options, omp_options=omp_options + ) """Compiles an imported module, with a flat list of functions. The module is expected to be in linalg-on-tensors + scalar code form. @@ -473,9 +473,16 @@ return fx_importer.module -def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, - parallel="none", enable_ir_printing=False, - num_threads = 1, **kwargs): +def mpact_jit_compile( + f, + *args, + opt_level=2, + use_sp_it=False, + parallel="none", + enable_ir_printing=False, + num_threads=1, + **kwargs, +): """This method compiles the given callable using the MPACT backend.""" # Import module and lower into Linalg IR. module = export_and_import(f, *args, **kwargs) @@ -490,11 +497,13 @@ enable_ir_printing=enable_ir_printing, ) # Compile with MPACT backend compiler. - backend = MpactBackendCompiler(opt_level=opt_level, - use_sp_it=use_sp_it, - parallel=parallel, - enable_ir_printing=enable_ir_printing, - num_threads=num_threads) + backend = MpactBackendCompiler( + opt_level=opt_level, + use_sp_it=use_sp_it, + parallel=parallel, + enable_ir_printing=enable_ir_printing, + num_threads=num_threads, + ) compiled = backend.compile(module) invoker = backend.load(compiled) return invoker, f
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 30b9164..b93e6e2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt
@@ -24,8 +24,8 @@ set_target_properties(check-mpact PROPERTIES FOLDER "Tests") # TODO: find omp library. -find_package(OpenMP REQUIRED) -add_compile_options(${OpenMP_CXX_FLAGS}) +# find_package(OpenMP REQUIRED) +# add_compile_options(${OpenMP_CXX_FLAGS}) # target_link_libraries(check-mpact OpenMP::OpenMP_CXX) add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
diff --git a/test/python/parallel.py b/test/python/parallel.py index edcf859..427155a 100644 --- a/test/python/parallel.py +++ b/test/python/parallel.py
@@ -15,6 +15,7 @@ f(*args, **kwargs) gc.collect() + net = MMNet() # Construct dense and sparse matrices. @@ -35,6 +36,12 @@ # TODO: enable the check test. # C-HECK: omp.parallel # CHECK: openmp -run_test(mpact_jit, net, X, Y, - parallel="any-storage-any-loop", enable_ir_printing=True, - num_threads=10) +run_test( + mpact_jit, + net, + X, + Y, + parallel="any-storage-any-loop", + enable_ir_printing=True, + num_threads=10, +)