blob: 109b1854f06fdd0ca9e792051c26bd83a2024baf [file] [log] [blame]
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_binary", "pytype_strict_library")
load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests")
load("//tools/build_defs/testing:bzl_library.bzl", "bzl_library")
load("build_defs.bzl", "mosaic_sc_test")
package(
default_visibility = ["//platforms/xla/mosaic:__subpackages__"],
)
glob_lit_tests(
name = "all_tests",
data = [
":lit_test_utilities",
],
driver = "lit/lit_test.sh",
test_file_exts = ["mlir"],
)
filegroup(
name = "lit_test_utilities",
testonly = True,
data = [
"//platforms/xla/mosaic:mlir-tpu-opt",
"//third_party/llvm/llvm-project/llvm:FileCheck",
],
)
pytype_strict_binary(
name = "llo_dump_diff",
srcs = ["llo_dump_diff.py"],
deps = [
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/absl/logging",
"//third_party/py/google/protobuf:use_fast_cpp_protos", # Automatically added go/proto_python_upb_flip
],
)
pytype_strict_library(
name = "mlir_interpreter",
testonly = True,
srcs = ["mlir_interpreter.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/numpy",
],
)
mosaic_sc_test(
name = "sparsecore_tile_index",
main = "sparsecore_tile_index.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/sparse_core/mlo/ir:sc_py_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_sc_test(
name = "sparsecore_vload_arith",
main = "sparsecore_vload_arith.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/sparse_core/mlo/ir:sc_py_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
mosaic_sc_test(
name = "sparsecore_persistent_args",
main = "sparsecore_persistent_args.py",
deps = [
"//learning/brain/research/jax:tpu_iss_support",
"//learning/brain/research/jax:tpu_simulator_support",
"//learning/brain/research/jax:tpu_support",
"//platforms/xla/sparse_core/mlo/ir:sc_py_dialect",
"//third_party/py/absl:app",
"//third_party/py/jax",
"//third_party/py/jax:mosaic",
"//third_party/py/mlir:arithmetic_dialect",
"//third_party/py/mlir:func_dialect",
"//third_party/py/mlir:ir",
"//third_party/py/mlir:vector_dialect",
"//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything",
],
)
bzl_library(
name = "build_defs_bzl",
srcs = ["build_defs.bzl"],
parse_tests = False,
visibility = ["//visibility:private"],
deps = [
"//third_party/bazel_rules/rules_python/python:py_test_bzl",
"//third_party/bazel_skylib/lib:dicts",
],
)