[mpact][passes] set up transform passes structure. (#16)
diff --git a/.gitignore b/.gitignore
index 72af6fc..a6d9b9f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,7 @@
*_venv/
__pycache__
/build*/
+
+# lsp files
+.cache/
+compile_commands.json
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0063ba9..a9964e3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -40,6 +40,23 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
+function(mpact_target_includes target)
+ set(_dirs
+ $<BUILD_INTERFACE:${MLIR_INCLUDE_DIRS}>
+ $<BUILD_INTERFACE:${MPACT_SOURCE_DIR}/include>
+ $<BUILD_INTERFACE:${MPACT_BINARY_DIR}/include>
+ )
+ # In LLVM parlance, the actual target may just be an interface and may not
+ # be responsible for actually compiling anything. The corresponding obj.
+ # target, when present, is just used for compilation and does not
+ # contribute to the interface properties.
+ # TODO: Normalize this upstream.
+ target_include_directories(${target} PUBLIC ${_dirs})
+ if(TARGET obj.${target})
+ target_include_directories(obj.${target} PRIVATE ${_dirs})
+ endif()
+endfunction()
+
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt
index e69de29..711b39d 100644
--- a/include/CMakeLists.txt
+++ b/include/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(mpact)
diff --git a/include/mpact/CMakeLists.txt b/include/mpact/CMakeLists.txt
new file mode 100644
index 0000000..e31af32
--- /dev/null
+++ b/include/mpact/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(Transforms)
diff --git a/include/mpact/Transforms/CMakeLists.txt b/include/mpact/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..17587b1
--- /dev/null
+++ b/include/mpact/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls)
+add_public_tablegen_target(MPACTTransformsPassIncGen)
+
+add_mlir_doc(Passes MPACTTransformsPass ./ -gen-pass-doc)
diff --git a/include/mpact/Transforms/Passes.h b/include/mpact/Transforms/Passes.h
new file mode 100644
index 0000000..389184b
--- /dev/null
+++ b/include/mpact/Transforms/Passes.h
@@ -0,0 +1,22 @@
+//===------------------------------------------------------------*- C++ -*-===//
+//
+// Part of the MPACT Project, under the Apache License v2.0 with LLVM
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MPACT_TRANSFORMS_PASSES_H
+#define MPACT_TRANSFORMS_PASSES_H
+
+namespace mlir {
+namespace mpact {
+
+/// Registers all mpact transform passes.
+void registerTransformPasses();
+
+} // namespace mpact
+} // namespace mlir
+
+#endif // MPACT_TRANSFORMS_PASSES_H
diff --git a/include/mpact/Transforms/Passes.td b/include/mpact/Transforms/Passes.td
new file mode 100644
index 0000000..c83f5d4
--- /dev/null
+++ b/include/mpact/Transforms/Passes.td
@@ -0,0 +1,51 @@
+//===-- Passes.td - Transforms pass definition file --------*- tablegen -*-===//
+//
+// Part of the MPACT Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for passes within the Transforms/ directory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MPACT_TRANSFORMS_PASSES
+#define MPACT_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def SparseEncodingPropagation : Pass<"sparse-encoding-propagation", "func::FuncOp"> {
+ let summary = "Propagate sparse tensor encodings";
+ let description = [{
+ A pass that propagates sparse tensor encodings.
+
+ Background: To avoid introducing repetitive operations, sparse tensors
+ in MLIR try to reuse tensor operations whenever available. However, most
+ tensor operations are canonicalized/transformed without the knowledge
+ of sparsity. The pass tries to propagate missing sparse encodings.
+
+ For example:
+ ```mlir
+ %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>
+
+ // After rank reducing (by tensor dialect transformation)
+ %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2xf32>
+ %s = tensor.expand_shape [[0, 1]] %t
+ : tensor<2xf32> to tensor<2x1xf32, #sparse>
+
+ // After sparsity propagation
+ %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
+ %s = tensor.expand_shape [[0, 1]] %t
+ : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
+ ```
+ }];
+
+ let constructor = "mlir::mpact::createSparseEncodingPropagationPass()";
+ let dependentDialects = [];
+}
+
+#endif // MPACT_TRANSFORMS_PASSES
diff --git a/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h b/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h
new file mode 100644
index 0000000..a829400
--- /dev/null
+++ b/include/mpact/Transforms/Sparsity/SparseEncodingPropagate.h
@@ -0,0 +1,24 @@
+//===------------------------------------------------------------*- C++ -*-===//
+//
+// Part of the MPACT Project, under the Apache License v2.0 with LLVM
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H
+#define MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace mpact {
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSparseEncodingPropagationPass();
+}
+} // namespace mlir
+
+#endif // MPACT_TRANSFORMS_SPARSITY_SPARSEENCODINGPROPAGATE_H
diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt
index e69de29..e31af32 100644
--- a/lib/CMakeLists.txt
+++ b/lib/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(Transforms)
diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..edd737f
--- /dev/null
+++ b/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_subdirectory(Sparsity)
+
+set(linked_libs MPACTSparsityPropagation)
+
+add_mlir_library(MPACTTransformPasses
+ Passes.cpp
+
+ DEPENDS
+ MPACTTransformsPassIncGen
+
+ LINK_LIBS PUBLIC
+ ${linked_libs}
+)
diff --git a/lib/Transforms/Passes.cpp b/lib/Transforms/Passes.cpp
new file mode 100644
index 0000000..2363ac1
--- /dev/null
+++ b/lib/Transforms/Passes.cpp
@@ -0,0 +1,22 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the MPACT Project, under the Apache License v2.0 with LLVM
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Also available under a BSD-style license. See LICENSE.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mpact/Transforms/Passes.h"
+#include "mpact/Transforms/Sparsity/SparseEncodingPropagate.h"
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "mpact/Transforms/Passes.h.inc"
+} // end namespace
+
+void mlir::mpact::registerTransformPasses() { ::registerPasses(); }
diff --git a/lib/Transforms/Sparsity/CMakeLists.txt b/lib/Transforms/Sparsity/CMakeLists.txt
new file mode 100644
index 0000000..9323021
--- /dev/null
+++ b/lib/Transforms/Sparsity/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MPACTSparsityPropagation
+ SparseEncodingPropagate.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/include/mpact/Transforms/Sparsity
+
+ DEPENDS
+ MPACTTransformsPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+)
+
+mpact_target_includes(MPACTSparsityPropagation)
diff --git a/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp b/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp
new file mode 100644
index 0000000..f42db3b
--- /dev/null
+++ b/lib/Transforms/Sparsity/SparseEncodingPropagate.cpp
@@ -0,0 +1,35 @@
+//===- SparseEncodingPropagate.cpp ---------------------------------------===//
+//
+// Part of the MPACT Project, under the Apache License v2.0 with LLVM
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mpact/Transforms/Sparsity/SparseEncodingPropagate.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION
+#include "mpact/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+// -----------------------------------------------------------------------------
+// The pass
+// -----------------------------------------------------------------------------
+
+namespace {
+struct SparseEncodingPropagation
+ : public impl::SparseEncodingPropagationBase<SparseEncodingPropagation> {
+ SparseEncodingPropagation() = default;
+ SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default;
+
+ void runOnOperation() override {}
+};
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::mpact::createSparseEncodingPropagationPass() {
+ return std::make_unique<SparseEncodingPropagation>();
+}
diff --git a/tools/mpact-opt/CMakeLists.txt b/tools/mpact-opt/CMakeLists.txt
index ea60836..a26b6ca 100644
--- a/tools/mpact-opt/CMakeLists.txt
+++ b/tools/mpact-opt/CMakeLists.txt
@@ -11,5 +11,7 @@
TorchMLIRInitAll
TorchMLIRTorchDialect
TorchMLIRTorchPasses
+
+ MPACTTransformPasses
${dependency_libraries}
)
diff --git a/tools/mpact-opt/mpact_opt.cpp b/tools/mpact-opt/mpact_opt.cpp
index a7e5736..87d88c3 100644
--- a/tools/mpact-opt/mpact_opt.cpp
+++ b/tools/mpact-opt/mpact_opt.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
+#include "mpact/Transforms/Passes.h"
#include "torch-mlir/InitAll.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
@@ -21,6 +22,8 @@
using namespace mlir;
int main(int argc, char **argv) {
+ mlir::mpact::registerTransformPasses();
+
mlir::torch::registerAllPasses();
// Core Transforms