load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("//learning/brain/google/xla:libtpu_build_defs.bzl", "tpu_cc_library")
load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test")

package(
    default_visibility = [
        "//platforms/xla:__subpackages__",
        "//platforms/xla/mosaic:__subpackages__",
    ],
)

[
    py_test(
        name = "multidevice_test" + suffix,
        srcs = ["multidevice_test.py"],
        args = args,
        main = "multidevice_test.py",
        tags = tags,
        deps = [
            ":test_util",
            "//learning/brain/research/jax:tpu_simulator_support",
            "//learning/brain/research/jax:tpu_support",
            "//platforms/xla/mosaic/python/dialects:llo_dialect",
            "//testing/pybase",
            "//testing/pybase:parameterized",
            "//third_party/llvm/llvm-project/mlir:IR",
            "//third_party/py/jax",
            "//third_party/py/mlir:arithmetic_dialect",
            "//third_party/py/mlir:builtin_dialect",
            "//third_party/py/mlir:ir",
            "//third_party/py/mlir:memref_dialect",
            "//third_party/py/mlir:vector_dialect",
            "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
            "//third_party/py/numpy",
        ],
    )
    for suffix, args, tags in [
        (
            "_vl_grm",
            [
                "--deepsea_chips_per_host_bounds=2,2,1",
                "--deepsea_version=viperfish",
                "--deepsea_variant=lite",
            ],
            [
                "noasan",
                "notsan",
                "nodebug",
            ],  # Takes way too long with sanitizers or debug mode.
        ),
        (
            "_vl",
            [],
            ["requires-viperlite:8"],
        ),
    ]
]

[
    py_test(
        name = "tpu_custom_call_test" + suffix + py_suffix,
        timeout = "long",
        srcs = ["tpu_custom_call_test.py"],
        args = args,
        env = env,
        main = "tpu_custom_call_test.py",
        shard_count = shard_count,
        tags = tags,
        deps = [
            ":test_util",
            "//learning/brain/research/jax:tpu_iss_support",
            "//learning/brain/research/jax:tpu_support",
            "//platforms/xla/mosaic/python/dialects:llo_dialect",
            "//platforms/xla/mosaic/tests:mlir_interpreter",
            "//testing/pybase",
            "//testing/pybase:parameterized",
            "//third_party/llvm/llvm-project/mlir:IR",
            "//third_party/py/absl/flags",
            "//third_party/py/absl/testing:flagsaver",
            "//third_party/py/hypothesis",
            "//third_party/py/jax",
            "//third_party/py/mlir:arithmetic_dialect",
            "//third_party/py/mlir:builtin_dialect",
            "//third_party/py/mlir:control_flow_dialect",
            "//third_party/py/mlir:ir",
            "//third_party/py/mlir:memref_dialect",
            "//third_party/py/mlir:scf_dialect",
            "//third_party/py/mlir:vector_dialect",
            "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
            "//third_party/py/numpy",
            "//third_party/tensorflow/compiler/xla/python:xla_extension",
        ],
    )
    for py_suffix, env in [
        (
            "_py",
            {"MOSAIC_USE_PYTHON_PIPELINE": "True"},
        ),
        (
            "",
            {"MOSAIC_USE_PYTHON_PIPELINE": "False"},
        ),
    ]
    for suffix, args, tags, shard_count in [
        (
            "_jf_iss",
            [
                "--xla_tpu_mosaic_fusion",
                "--xla_tpu_autofdo=false",
            ],
            [
                "noasan",
                "notsan",
                "nodebug",
            ],
            8,
        ),
        (
            "_pf_iss",
            [
                "--deepsea_version=pufferfish",
                "--xla_tpu_autofdo=false",
            ],
            [
                "noasan",
                "notsan",
                "nodebug",
            ],
            10,
        ),
        (
            "_vl_iss",
            [
                "--deepsea_version=viperfish",
                "--deepsea_variant=lite",
                "--xla_tpu_autofdo=false",
            ],
            [
                "notap",  # vl_iss on TAP drops precision http://shortn/_6XHhU0fZml.
            ],
            10,
        ),
        (
            "_gl_iss",
            [
                "--deepsea_version=ghostlite",
                "--xla_tpu_autofdo=false",
            ],
            [
                "noasan",
                "notsan",
                "nodebug",
            ],
            10,
        ),
        (
            "_gf_iss",
            [
                "--deepsea_version=ghostfish",
                "--xla_tpu_autofdo=false",
            ],
            ["notap"],
            10,
        ),
        (
            "_pf_megacore_iss",
            [
                "--deepsea_version=pufferfish",
                "--deepsea_chip_config_name=megacore",
                "--iss_topology_mode=ISS_TOPOLOGY_MODE_MULTI_TC",
                "--xla_tpu_autofdo=false",
            ],
            [
                "noasan",
                "notsan",
                "nodebug",
            ],
            10,
        ),
        (
            "_jf",
            ["--xla_tpu_mosaic_fusion"],
            ["requires-jellyfish"],
            1,
        ),
        (
            "_pf",
            [],
            ["requires-pufferfish"],
            4,
        ),
        (
            "_vl",
            [],
            ["requires-viperlite"],
            10,
        ),
        (
            "_pf_fusion",
            ["--xla_tpu_mosaic_fusion"],
            ["requires-pufferfish"],
            4,
        ),
        (
            "_vl_fusion",
            ["--xla_tpu_mosaic_fusion"],
            ["requires-viperlite"],
            10,
        ),
    ]
]

