blob: b4752c3f313d7597f3df7b87749544288818e1da [file] [log] [blame]
#ifndef PLATFORMS_XLA_MOSAIC_DIALECT_TPU_TPU_PASSES_H_
#define PLATFORMS_XLA_MOSAIC_DIALECT_TPU_TPU_PASSES_H_
#include <memory>
#include "platforms/xla/service/jellyfish/target.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep.
#include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h" // IWYU pragma: keep.
namespace mlir::tpu {
class TPUDialect;
} // namespace mlir::tpu
namespace mlir {
namespace tpu {
std::unique_ptr<OperationPass<func::FuncOp>> createLowerToLLOPass(
const xla::jellyfish::Target& target,
bool unsafe_allow_multicore_remote_dma = false);
std::unique_ptr<OperationPass<func::FuncOp>> createPartialLowerToLLOPass();
#define GEN_PASS_REGISTRATION
#include "platforms/xla/mosaic/dialect/tpu/tpu_passes.h.inc"
} // namespace tpu
} // namespace mlir
#endif // PLATFORMS_XLA_MOSAIC_DIALECT_TPU_TPU_PASSES_H_