[mpact][compiler] extract linalg module import into own method (#76)
diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index 72b440d..eb09e0b 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py
@@ -319,9 +319,8 @@ return fx_importer.module -def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs): - """This method compiles the given callable using the MPACT backend.""" - # Import module and lower into Linalg IR. +def mpact_linalg(f, *args, **kwargs): + """Imports a function as module and lowers it into Linalg IR.""" module = export_and_import(f, *args, **kwargs) run_pipeline_with_repro_report( module, @@ -333,7 +332,12 @@ "Lowering TorchFX IR -> Linalg IR", enable_ir_printing=False, ) - # Compile with MPACT backend compiler. + return module + + +def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs): + """This method compiles the given callable using the MPACT backend.""" + module = mpact_linalg(f, *args, **kwargs) backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it) compiled = backend.compile(module) invoker = backend.load(compiled)
diff --git a/test/python/mm_print.py b/test/python/mm_print.py new file mode 100644 index 0000000..976c10c --- /dev/null +++ b/test/python/mm_print.py
@@ -0,0 +1,31 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch +import numpy as np + +from mpact.mpactbackend import mpact_linalg + +from mpact.models.kernels import MMNet + + +net = MMNet() + +X = torch.arange(0, 16, dtype=torch.float32).view(4, 4) +Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) + +# +# CHECK: module { +# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> { +# CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +# CHECK: %[[T0:.*]] = tensor.empty() : tensor<4x4xf32> +# CHECK: %[[T1:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[T0]] : tensor<4x4xf32>) -> tensor<4x4xf32> +# CHECK: %[[T2:.*]] = linalg.matmul +# CHECK-SAME: ins(%[[A0]], %[[A1]] : tensor<4x4xf32>, tensor<4x4xf32>) +# CHECK-SAME: outs(%[[T1]] : tensor<4x4xf32>) -> tensor<4x4xf32> +# CHECK: return %2 : tensor<4x4xf32> +# CHECK: } +# CHECK: } +# + +linalg = mpact_linalg(net, X, Y) +print(linalg)