blob: f9d19b817969339bf1c6866bd25eb28cda4b2708 [file] [log] [blame]
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("//learning/brain/google/xla:libtpu_build_defs.bzl", "tpu_cc_library")
load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test")
package(
default_visibility = [
"//platforms/xla:__subpackages__",
"//platforms/xla/mosaic:__subpackages__",
],
)
[
py_test(
name = "multidevice_test" + suffix,
srcs = ["multidevice_test.py"],
args = args,
main = "multidevice_test.py",
tags = tags,
deps = [
":test_util",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//testing/pybase",
"//testing/pybase:parameterized",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/py/jax",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
for suffix, args, tags in [
(
"_vl_grm",
[
"--deepsea_chips_per_host_bounds=2,2,1",
"--deepsea_version=viperfish",
"--deepsea_variant=lite",
],
[
"noasan",
"notsan",
"nodebug",
], # Takes way too long with sanitizers or debug mode.
),
(
"_vl",
[],
["requires-viperlite:8"],
),
]
]
[
py_test(
name = "tpu_custom_call_test" + suffix + py_suffix,
timeout = "long",
srcs = ["tpu_custom_call_test.py"],
args = args,
env = env,
main = "tpu_custom_call_test.py",
shard_count = shard_count,
tags = tags,
deps = [
":test_util",
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//platforms/xla/mosaic/tests:mlir_interpreter",
"//testing/pybase",
"//testing/pybase:parameterized",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/py/absl/flags",
"//third_party/py/absl/testing:flagsaver",
"//third_party/py/hypothesis",
"//third_party/py/jax",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:control_flow_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
"//third_party/tensorflow/compiler/xla/python:xla_extension",
],
)
for py_suffix, env in [
(
"_py",
{"MOSAIC_USE_PYTHON_PIPELINE": "True"},
),
(
"",
{"MOSAIC_USE_PYTHON_PIPELINE": "False"},
),
]
for suffix, args, tags, shard_count in [
(
"_jf_iss",
[
"--xla_tpu_mosaic_fusion",
"--xla_tpu_autofdo=false",
],
[
"noasan",
"notsan",
"nodebug",
],
8,
),
(
"_pf_iss",
[
"--deepsea_version=pufferfish",
"--xla_tpu_autofdo=false",
],
[
"noasan",
"notsan",
"nodebug",
],
10,
),
(
"_vl_iss",
[
"--deepsea_version=viperfish",
"--deepsea_variant=lite",
"--xla_tpu_autofdo=false",
],
[
"notap", # vl_iss on TAP drops precision http://shortn/_6XHhU0fZml.
],
10,
),
(
"_gl_iss",
[
"--deepsea_version=ghostlite",
"--xla_tpu_autofdo=false",
],
[
"noasan",
"notsan",
"nodebug",
],
10,
),
(
"_gf_iss",
[
"--deepsea_version=ghostfish",
"--xla_tpu_autofdo=false",
],
["notap"],
10,
),
(
"_pf_megacore_iss",
[
"--deepsea_version=pufferfish",
"--deepsea_chip_config_name=megacore",
"--iss_topology_mode=ISS_TOPOLOGY_MODE_MULTI_TC",
"--xla_tpu_autofdo=false",
],
[
"noasan",
"notsan",
"nodebug",
],
10,
),
(
"_jf",
["--xla_tpu_mosaic_fusion"],
["requires-jellyfish"],
1,
),
(
"_pf",
[],
["requires-pufferfish"],
4,
),
(
"_vl",
[],
["requires-viperlite"],
10,
),
(
"_pf_fusion",
["--xla_tpu_mosaic_fusion"],
["requires-pufferfish"],
4,
),
(
"_vl_fusion",
["--xla_tpu_mosaic_fusion"],
["requires-viperlite"],
10,
),
]
]
py_test(
name = "apply_vector_layout_test",
timeout = "long",
srcs = ["apply_vector_layout_test.py"],
main = "apply_vector_layout_test.py",
shard_count = 4,
deps = [
":test_util",
"//platforms/xla/mosaic/tests:mlir_interpreter",
"//testing/pybase",
"//testing/pybase:parameterized",
"//third_party/py/hypothesis",
"//third_party/py/jax",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/numpy",
],
)
tpu_cc_library(
name = "megacore_adjuster_base",
hdrs = ["megacore_adjuster_base.h"],
compatible_with = ["//buildenv/target:libtpu"],
deps_for_viperfish = [
"//platforms/xla/sparse_core:debug_info",
"//platforms/xla/sparse_core:lowering_util",
],
deps = [
"//learning/brain/tpu/runtime:libtpu_support_macros",
"//platforms/xla/service/jellyfish:llo_ir",
"//platforms/xla/service/jellyfish:target_base",
"//platforms/xla/service/jellyfish/lowering:math_util",
"//third_party/absl/algorithm:container",
"//third_party/absl/log",
"//third_party/absl/log:check",
"//third_party/absl/types:optional",
"//third_party/absl/types:span",
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/mlir:ArithDialect",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
"//third_party/tensorflow/compiler/xla:status",
"//third_party/tensorflow/compiler/xla:status_macros",
"//third_party/tensorflow/compiler/xla:util",
"//third_party/tensorflow/compiler/xla:xla_data_proto_cc",
"//third_party/tensorflow/tsl/platform:statusor",
],
)
cc_library(
name = "custom_call_emitter",
srcs = ["custom_call_emitter.cc"],
compatible_with = ["//buildenv/target:libtpu"],
deps = [
":megacore_adjuster_base",
"//base:googleinit",
"//base:vlog",
"//learning/brain/tpu/runtime:libtpu_support_macros",
"//learning/brain/tpu/runtime:tpu_version",
"//platforms/xla/mosaic:llo_dialect",
"//platforms/xla/mosaic:tpu_passes",
"//platforms/xla/service:fusion_util",
"//platforms/xla/service/jellyfish:async_collective_fusion_util",
"//platforms/xla/service/jellyfish:custom_call_registration",
"//platforms/xla/service/jellyfish:fusion_options",
"//platforms/xla/service/jellyfish:hardware_layout",
"//platforms/xla/service/jellyfish:hlo_deduplication",
"//platforms/xla/service/jellyfish:llo_ir",
"//platforms/xla/service/jellyfish:llo_program_shared_registry",
"//platforms/xla/service/jellyfish:llo_sdc_reporter",
"//platforms/xla/service/jellyfish:memory_space_enum",
"//platforms/xla/service/jellyfish:target_base",
"//platforms/xla/service/jellyfish:tpu_compilation_environment",
"//platforms/xla/service/jellyfish:tpu_compilation_environment_cc_proto",
"//platforms/xla/service/jellyfish:tpu_instruction_fusion",
"//platforms/xla/service/jellyfish:transfer_size_util",
"//platforms/xla/service/jellyfish/cost_model",
"//platforms/xla/service/jellyfish/cost_model:cost_model_util",
"//platforms/xla/service/jellyfish/lowering:backend_config_util",
"//platforms/xla/service/jellyfish/lowering:backend_configs_cc_proto",
"//platforms/xla/service/jellyfish/lowering:deep_copy_util",
"//platforms/xla/service/jellyfish/lowering:fusion_emitter",
"//platforms/xla/service/jellyfish/lowering:fusion_util",
"//platforms/xla/service/jellyfish/lowering:group_utils",
"//platforms/xla/service/jellyfish/lowering:lowering_util",
"//platforms/xla/service/jellyfish/lowering:net_util",
"//platforms/xla/service/jellyfish/lowering:op_emitter",
"//platforms/xla/service/jellyfish/lowering:param_input",
"//platforms/xla/service/jellyfish/lowering:pipeline_emitter",
"//platforms/xla/service/jellyfish/lowering:windowing_util",
"//third_party/absl/algorithm:container",
"//third_party/absl/container:flat_hash_map",
"//third_party/absl/flags:flag",
"//third_party/absl/log",
"//third_party/absl/log:check",
"//third_party/absl/status",
"//third_party/absl/strings",
"//third_party/absl/strings:str_format",
"//third_party/absl/strings:string_view",
"//third_party/absl/types:optional",
"//third_party/absl/types:span",
"//third_party/llvm/llvm-project/llvm:Support",
"//third_party/llvm/llvm-project/mlir:FuncDialect",
"//third_party/llvm/llvm-project/mlir:FuncExtensions",
"//third_party/llvm/llvm-project/mlir:IR",
"//third_party/llvm/llvm-project/mlir:LinalgDialect",
"//third_party/llvm/llvm-project/mlir:Parser",
"//third_party/llvm/llvm-project/mlir:Pass",
"//third_party/llvm/llvm-project/mlir:Support",
"//third_party/llvm/llvm-project/mlir:TensorDialect",
"//third_party/llvm/llvm-project/mlir:Transforms",
"//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
"//third_party/tensorflow/compiler/xla:shape_util",
"//third_party/tensorflow/compiler/xla:status",
"//third_party/tensorflow/compiler/xla:statusor",
"//third_party/tensorflow/compiler/xla:util",
"//third_party/tensorflow/compiler/xla:xla_data_proto_cc",
"//third_party/tensorflow/compiler/xla/hlo/ir:hlo",
"//third_party/tensorflow/compiler/xla/mlir/utils:type_util",
"//third_party/tensorflow/compiler/xla/mlir_hlo",
"//third_party/tensorflow/compiler/xla/mlir_hlo:mhlo_passes",
"//third_party/tensorflow/compiler/xla/service:hlo_cost_analysis",
"//third_party/tensorflow/core/platform:path",
"//third_party/tensorflow/tsl/platform:env",
"//third_party/tensorflow/tsl/platform:errors",
"//third_party/tensorflow/tsl/platform:status",
"//third_party/tensorflow/tsl/platform:statusor",
"//util/task:status",
],
alwayslink = True,
)
pytype_strict_library(
name = "test_util",
testonly = True,
srcs = ["test_util.py"],
deps = [
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//testing/pybase:parameterized",
"//third_party/py/hypothesis",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:mhlo_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)