py_test(
    name = "apply_vector_layout_test",
    timeout = "long",
    srcs = ["apply_vector_layout_test.py"],
    main = "apply_vector_layout_test.py",
    shard_count = 4,
    deps = [
        ":test_util",
        "//platforms/xla/mosaic/tests:mlir_interpreter",
        "//testing/pybase",
        "//testing/pybase:parameterized",
        "//third_party/py/hypothesis",
        "//third_party/py/jax",
        "//third_party/py/mlir:arithmetic_dialect",
        "//third_party/py/mlir:func_dialect",
        "//third_party/py/mlir:ir",
        "//third_party/py/mlir:math_dialect",
        "//third_party/py/mlir:vector_dialect",
        "//third_party/py/numpy",
    ],
)

tpu_cc_library(
    name = "megacore_adjuster_base",
    hdrs = ["megacore_adjuster_base.h"],
    compatible_with = ["//buildenv/target:libtpu"],
    deps_for_viperfish = [
        "//platforms/xla/sparse_core:debug_info",
        "//platforms/xla/sparse_core:lowering_util",
    ],
    deps = [
        "//learning/brain/tpu/runtime:libtpu_support_macros",
        "//platforms/xla/service/jellyfish:llo_ir",
        "//platforms/xla/service/jellyfish:target_base",
        "//platforms/xla/service/jellyfish/lowering:math_util",
        "//third_party/absl/algorithm:container",
        "//third_party/absl/log",
        "//third_party/absl/log:check",
        "//third_party/absl/types:optional",
        "//third_party/absl/types:span",
        "//third_party/llvm/llvm-project/llvm:Support",
        "//third_party/llvm/llvm-project/mlir:ArithDialect",
        "//third_party/llvm/llvm-project/mlir:IR",
        "//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
        "//third_party/tensorflow/compiler/xla:status",
        "//third_party/tensorflow/compiler/xla:status_macros",
        "//third_party/tensorflow/compiler/xla:util",
        "//third_party/tensorflow/compiler/xla:xla_data_proto_cc",
        "//third_party/tensorflow/tsl/platform:statusor",
    ],
)

