[mpact][passes] set up transform passes structure. (#16)
diff --git a/.gitignore b/.gitignore index 72af6fc..a6d9b9f 100644 --- a/.gitignore +++ b/.gitignore
@@ -1,3 +1,7 @@ *_venv/ __pycache__ /build*/ + +# lsp files +.cache/ +compile_commands.json
diff --git a/CMakeLists.txt b/CMakeLists.txt index 0063ba9..a9964e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -40,6 +40,23 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +function(mpact_target_includes target) + set(_dirs + $<BUILD_INTERFACE:${MLIR_INCLUDE_DIRS}> + $<BUILD_INTERFACE:${MPACT_SOURCE_DIR}/include> + $<BUILD_INTERFACE:${MPACT_BINARY_DIR}/include> + ) + # In LLVM parlance, the actual target may just be an interface and may not + # be responsible for actually compiling anything. The corresponding obj. + # target, when present, is just used for compilation and does not + # contribute to the interface properties. + # TODO: Normalize this upstream. + target_include_directories(${target} PUBLIC ${_dirs}) + if(TARGET obj.${target}) + target_include_directories(obj.${target} PRIVATE ${_dirs}) + endif() +endfunction() + add_subdirectory(include) add_subdirectory(lib) add_subdirectory(tools)
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index e69de29..711b39d 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt
@@ -0,0 +1 @@ +add_subdirectory(mpact)
diff --git a/include/mpact/CMakeLists.txt b/include/mpact/CMakeLists.txt new file mode 100644 index 0000000..e31af32 --- /dev/null +++ b/include/mpact/CMakeLists.txt
@@ -0,0 +1 @@ +add_subdirectory(Transforms)
diff --git a/include/mpact/Transforms/CMakeLists.txt b/include/mpact/Transforms/CMakeLists.txt new file mode 100644 index 0000000..17587b1 --- /dev/null +++ b/include/mpact/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MPACTTransformsPassIncGen) + +add_mlir_doc(Passes MPACTTransformsPass ./ -gen-pass-doc)
diff --git a/include/mpact/Transforms/Passes.h b/include/mpact/Transforms/Passes.h new file mode 100644 index 0000000..389184b --- /dev/null +++ b/include/mpact/Transforms/Passes.h
@@ -0,0 +1,22 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the MPACT Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef MPACT_TRANSFORMS_PASSES_H +#define MPACT_TRANSFORMS_PASSES_H + +namespace mlir { +namespace mpact { + +/// Registers all mpact transform passes. +void registerTransformPasses(); + +} // namespace mpact +} // namespace mlir + +#endif // MPACT_TRANSFORMS_PASSES_H
diff --git a/include/mpact/Transforms/Passes.td b/include/mpact/Transforms/Passes.td new file mode 100644 index 0000000..c83f5d4 --- /dev/null +++ b/include/mpact/Transforms/Passes.td
@@ -0,0 +1,51 @@ +//===-- Passes.td - Transforms pass definition file --------*- tablegen -*-===// +// +// Part of the MPACT Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for passes within the Transforms/ directory. +// +//===----------------------------------------------------------------------===// + +#ifndef MPACT_TRANSFORMS_PASSES +#define MPACT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def SparseEncodingPropagation : Pass<"sparse-encoding-propagation", "func::FuncOp"> { + let summary = "Propagate sparse tensor encodings"; + let description = [{ + A pass that propagates sparse tensor encodings. + + Background: To avoid introducing repetitive operations, sparse tensors + in MLIR try to reuse tensor operations whenever available. However, most + tensor operations are canonicalized/transformed without the knowledge + of sparsity. The pass tries to propagate missing sparse encodings. + + For example: + ```mlir + %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse> + + // After rank reducing (by tensor dialect transformation) + %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2xf32> + %s = tensor.expand_shape [[0, 1]] %t + : tensor<2xf32> to tensor<2x1xf32, #sparse> + + // After sparsity propagation + %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1] + : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1> + %s = tensor.expand_shape [[0, 1]] %t + : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse> + ``` + }]; + + let constructor = "mlir::mpact::createSparseEncodingPropagationPass()"; + let dependentDialects = []; +} + +#endif // MPACT_TRANSFORMS_PASSES
diff --git a/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h b/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h new file mode 100644 index 0000000..a829400 --- /dev/null +++ b/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h
@@ -0,0 +1,24 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the MPACT Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H +#define MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mpact { +std::unique_ptr<OperationPass<func::FuncOp>> +createSparseEncodingPropagationPass(); +} +} // namespace mlir + +#endif // MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H
diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e69de29..e31af32 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt
@@ -0,0 +1 @@ +add_subdirectory(Transforms)
diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt new file mode 100644 index 0000000..edd737f --- /dev/null +++ b/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,13 @@ +add_subdirectory(Sparsity) + +set(linked_libs MPACTSparsityPropagation) + +add_mlir_library(MPACTTransformPasses + Passes.cpp + + DEPENDS + MPACTTransformsPassIncGen + + LINK_LIBS PUBLIC + ${linked_libs} +)
diff --git a/lib/Transforms/Passes.cpp b/lib/Transforms/Passes.cpp new file mode 100644 index 0000000..2363ac1 --- /dev/null +++ b/lib/Transforms/Passes.cpp
@@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Part of the MPACT Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mpact/Transforms/Passes.h" +#include "mpact/Transforms/Sparsity/SparseEncodingPropagate.h" + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "mpact/Transforms/Passes.h.inc" +} // end namespace + +void mlir::mpact::registerTransformPasses() { ::registerPasses(); }
diff --git a/lib/Transforms/Sparsity/CMakeLists.txt b/lib/Transforms/Sparsity/CMakeLists.txt new file mode 100644 index 0000000..9323021 --- /dev/null +++ b/lib/Transforms/Sparsity/CMakeLists.txt
@@ -0,0 +1,15 @@ +add_mlir_conversion_library(MPACTSparsityPropagation + SparseEncodingPropagate.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/mpact/Transforms/Sparsity + + DEPENDS + MPACTTransformsPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass +) + +mpact_target_includes(MPACTSparsityPropagation)
diff --git a/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp b/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp new file mode 100644 index 0000000..f42db3b --- /dev/null +++ b/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp
@@ -0,0 +1,35 @@ +//===- SparseEncodingPropagate.cpp ---------------------------------------===// +// +// Part of the MPACT Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mpact/Transforms/Sparsity/SparseEncodingPropagate.h" + +namespace mlir { +#define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION +#include "mpact/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +// ----------------------------------------------------------------------------- +// The pass +// ----------------------------------------------------------------------------- + +namespace { +struct SparseEncodingPropagation + : public impl::SparseEncodingPropagationBase<SparseEncodingPropagation> { + SparseEncodingPropagation() = default; + SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default; + + void runOnOperation() override {} +}; +} // namespace + +std::unique_ptr<OperationPass<func::FuncOp>> +mlir::mpact::createSparseEncodingPropagationPass() { + return std::make_unique<SparseEncodingPropagation>(); +}
diff --git a/tools/mpact-opt/CMakeLists.txt b/tools/mpact-opt/CMakeLists.txt index ea60836..a26b6ca 100644 --- a/tools/mpact-opt/CMakeLists.txt +++ b/tools/mpact-opt/CMakeLists.txt
@@ -11,5 +11,7 @@ TorchMLIRInitAll TorchMLIRTorchDialect TorchMLIRTorchPasses + + MPACTTransformPasses ${dependency_libraries} )
diff --git a/tools/mpact-opt/mpact_opt.cpp b/tools/mpact-opt/mpact_opt.cpp index a7e5736..87d88c3 100644 --- a/tools/mpact-opt/mpact_opt.cpp +++ b/tools/mpact-opt/mpact_opt.cpp
@@ -12,6 +12,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" +#include "mpact/Transforms/Passes.h" #include "torch-mlir/InitAll.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO @@ -21,6 +22,8 @@ using namespace mlir; int main(int argc, char **argv) { + mlir::mpact::registerTransformPasses(); + mlir::torch::registerAllPasses(); // Core Transforms