blob: a42d1f6d6ceb31efd26a230bcdf2a23923646dbb [file] [log] [blame]
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("//devtools/python/blaze:strict.bzl", "py_strict_test")
load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test")
# Examples showing the use of the compiler.
load("//tools/build_defs/testing:bzl_library.bzl", "bzl_library")
load("build_defs.bzl", "mosaic_example")
# TODO(apaszke): Remove dependencies on LLO and TPU dialects when unnecessary.
mosaic_example(
name = "benchmark",
additional_srcs = ["matmul.py"],
additional_tags = [
"requires-net:external",
"manual",
"notap",
],
main = "benchmark.py",
visibility = ["//visibility:private"],
deps = [
"//file/localfile",
"//learning/brain/research/jax:tpu_support",
"//perftools/accelerators/xprof/api/python:xprof_analysis_client",
"//perftools/accelerators/xprof/api/python:xprof_session",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//pyglib:gfile",
"//third_party/py/absl/flags",
"//third_party/py/absl/logging",
"//third_party/py/absl/testing:absltest",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
mosaic_example(
name = "matmul",
main = "matmul.py",
target_kwargs = {
"iss": {"args": ["1"]}, # Don't benchmark.
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/tensorflow",
"//third_party/py/tensorflow:tensorflow_google",
"//third_party/tensorflow/core:protos_all_py_pb2",
],
)
mosaic_example(
name = "trmm",
main = "trmm.py",
target_kwargs = {
"iss": {
"args": ["1"], # Don't benchmark.
"tags": ["noasan"], # Don't run asan as it takes too long.
},
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_example(
name = "rope",
additional_tags = [
# TODO(tlongeri): Remove after we get vector_constants support on C++
"manual",
"notap",
],
main = "rope.py",
target_kwargs = {
"iss": {
"args": ["1"], # Don't benchmark.
"tags": [
"noasan", # Too slow under ASAN.
"notsan", # Too slow under TSAN.
],
},
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_example(
name = "block_sparse_matmul",
main = "block_sparse_matmul.py",
target_kwargs = {
"iss": {"args": ["small"]}, # Don't benchmark.
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//platforms/xla/mosaic/tests:mlir_interpreter",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_example(
name = "segmented_matmul",
main = "segmented_matmul.py",
target_kwargs = {
"iss": {"args": ["small"]}, # Don't benchmark.
"vl": {"tags": ["nomsan"]}, # Forge OOM.
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
py_test(
name = "matmul_llo_jf",
srcs = ["matmul_llo_jf.py"],
tags = ["requires-jellyfish"],
deps = [
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
py_test(
name = "matmul_llo_jf_iss",
srcs = ["matmul_llo_jf.py"],
args = ["1"], # Don't benchmark.
main = "matmul_llo_jf.py",
deps = [
":matmul_llo_jf",
"//learning/brain/research/jax:tpu_iss_support",
],
)
py_test(
name = "trmm_llo_jf",
srcs = ["trmm_llo_jf.py"],
tags = ["requires-jellyfish"],
deps = [
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
mosaic_example(
name = "cholesky",
main = "cholesky.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
# This is a test for the bounds checker, to make sure it accepts correct programs.
py_strict_test(
name = "cholesky_bounds_checked",
srcs = ["cholesky.py"],
args = ["--xla_mosaic_on_device_checks=bounds"],
main = "cholesky.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
mosaic_example(
name = "cholesky_small",
args = [
"--",
"8",
],
main = "cholesky.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:math_dialect",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
pytype_strict_library(
name = "collective_matmul",
srcs = ["collective_matmul.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/numpy",
],
)
py_test(
name = "collective_matmul_test_vl",
srcs = ["collective_matmul_test.py"],
main = "collective_matmul_test.py",
tags = [
"notap", # TODO(b/307309369): fix and re-enable
"requires-viperlite:8",
],
deps = [
":collective_matmul",
"//learning/brain/research/jax:tpu_support",
"//perftools/accelerators/xprof/api/python:xprof_session",
"//testing/pybase",
"//third_party/py/absl/flags",
"//third_party/py/absl/testing:parameterized",
"//third_party/py/jax",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
"//third_party/py/numpy",
],
)
mosaic_example(
name = "flash_attention",
main = "flash_attention.py",
target_kwargs = {
"iss": {
"args": ["1"], # Don't benchmark.
"tags": [
"noasan", # Too slow under ASAN.
"notsan", # Too slow under TSAN.
],
},
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/mosaic/python/dialects:llo_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_example(
name = "matmul_mhlo",
main = "matmul_mhlo.py",
target_kwargs = {
"iss": {"args": ["1"]}, # Don't benchmark.
},
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_support",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:builtin_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:pass_manager",
"//third_party/py/mlir:tensor_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
pytype_strict_library(
name = "gather_sparsecore_lib",
srcs = [
"gather1d_sparsecore.py",
"gather2d_sparsecore.py",
],
visibility = ["//learning/brain/experimental/jax_tpu_embedding/sparsecore:__subpackages__"],
deps = [
"//platforms/xla/sparse_core/mlo/ir:sc_py_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:control_flow_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:memref_dialect",
"//third_party/py/mlir:scf_dialect",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/numpy",
],
)
mosaic_example(
name = "gather1d_sparsecore",
# Only run the example on SC-supporting platforms
enabled_platforms = [
"vf_megachip_hardware",
"vf_megachip_grm",
"vf_megachip_iss",
],
main = "gather1d_sparsecore.py",
target_kwargs = {
"iss": {"args": ["1"]}, # Don't benchmark.
"grm": {"args": ["1"]}, # Don't benchmark.
"hardware": {
"args": ["1"],
},
},
deps = [
":gather_sparsecore_lib",
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_example(
name = "gather2d_sparsecore",
# Only run the example on SC-supporting platforms
enabled_platforms = [
"vf_megachip_hardware",
"vf_megachip_grm",
"vf_megachip_iss",
],
main = "gather2d_sparsecore.py",
target_kwargs = {
"iss": {"args": ["1"]}, # Don't benchmark.
"grm": {"args": ["1"]}, # Don't benchmark.
"hardware": {
"args": ["1"],
},
},
deps = [
":gather_sparsecore_lib",
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
bzl_library(
name = "build_defs_bzl",
srcs = ["build_defs.bzl"],
parse_tests = False,
visibility = ["//visibility:private"],
deps = [
"//third_party/bazel_rules/rules_python/python:py_test_bzl",
"//third_party/bazel_skylib/lib:dicts",
],
)