cc_library(
    name = "custom_call_emitter",
    srcs = ["custom_call_emitter.cc"],
    compatible_with = ["//buildenv/target:libtpu"],
    deps = [
        ":megacore_adjuster_base",
        "//base:googleinit",
        "//base:vlog",
        "//learning/brain/tpu/runtime:libtpu_support_macros",
        "//learning/brain/tpu/runtime:tpu_version",
        "//platforms/xla/mosaic:llo_dialect",
        "//platforms/xla/mosaic:tpu_passes",
        "//platforms/xla/service:fusion_util",
        "//platforms/xla/service/jellyfish:async_collective_fusion_util",
        "//platforms/xla/service/jellyfish:custom_call_registration",
        "//platforms/xla/service/jellyfish:fusion_options",
        "//platforms/xla/service/jellyfish:hardware_layout",
        "//platforms/xla/service/jellyfish:hlo_deduplication",
        "//platforms/xla/service/jellyfish:llo_ir",
        "//platforms/xla/service/jellyfish:llo_program_shared_registry",
        "//platforms/xla/service/jellyfish:llo_sdc_reporter",
        "//platforms/xla/service/jellyfish:memory_space_enum",
        "//platforms/xla/service/jellyfish:target_base",
        "//platforms/xla/service/jellyfish:tpu_compilation_environment",
        "//platforms/xla/service/jellyfish:tpu_compilation_environment_cc_proto",
        "//platforms/xla/service/jellyfish:tpu_instruction_fusion",
        "//platforms/xla/service/jellyfish:transfer_size_util",
        "//platforms/xla/service/jellyfish/cost_model",
        "//platforms/xla/service/jellyfish/cost_model:cost_model_util",
        "//platforms/xla/service/jellyfish/lowering:backend_config_util",
        "//platforms/xla/service/jellyfish/lowering:backend_configs_cc_proto",
        "//platforms/xla/service/jellyfish/lowering:deep_copy_util",
        "//platforms/xla/service/jellyfish/lowering:fusion_emitter",
        "//platforms/xla/service/jellyfish/lowering:fusion_util",
        "//platforms/xla/service/jellyfish/lowering:group_utils",
        "//platforms/xla/service/jellyfish/lowering:lowering_util",
        "//platforms/xla/service/jellyfish/lowering:net_util",
        "//platforms/xla/service/jellyfish/lowering:op_emitter",
        "//platforms/xla/service/jellyfish/lowering:param_input",
        "//platforms/xla/service/jellyfish/lowering:pipeline_emitter",
        "//platforms/xla/service/jellyfish/lowering:windowing_util",
        "//third_party/absl/algorithm:container",
        "//third_party/absl/container:flat_hash_map",
        "//third_party/absl/flags:flag",
        "//third_party/absl/log",
        "//third_party/absl/log:check",
        "//third_party/absl/status",
        "//third_party/absl/strings",
        "//third_party/absl/strings:str_format",
        "//third_party/absl/strings:string_view",
        "//third_party/absl/types:optional",
        "//third_party/absl/types:span",
        "//third_party/llvm/llvm-project/llvm:Support",
        "//third_party/llvm/llvm-project/mlir:FuncDialect",
        "//third_party/llvm/llvm-project/mlir:FuncExtensions",
        "//third_party/llvm/llvm-project/mlir:IR",
        "//third_party/llvm/llvm-project/mlir:LinalgDialect",
        "//third_party/llvm/llvm-project/mlir:Parser",
        "//third_party/llvm/llvm-project/mlir:Pass",
        "//third_party/llvm/llvm-project/mlir:Support",
        "//third_party/llvm/llvm-project/mlir:TensorDialect",
        "//third_party/llvm/llvm-project/mlir:Transforms",
        "//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
        "//third_party/tensorflow/compiler/xla:shape_util",
        "//third_party/tensorflow/compiler/xla:status",
        "//third_party/tensorflow/compiler/xla:statusor",
        "//third_party/tensorflow/compiler/xla:util",
        "//third_party/tensorflow/compiler/xla:xla_data_proto_cc",
        "//third_party/tensorflow/compiler/xla/hlo/ir:hlo",
        "//third_party/tensorflow/compiler/xla/mlir/utils:type_util",
        "//third_party/tensorflow/compiler/xla/mlir_hlo",
        "//third_party/tensorflow/compiler/xla/mlir_hlo:mhlo_passes",
        "//third_party/tensorflow/compiler/xla/service:hlo_cost_analysis",
        "//third_party/tensorflow/core/platform:path",
        "//third_party/tensorflow/tsl/platform:env",
        "//third_party/tensorflow/tsl/platform:errors",
        "//third_party/tensorflow/tsl/platform:status",
        "//third_party/tensorflow/tsl/platform:statusor",
        "//util/task:status",
    ],
    alwayslink = True,
)

pytype_strict_library(
    name = "test_util",
    testonly = True,
    srcs = ["test_util.py"],
    deps = [
        "//platforms/xla/mosaic/python/dialects:llo_dialect",
        "//testing/pybase:parameterized",
        "//third_party/py/hypothesis",
        "//third_party/py/jax",
        "//third_party/py/jax:mosaic",
        "//third_party/py/mlir:func_dialect",
        "//third_party/py/mlir:ir",
        "//third_party/py/mlir:mhlo_dialect",
        "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
        "//third_party/py/numpy",
    ],
)
