| load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library") |
| load("//learning/brain/google/xla:libtpu_build_defs.bzl", "tpu_cc_library") |
| load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test") |
| |
| package( |
| default_visibility = [ |
| "//platforms/xla:__subpackages__", |
| "//platforms/xla/mosaic:__subpackages__", |
| ], |
| ) |
| |
| [ |
| py_test( |
| name = "multidevice_test" + suffix, |
| srcs = ["multidevice_test.py"], |
| args = args, |
| main = "multidevice_test.py", |
| tags = tags, |
| deps = [ |
| ":test_util", |
| "//learning/brain/research/jax:tpu_simulator_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//testing/pybase", |
| "//testing/pybase:parameterized", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/py/jax", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |
| for suffix, args, tags in [ |
| ( |
| "_vl_grm", |
| [ |
| "--deepsea_chips_per_host_bounds=2,2,1", |
| "--deepsea_version=viperfish", |
| "--deepsea_variant=lite", |
| ], |
| [ |
| "noasan", |
| "notsan", |
| "nodebug", |
| ], # Takes way too long with sanitizers or debug mode. |
| ), |
| ( |
| "_vl", |
| [], |
| ["requires-viperlite:8"], |
| ), |
| ] |
| ] |
| |
| [ |
| py_test( |
| name = "tpu_custom_call_test" + suffix + py_suffix, |
| timeout = "long", |
| srcs = ["tpu_custom_call_test.py"], |
| args = args, |
| env = env, |
| main = "tpu_custom_call_test.py", |
| shard_count = shard_count, |
| tags = tags, |
| deps = [ |
| ":test_util", |
| "//learning/brain/research/jax:tpu_iss_support", |
| "//learning/brain/research/jax:tpu_support", |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//platforms/xla/mosaic/tests:mlir_interpreter", |
| "//testing/pybase", |
| "//testing/pybase:parameterized", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/py/absl/flags", |
| "//third_party/py/absl/testing:flagsaver", |
| "//third_party/py/hypothesis", |
| "//third_party/py/jax", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:builtin_dialect", |
| "//third_party/py/mlir:control_flow_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:memref_dialect", |
| "//third_party/py/mlir:scf_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| "//third_party/tensorflow/compiler/xla/python:xla_extension", |
| ], |
| ) |
| for py_suffix, env in [ |
| ( |
| "_py", |
| {"MOSAIC_USE_PYTHON_PIPELINE": "True"}, |
| ), |
| ( |
| "", |
| {"MOSAIC_USE_PYTHON_PIPELINE": "False"}, |
| ), |
| ] |
| for suffix, args, tags, shard_count in [ |
| ( |
| "_jf_iss", |
| [ |
| "--xla_tpu_mosaic_fusion", |
| "--xla_tpu_autofdo=false", |
| ], |
| [ |
| "noasan", |
| "notsan", |
| "nodebug", |
| ], |
| 8, |
| ), |
| ( |
| "_pf_iss", |
| [ |
| "--deepsea_version=pufferfish", |
| "--xla_tpu_autofdo=false", |
| ], |
| [ |
| "noasan", |
| "notsan", |
| "nodebug", |
| ], |
| 10, |
| ), |
| ( |
| "_vl_iss", |
| [ |
| "--deepsea_version=viperfish", |
| "--deepsea_variant=lite", |
| "--xla_tpu_autofdo=false", |
| ], |
| [ |
| "notap", # vl_iss on TAP drops precision http://shortn/_6XHhU0fZml. |
| ], |
| 10, |
| ), |
| ( |
| "_gl_iss", |
| [ |
| "--deepsea_version=ghostlite", |
| "--xla_tpu_autofdo=false", |
| ], |
| [ |
| "noasan", |
| "notsan", |
| "nodebug", |
| ], |
| 10, |
| ), |
| ( |
| "_gf_iss", |
| [ |
| "--deepsea_version=ghostfish", |
| "--xla_tpu_autofdo=false", |
| ], |
| ["notap"], |
| 10, |
| ), |
| ( |
| "_pf_megacore_iss", |
| [ |
| "--deepsea_version=pufferfish", |
| "--deepsea_chip_config_name=megacore", |
| "--iss_topology_mode=ISS_TOPOLOGY_MODE_MULTI_TC", |
| "--xla_tpu_autofdo=false", |
| ], |
| [ |
| "noasan", |
| "notsan", |
| "nodebug", |
| ], |
| 10, |
| ), |
| ( |
| "_jf", |
| ["--xla_tpu_mosaic_fusion"], |
| ["requires-jellyfish"], |
| 1, |
| ), |
| ( |
| "_pf", |
| [], |
| ["requires-pufferfish"], |
| 4, |
| ), |
| ( |
| "_vl", |
| [], |
| ["requires-viperlite"], |
| 10, |
| ), |
| ( |
| "_pf_fusion", |
| ["--xla_tpu_mosaic_fusion"], |
| ["requires-pufferfish"], |
| 4, |
| ), |
| ( |
| "_vl_fusion", |
| ["--xla_tpu_mosaic_fusion"], |
| ["requires-viperlite"], |
| 10, |
| ), |
| ] |
| ] |
| |
| py_test( |
| name = "apply_vector_layout_test", |
| timeout = "long", |
| srcs = ["apply_vector_layout_test.py"], |
| main = "apply_vector_layout_test.py", |
| shard_count = 4, |
| deps = [ |
| ":test_util", |
| "//platforms/xla/mosaic/tests:mlir_interpreter", |
| "//testing/pybase", |
| "//testing/pybase:parameterized", |
| "//third_party/py/hypothesis", |
| "//third_party/py/jax", |
| "//third_party/py/mlir:arithmetic_dialect", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:math_dialect", |
| "//third_party/py/mlir:vector_dialect", |
| "//third_party/py/numpy", |
| ], |
| ) |
| |
| tpu_cc_library( |
| name = "megacore_adjuster_base", |
| hdrs = ["megacore_adjuster_base.h"], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps_for_viperfish = [ |
| "//platforms/xla/sparse_core:debug_info", |
| "//platforms/xla/sparse_core:lowering_util", |
| ], |
| deps = [ |
| "//learning/brain/tpu/runtime:libtpu_support_macros", |
| "//platforms/xla/service/jellyfish:llo_ir", |
| "//platforms/xla/service/jellyfish:target_base", |
| "//platforms/xla/service/jellyfish/lowering:math_util", |
| "//third_party/absl/algorithm:container", |
| "//third_party/absl/log", |
| "//third_party/absl/log:check", |
| "//third_party/absl/types:optional", |
| "//third_party/absl/types:span", |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/mlir:ArithDialect", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/py/jax/jaxlib/mosaic:tpu_dialect", |
| "//third_party/tensorflow/compiler/xla:status", |
| "//third_party/tensorflow/compiler/xla:status_macros", |
| "//third_party/tensorflow/compiler/xla:util", |
| "//third_party/tensorflow/compiler/xla:xla_data_proto_cc", |
| "//third_party/tensorflow/tsl/platform:statusor", |
| ], |
| ) |
| |
| cc_library( |
| name = "custom_call_emitter", |
| srcs = ["custom_call_emitter.cc"], |
| compatible_with = ["//buildenv/target:libtpu"], |
| deps = [ |
| ":megacore_adjuster_base", |
| "//base:googleinit", |
| "//base:vlog", |
| "//learning/brain/tpu/runtime:libtpu_support_macros", |
| "//learning/brain/tpu/runtime:tpu_version", |
| "//platforms/xla/mosaic:llo_dialect", |
| "//platforms/xla/mosaic:tpu_passes", |
| "//platforms/xla/service:fusion_util", |
| "//platforms/xla/service/jellyfish:async_collective_fusion_util", |
| "//platforms/xla/service/jellyfish:custom_call_registration", |
| "//platforms/xla/service/jellyfish:fusion_options", |
| "//platforms/xla/service/jellyfish:hardware_layout", |
| "//platforms/xla/service/jellyfish:hlo_deduplication", |
| "//platforms/xla/service/jellyfish:llo_ir", |
| "//platforms/xla/service/jellyfish:llo_program_shared_registry", |
| "//platforms/xla/service/jellyfish:llo_sdc_reporter", |
| "//platforms/xla/service/jellyfish:memory_space_enum", |
| "//platforms/xla/service/jellyfish:target_base", |
| "//platforms/xla/service/jellyfish:tpu_compilation_environment", |
| "//platforms/xla/service/jellyfish:tpu_compilation_environment_cc_proto", |
| "//platforms/xla/service/jellyfish:tpu_instruction_fusion", |
| "//platforms/xla/service/jellyfish:transfer_size_util", |
| "//platforms/xla/service/jellyfish/cost_model", |
| "//platforms/xla/service/jellyfish/cost_model:cost_model_util", |
| "//platforms/xla/service/jellyfish/lowering:backend_config_util", |
| "//platforms/xla/service/jellyfish/lowering:backend_configs_cc_proto", |
| "//platforms/xla/service/jellyfish/lowering:deep_copy_util", |
| "//platforms/xla/service/jellyfish/lowering:fusion_emitter", |
| "//platforms/xla/service/jellyfish/lowering:fusion_util", |
| "//platforms/xla/service/jellyfish/lowering:group_utils", |
| "//platforms/xla/service/jellyfish/lowering:lowering_util", |
| "//platforms/xla/service/jellyfish/lowering:net_util", |
| "//platforms/xla/service/jellyfish/lowering:op_emitter", |
| "//platforms/xla/service/jellyfish/lowering:param_input", |
| "//platforms/xla/service/jellyfish/lowering:pipeline_emitter", |
| "//platforms/xla/service/jellyfish/lowering:windowing_util", |
| "//third_party/absl/algorithm:container", |
| "//third_party/absl/container:flat_hash_map", |
| "//third_party/absl/flags:flag", |
| "//third_party/absl/log", |
| "//third_party/absl/log:check", |
| "//third_party/absl/status", |
| "//third_party/absl/strings", |
| "//third_party/absl/strings:str_format", |
| "//third_party/absl/strings:string_view", |
| "//third_party/absl/types:optional", |
| "//third_party/absl/types:span", |
| "//third_party/llvm/llvm-project/llvm:Support", |
| "//third_party/llvm/llvm-project/mlir:FuncDialect", |
| "//third_party/llvm/llvm-project/mlir:FuncExtensions", |
| "//third_party/llvm/llvm-project/mlir:IR", |
| "//third_party/llvm/llvm-project/mlir:LinalgDialect", |
| "//third_party/llvm/llvm-project/mlir:Parser", |
| "//third_party/llvm/llvm-project/mlir:Pass", |
| "//third_party/llvm/llvm-project/mlir:Support", |
| "//third_party/llvm/llvm-project/mlir:TensorDialect", |
| "//third_party/llvm/llvm-project/mlir:Transforms", |
| "//third_party/py/jax/jaxlib/mosaic:tpu_dialect", |
| "//third_party/tensorflow/compiler/xla:shape_util", |
| "//third_party/tensorflow/compiler/xla:status", |
| "//third_party/tensorflow/compiler/xla:statusor", |
| "//third_party/tensorflow/compiler/xla:util", |
| "//third_party/tensorflow/compiler/xla:xla_data_proto_cc", |
| "//third_party/tensorflow/compiler/xla/hlo/ir:hlo", |
| "//third_party/tensorflow/compiler/xla/mlir/utils:type_util", |
| "//third_party/tensorflow/compiler/xla/mlir_hlo", |
| "//third_party/tensorflow/compiler/xla/mlir_hlo:mhlo_passes", |
| "//third_party/tensorflow/compiler/xla/service:hlo_cost_analysis", |
| "//third_party/tensorflow/core/platform:path", |
| "//third_party/tensorflow/tsl/platform:env", |
| "//third_party/tensorflow/tsl/platform:errors", |
| "//third_party/tensorflow/tsl/platform:status", |
| "//third_party/tensorflow/tsl/platform:statusor", |
| "//util/task:status", |
| ], |
| alwayslink = True, |
| ) |
| |
| pytype_strict_library( |
| name = "test_util", |
| testonly = True, |
| srcs = ["test_util.py"], |
| deps = [ |
| "//platforms/xla/mosaic/python/dialects:llo_dialect", |
| "//testing/pybase:parameterized", |
| "//third_party/py/hypothesis", |
| "//third_party/py/jax", |
| "//third_party/py/jax:mosaic", |
| "//third_party/py/mlir:func_dialect", |
| "//third_party/py/mlir:ir", |
| "//third_party/py/mlir:mhlo_dialect", |
| "//third_party/py/mlir/_mlir_libs:_mlirRegisterEverything", |
| "//third_party/py/numpy", |
| ], |
| ) |