| load("//third_party/llvm/llvm-project/mlir:tblgen.bzl", "gentbl_cc_library", "td_library") |
| |
| package( |
| default_visibility = [ |
| ":__subpackages__", |
| "//platforms/xla/service/jellyfish/mlir:__subpackages__", |
| ], |
| ) |
| |
| # Mosaic is still an experimental project. If you want to use it, please get in touch |
| # with us. We also depend on MLIR and LLVM, both of which only provide loose SLA guarantees. |
| |
| ################################################################################ |
| # TPU dialect |
| |
| cc_library( |
| name = "tpu_passes", |
| srcs = glob(["dialect/tpu/transforms/*.cc"]), |
| hdrs = [ |
| "dialect/tpu/tpu_passes.h", |
| ], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| ":llo_dialect", # TODO(apaszke): Split out the conversion pass |
| ":tpu_inc_gen", |
| "//learning/brain/tpu/runtime:tpu_version", |
| "//learning/brain/tpu/runtime/topology:tpu_topology", |
| "//learning/brain/tpu/runtime/topology:tpu_topology_serdes", |
| "//platforms/xla/service/jellyfish:target_base", |
| "//platforms/xla/service/jellyfish:tpu_constants", |
| "//platforms/xla/service/jellyfish:transfer_size_util", |
| "//third_party/absl/log", |
| "//third_party/absl/log:check", |
| "//third_party/absl/types:span", |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/mlir:ArithDialect", |
| "//third_party/llvm/llvm-project/mlir:ControlFlowDialect", |
| "//third_party/llvm/llvm-project/mlir:FuncDialect", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/llvm/llvm-project/mlir:MathDialect", |
| "//third_party/llvm/llvm-project/mlir:MemRefDialect", |
| "//third_party/llvm/llvm-project/mlir:Pass", |
| "//third_party/llvm/llvm-project/mlir:SCFDialect", |
| "//third_party/llvm/llvm-project/mlir:SCFTransforms", |
| "//third_party/llvm/llvm-project/mlir:Support", |
| "//third_party/llvm/llvm-project/mlir:TransformUtils", |
| "//third_party/llvm/llvm-project/mlir:VectorDialect", |
| "//third_party/py/jax/jaxlib/mosaic:tpu_dialect", |
| "//third_party/tensorflow/compiler/xla:shape_util", |
| "//third_party/tensorflow/compiler/xla/mlir/utils:type_util", |
| ], |
| ) |
| |
| gentbl_cc_library( |
| name = "tpu_inc_gen", |
| compatible_with = ["//buildenv/target:libtpu"], |
| tbl_outs = [ |
| ( |
| [ |
| "-gen-pass-decls", |
| "-name=TPUInternal", |
| ], |
| "dialect/tpu/tpu_passes.h.inc", |
| ), |
| ], |
| tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen", |
| td_file = "dialect/tpu/tpu_passes.td", |
| deps = [":tpu_td_files"], |
| ) |
| |
| td_library( |
| name = "tpu_td_files", |
| srcs = [ |
| "dialect/tpu/tpu_passes.td", |
| ], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| "//third_party/llvm/llvm-project/mlir:BuiltinDialectTdFiles", |
| "//third_party/llvm/llvm-project/mlir:ControlFlowInterfacesTdFiles", |
| "//third_party/llvm/llvm-project/mlir:InferTypeOpInterfaceTdFiles", |
| "//third_party/llvm/llvm-project/mlir:OpBaseTdFiles", |
| "//third_party/llvm/llvm-project/mlir:PassBaseTdFiles", |
| "//third_party/llvm/llvm-project/mlir:SideEffectInterfacesTdFiles", |
| "//third_party/py/jax/jaxlib/mosaic:tpu_td_files", |
| ], |
| ) |
| |
| ################################################################################ |
| # LLO dialect |
| |
| cc_library( |
| name = "llo_dialect", |
| srcs = [ |
| "dialect/llo/llo_builder.cc", |
| "dialect/llo/llo_dialect.cc", |
| "dialect/llo/llo_ops.cc", |
| ], |
| hdrs = [ |
| "dialect/llo/llo_builder.h", |
| "dialect/llo/llo_dialect.h", |
| ], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| ":llo_builder_inc_gen", |
| ":llo_inc_gen", |
| "//platforms/xla/service/jellyfish:dma_strides", |
| "//platforms/xla/service/jellyfish:execution_profiler", |
| "//platforms/xla/service/jellyfish:execution_profiler_traceme", |
| "//platforms/xla/service/jellyfish:llo_ir", |
| "//platforms/xla/service/jellyfish:llo_program_shared_registry", |
| "//platforms/xla/service/jellyfish:target_base", |
| "//platforms/xla/service/jellyfish/lowering:fusion_util", |
| "//platforms/xla/service/jellyfish/lowering:net_util", |
| "//third_party/absl/log", |
| "//third_party/absl/log:check", |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/mlir:FuncDialect", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/llvm/llvm-project/mlir:Pass", |
| "//third_party/llvm/llvm-project/mlir:SCFDialect", |
| "//third_party/llvm/llvm-project/mlir:Support", |
| "//third_party/llvm/llvm-project/mlir:TransformUtils", |
| "//third_party/tensorflow/compiler/xla:shape_util", |
| "//third_party/tensorflow/compiler/xla/mlir/utils:type_util", |
| ], |
| ) |
| |
| gentbl_cc_library( |
| name = "llo_inc_gen", |
| compatible_with = ["//buildenv/target:libtpu"], |
| tbl_outs = [ |
| ( |
| ["-gen-op-decls"], |
| "dialect/llo/llo_ops.h.inc", |
| ), |
| ( |
| ["-gen-op-defs"], |
| "dialect/llo/llo_ops.cc.inc", |
| ), |
| ( |
| ["-gen-dialect-decls"], |
| "dialect/llo/llo_dialect.h.inc", |
| ), |
| ( |
| ["-gen-dialect-defs"], |
| "dialect/llo/llo_dialect.cc.inc", |
| ), |
| ( |
| ["-gen-enum-decls"], |
| "dialect/llo/llo_enums.h.inc", |
| ), |
| ( |
| ["-gen-enum-defs"], |
| "dialect/llo/llo_enums.cc.inc", |
| ), |
| ( |
| ["-gen-attrdef-decls"], |
| "dialect/llo/llo_attr_defs.h.inc", |
| ), |
| ( |
| ["-gen-attrdef-defs"], |
| "dialect/llo/llo_attr_defs.cc.inc", |
| ), |
| ( |
| [ |
| "-gen-pass-decls", |
| "-name=LLO", |
| ], |
| "dialect/llo/llo_pass_defs.h.inc", |
| ), |
| ], |
| tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen", |
| td_file = "dialect/llo/llo.td", |
| deps = [":llo_td_files"], |
| ) |
| |
| td_library( |
| name = "llo_td_files", |
| srcs = [ |
| "dialect/llo/llo.td", |
| ], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| "//third_party/llvm/llvm-project/mlir:ControlFlowInterfacesTdFiles", |
| "//third_party/llvm/llvm-project/mlir:InferTypeOpInterfaceTdFiles", |
| "//third_party/llvm/llvm-project/mlir:OpBaseTdFiles", |
| "//third_party/llvm/llvm-project/mlir:PassBaseTdFiles", |
| "//third_party/llvm/llvm-project/mlir:SideEffectInterfacesTdFiles", |
| ], |
| ) |
| |
| gentbl_cc_library( |
| name = "llo_builder_inc_gen", |
| compatible_with = ["//buildenv/target:libtpu"], |
| tbl_outs = [([], "dialect/llo/llo_builder.cc.inc")], |
| tblgen = ":llo_builder_gen", |
| td_file = "dialect/llo/llo.td", |
| deps = [":llo_td_files"], |
| ) |
| |
| cc_binary( |
| name = "llo_builder_gen", |
| srcs = ["dialect/llo/llo_builder_gen.cc"], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/llvm:TableGen", |
| "//third_party/llvm/llvm-project/mlir:Support", |
| "//third_party/llvm/llvm-project/mlir:TableGen", |
| ], |
| ) |
| |
| ################################################################################ |
| # mlir-tpu-opt |
| |
| cc_binary( |
| name = "mlir-tpu-opt", |
| srcs = ["mlir-tpu-opt.cc"], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| ":llo_dialect", |
| ":tpu_passes", |
| "//base", |
| "//third_party/llvm/llvm-project/mlir:AllExtensions", |
| "//third_party/llvm/llvm-project/mlir:AllPassesAndDialects", |
| "//third_party/llvm/llvm-project/mlir:MlirOptLib", |
| "//third_party/py/jax/jaxlib/mosaic:tpu_dialect", |
| "//third_party/tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", |
| "//third_party/tensorflow/compiler/xla/mlir_hlo:mhlo_passes", |
| ], |
| ) |
| |
| ################################################################################ |
| |
| cc_library( |
| name = "custom_call_kernel_name", |
| srcs = ["custom_call_kernel_name.cc"], |
| hdrs = ["custom_call_kernel_name.h"], |
| deps = [ |
| "//platforms/xla/service/jellyfish/lowering:backend_configs_cc_proto", |
| "//third_party/absl/status", |
| "//third_party/absl/status:statusor", |
| "//third_party/absl/strings:string_view", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/protobuf/util:json_util", |
| ], |
| ) |
| |
| # Helper library to register a PartIR callback to extract the kernel name from a Mosaic custom |
| # call Op. This needs to be linked by the libraries that implement the PartIR tile mapping registry |
| # callbacks. |
| cc_library( |
| name = "partir_callback_registration", |
| srcs = ["partir_callback_registration.cc"], |
| visibility = ["//third_party/py/jax_triton/google/pallas_tpu:__subpackages__"], |
| deps = [ |
| ":custom_call_kernel_name", |
| "//base:googleinit", |
| "//learning/deepmind/partir/compiler/rewrites:custom_call_registry", |
| "//third_party/absl/log:check", |
| "//third_party/absl/status", |
| "//third_party/absl/status:statusor", |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/tensorflow/compiler/xla/mlir_hlo", |
| ], |
| alwayslink = True, |
| ) |