| load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library") |
| load("//devtools/python/blaze:strict.bzl", "py_strict_test") |
| load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test") |
| |
| # Examples showing the use of the compiler. |
| load("//tools/build_defs/testing:bzl_library.bzl", "bzl_library") |
| load("build_defs.bzl", "mosaic_example") |
| |
| # TODO(apaszke): Remove dependencies on LLO and TPU dialects when unnecessary. |
| |
| mosaic_example( |
| name = "benchmark", |
| additional_srcs = ["matmul.py"], |
| additional_tags = [ |
| "requires-net:external", |
| "manual", |
| "notap", |
| ], |
| main = "benchmark.py", |
| visibility = ["//visibility:private"], |
| deps = [ |
| "//file/localfile", |
| "//learning/brain/research/jax:tpu_support", |
| "//perftools/accelerators/xprof/api/python:xprof_analysis_client", |
| "//perftools/accelerators/xprof/api/python:xprof_session", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//pyglib:gfile", |
| "//third_party/py/absl/flags", |
| "//third_party/py/absl/logging", |
| "//third_party/py/absl/testing:absltest", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "matmul", |
| main = "matmul.py", |
| target_kwargs = { |
| "iss": {"args": ["1"]}, # Don't benchmark. |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/tensorflow", |
| "//third_party/py/tensorflow:tensorflow_google", |
| "//third_party/tensorflow/core:protos_all_py_pb2", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "trmm", |
| main = "trmm.py", |
| target_kwargs = { |
| "iss": { |
| "args": ["1"], # Don't benchmark. |
| "tags": ["noasan"], # Don't run asan as it takes too long. |
| }, |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "rope", |
| additional_tags = [ |
| # TODO(tlongeri): Remove after we get vector_constants support on C++ |
| "manual", |
| "notap", |
| ], |
| main = "rope.py", |
| target_kwargs = { |
| "iss": { |
| "args": ["1"], # Don't benchmark. |
| "tags": [ |
| "noasan", # Too slow under ASAN. |
| "notsan", # Too slow under TSAN. |
| ], |
| }, |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "block_sparse_matmul", |
| main = "block_sparse_matmul.py", |
| target_kwargs = { |
| "iss": {"args": ["small"]}, # Don't benchmark. |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//platforms/xla/mosaic/tests:mlir_interpreter", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "segmented_matmul", |
| main = "segmented_matmul.py", |
| target_kwargs = { |
| "iss": {"args": ["small"]}, # Don't benchmark. |
| "vl": {"tags": ["nomsan"]}, # Forge OOM. |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| py_test( |
| name = "matmul_llo_jf", |
| srcs = ["matmul_llo_jf.py"], |
| tags = ["requires-jellyfish"], |
| deps = [ |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| py_test( |
| name = "matmul_llo_jf_iss", |
| srcs = ["matmul_llo_jf.py"], |
| args = ["1"], # Don't benchmark. |
| main = "matmul_llo_jf.py", |
| deps = [ |
| ":matmul_llo_jf", |
| "//learning/brain/research/jax:tpu_iss_support", |
| ], |
| ) |
| |
| py_test( |
| name = "trmm_llo_jf", |
| srcs = ["trmm_llo_jf.py"], |
| tags = ["requires-jellyfish"], |
| deps = [ |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "cholesky", |
| main = "cholesky.py", |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| # This is a test for the bounds checker, to make sure it accepts correct programs. |
| py_strict_test( |
| name = "cholesky_bounds_checked", |
| srcs = ["cholesky.py"], |
| args = ["--xla_mosaic_on_device_checks=bounds"], |
| main = "cholesky.py", |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "cholesky_small", |
| args = [ |
| "--", |
| "8", |
| ], |
| main = "cholesky.py", |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| pytype_strict_library( |
| name = "collective_matmul", |
| srcs = ["collective_matmul.py"], |
| deps = [ |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| py_test( |
| name = "collective_matmul_test_vl", |
| srcs = ["collective_matmul_test.py"], |
| main = "collective_matmul_test.py", |
| tags = [ |
| "notap", # TODO(b/307309369): fix and re-enable |
| "requires-viperlite:8", |
| ], |
| deps = [ |
| ":collective_matmul", |
| "//learning/brain/research/jax:tpu_support", |
| "//perftools/accelerators/xprof/api/python:xprof_session", |
| "//testing/pybase", |
| "//third_party/py/absl/flags", |
| "//third_party/py/absl/testing:parameterized", |
| "//third_party/py/jax", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "flash_attention", |
| main = "flash_attention.py", |
| target_kwargs = { |
| "iss": { |
| "args": ["1"], # Don't benchmark. |
| "tags": [ |
| "noasan", # Too slow under ASAN. |
| "notsan", # Too slow under TSAN. |
| ], |
| }, |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "matmul_mhlo", |
| main = "matmul_mhlo.py", |
| target_kwargs = { |
| "iss": {"args": ["1"]}, # Don't benchmark. |
| }, |
| deps = [ |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:pass_manager", |
| "//third_party/py/mlir:tensor_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| pytype_strict_library( |
| name = "gather_sparsecore_lib", |
| srcs = [ |
| "gather1d_sparsecore.py", |
| "gather2d_sparsecore.py", |
| ], |
| visibility = ["//learning/brain/experimental/jax_tpu_embedding/sparsecore:__subpackages__"], |
| deps = [ |
| "//platforms/xla/sparse_core/mlo/ir:sc_py_dialect", |
| "//third_party/py/absl:app", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:control_flow_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "gather1d_sparsecore", |
| # Only run the example on SC-supporting platforms |
| enabled_platforms = [ |
| "vf_megachip_hardware", |
| "vf_megachip_grm", |
| "vf_megachip_iss", |
| ], |
| main = "gather1d_sparsecore.py", |
| target_kwargs = { |
| "iss": {"args": ["1"]}, # Don't benchmark. |
| "grm": {"args": ["1"]}, # Don't benchmark. |
| "hardware": { |
| "args": ["1"], |
| }, |
| }, |
| deps = [ |
| ":gather_sparsecore_lib", |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_simulator_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| mosaic_example( |
| name = "gather2d_sparsecore", |
| # Only run the example on SC-supporting platforms |
| enabled_platforms = [ |
| "vf_megachip_hardware", |
| "vf_megachip_grm", |
| "vf_megachip_iss", |
| ], |
| main = "gather2d_sparsecore.py", |
| target_kwargs = { |
| "iss": {"args": ["1"]}, # Don't benchmark. |
| "grm": {"args": ["1"]}, # Don't benchmark. |
| "hardware": { |
| "args": ["1"], |
| }, |
| }, |
| deps = [ |
| ":gather_sparsecore_lib", |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_simulator_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| ], |
| ) |
| |
| bzl_library( |
| name = "build_defs_bzl", |
| srcs = ["build_defs.bzl"], |
| parse_tests = False, |
| visibility = ["//visibility:private"], |
| deps = [ |
| "//third_party/bazel_rules/rules_python/python:py_test_bzl", |
| "//third_party/bazel_skylib/lib:dicts", |
| ], |
| ) |