blob: 6903f34eea0a75d0832611d65af248503222d22a [file] [log] [blame] [edit]
load("//third_party/llvm/llvm-project/mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
package(
default_visibility = [
":__subpackages__",
"//platforms/xla/service/jellyfish/mlir:__subpackages__",
],
)
# Mosaic is still an experimental project. If you want to use it, please get in touch
# with us. We also depend on MLIR and LLVM, both of which only provide loose SLA guarantees.
################################################################################
# TPU dialect
cc_library(
name = "tpu_passes",
srcs = glob(["dialect/tpu/transforms/*.cc"]),
hdrs = [
"dialect/tpu/tpu_passes.h",
],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
":llo_dialect", # TODO(apaszke): Split out the conversion pass
":tpu_inc_gen",
"//learning/brain/tpu/runtime:tpu_version",
"//learning/brain/tpu/runtime/topology:tpu_topology",
"//learning/brain/tpu/runtime/topology:tpu_topology_serdes",
"//platforms/xla/service/jellyfish:target_base",
"//platforms/xla/service/jellyfish:tpu_constants",
"//platforms/xla/service/jellyfish:transfer_size_util",
"//third_party/absl/log",
"//third_party/absl/log:check",
"//third_party/absl/types:span",
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/mlir:ArithDialect",
"//third_party/llvm/llvm-project/mlir:ControlFlowDialect",
"//third_party/llvm/llvm-project/mlir:FuncDialect",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/llvm/llvm-project/mlir:MathDialect",
"//third_party/llvm/llvm-project/mlir:MemRefDialect",
"//third_party/llvm/llvm-project/mlir:Pass",
"//third_party/llvm/llvm-project/mlir:SCFDialect",
"//third_party/llvm/llvm-project/mlir:SCFTransforms",
"//third_party/llvm/llvm-project/mlir:Support",
"//third_party/llvm/llvm-project/mlir:TransformUtils",
"//third_party/llvm/llvm-project/mlir:VectorDialect",
"//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
"//third_party/tensorflow/compiler/xla:shape_util",
"//third_party/tensorflow/compiler/xla/mlir/utils:type_util",
],
)
gentbl_cc_library(
name = "tpu_inc_gen",
compatible_with = ["//buildenv/target:libtpu"],
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TPUInternal",
],
"dialect/tpu/tpu_passes.h.inc",
),
],
tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen",
td_file = "dialect/tpu/tpu_passes.td",
deps = [":tpu_td_files"],
)
td_library(
name = "tpu_td_files",
srcs = [
"dialect/tpu/tpu_passes.td",
],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
"//third_party/llvm/llvm-project/mlir:BuiltinDialectTdFiles",
"//third_party/llvm/llvm-project/mlir:ControlFlowInterfacesTdFiles",
"//third_party/llvm/llvm-project/mlir:InferTypeOpInterfaceTdFiles",
"//third_party/llvm/llvm-project/mlir:OpBaseTdFiles",
"//third_party/llvm/llvm-project/mlir:PassBaseTdFiles",
"//third_party/llvm/llvm-project/mlir:SideEffectInterfacesTdFiles",
"//third_party/py/jax/jaxlib/mosaic:tpu_td_files",
],
)
################################################################################
# LLO dialect
cc_library(
name = "llo_dialect",
srcs = [
"dialect/llo/llo_builder.cc",
"dialect/llo/llo_dialect.cc",
"dialect/llo/llo_ops.cc",
],
hdrs = [
"dialect/llo/llo_builder.h",
"dialect/llo/llo_dialect.h",
],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
":llo_builder_inc_gen",
":llo_inc_gen",
"//platforms/xla/service/jellyfish:dma_strides",
"//platforms/xla/service/jellyfish:execution_profiler",
"//platforms/xla/service/jellyfish:execution_profiler_traceme",
"//platforms/xla/service/jellyfish:llo_ir",
"//platforms/xla/service/jellyfish:llo_program_shared_registry",
"//platforms/xla/service/jellyfish:target_base",
"//platforms/xla/service/jellyfish/lowering:fusion_util",
"//platforms/xla/service/jellyfish/lowering:net_util",
"//third_party/absl/log",
"//third_party/absl/log:check",
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/mlir:FuncDialect",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/llvm/llvm-project/mlir:Pass",
"//third_party/llvm/llvm-project/mlir:SCFDialect",
"//third_party/llvm/llvm-project/mlir:Support",
"//third_party/llvm/llvm-project/mlir:TransformUtils",
"//third_party/tensorflow/compiler/xla:shape_util",
"//third_party/tensorflow/compiler/xla/mlir/utils:type_util",
],
)
gentbl_cc_library(
name = "llo_inc_gen",
compatible_with = ["//buildenv/target:libtpu"],
tbl_outs = [
(
["-gen-op-decls"],
"dialect/llo/llo_ops.h.inc",
),
(
["-gen-op-defs"],
"dialect/llo/llo_ops.cc.inc",
),
(
["-gen-dialect-decls"],
"dialect/llo/llo_dialect.h.inc",
),
(
["-gen-dialect-defs"],
"dialect/llo/llo_dialect.cc.inc",
),
(
["-gen-enum-decls"],
"dialect/llo/llo_enums.h.inc",
),
(
["-gen-enum-defs"],
"dialect/llo/llo_enums.cc.inc",
),
(
["-gen-attrdef-decls"],
"dialect/llo/llo_attr_defs.h.inc",
),
(
["-gen-attrdef-defs"],
"dialect/llo/llo_attr_defs.cc.inc",
),
(
[
"-gen-pass-decls",
"-name=LLO",
],
"dialect/llo/llo_pass_defs.h.inc",
),
],
tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen",
td_file = "dialect/llo/llo.td",
deps = [":llo_td_files"],
)
td_library(
name = "llo_td_files",
srcs = [
"dialect/llo/llo.td",
],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
"//third_party/llvm/llvm-project/mlir:ControlFlowInterfacesTdFiles",
"//third_party/llvm/llvm-project/mlir:InferTypeOpInterfaceTdFiles",
"//third_party/llvm/llvm-project/mlir:OpBaseTdFiles",
"//third_party/llvm/llvm-project/mlir:PassBaseTdFiles",
"//third_party/llvm/llvm-project/mlir:SideEffectInterfacesTdFiles",
],
)
gentbl_cc_library(
name = "llo_builder_inc_gen",
compatible_with = ["//buildenv/target:libtpu"],
tbl_outs = [([], "dialect/llo/llo_builder.cc.inc")],
tblgen = ":llo_builder_gen",
td_file = "dialect/llo/llo.td",
deps = [":llo_td_files"],
)
cc_binary(
name = "llo_builder_gen",
srcs = ["dialect/llo/llo_builder_gen.cc"],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/llvm:TableGen",
"//third_party/llvm/llvm-project/mlir:Support",
"//third_party/llvm/llvm-project/mlir:TableGen",
],
)
################################################################################
# mlir-tpu-opt
cc_binary(
name = "mlir-tpu-opt",
srcs = ["mlir-tpu-opt.cc"],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
":llo_dialect",
":tpu_passes",
"//base",
"//third_party/llvm/llvm-project/mlir:AllExtensions",
"//third_party/llvm/llvm-project/mlir:AllPassesAndDialects",
"//third_party/llvm/llvm-project/mlir:MlirOptLib",
"//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
"//third_party/tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration",
"//third_party/tensorflow/compiler/xla/mlir_hlo:mhlo_passes",
],
)
################################################################################
cc_library(
name = "custom_call_kernel_name",
srcs = ["custom_call_kernel_name.cc"],
hdrs = ["custom_call_kernel_name.h"],
deps = [
"//platforms/xla/service/jellyfish/lowering:backend_configs_cc_proto",
"//third_party/absl/status",
"//third_party/absl/status:statusor",
"//third_party/absl/strings:string_view",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/protobuf/util:json_util",
],
)
# Helper library to register a PartIR callback to extract the kernel name from a Mosaic custom
# call Op. This needs to be linked by the libraries that implement the PartIR tile mapping registry
# callbacks.
cc_library(
name = "partir_callback_registration",
srcs = ["partir_callback_registration.cc"],
visibility = ["//third_party/py/jax_triton/google/pallas_tpu:__subpackages__"],
deps = [
":custom_call_kernel_name",
"//base:googleinit",
"//learning/deepmind/partir/compiler/rewrites:custom_call_registry",
"//third_party/absl/log:check",
"//third_party/absl/status",
"//third_party/absl/status:statusor",
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/tensorflow/compiler/xla/mlir_hlo",
],
alwayslink = True,
)