blob: 4762ae517c80364e065826ebb4dc8bfd0a356118 [file] [log] [blame]
"""Build rules for Mosaic tests."""
load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test")
load("//third_party/bazel_skylib/lib:dicts.bzl", "dicts")
TARGETS = [
("vf_megachip_iss", ["notap"], ["--deepsea_version=viperfish " +
"--deepsea_chip_config_name=megachip_tccontrol " +
"--deepsea_platform_type=iss " +
"--iss_topology_mode=ISS_TOPOLOGY_ALL_CORES_SHARING_HBM " +
"--tpu_use_continuations=false"]),
# Do not target GRM on TAP to prevent timeouts
("vf_megachip_grm", ["notap"], ["--deepsea_version=viperfish " +
"--deepsea_chip_config_name=megachip_tccontrol " +
"--deepsea_platform_type=grm " +
"--tpu_use_continuations=true"]),
("vf_megachip_hardware", ["notap", "requires-viperfish"], [
"--deepsea_version=viperfish " +
"--deepsea_chip_config_name=megachip_tccontrol " +
"--deepsea_platform_type=hardware " +
"--tpu_use_continuations=true",
]),
]
def mosaic_sc_test(name, main, target_kwargs = {}, **kwargs):
"""Defines a set of targets that run the test on multiple TPU generations.
Args:
name: test name.
main: test file.
target_kwargs: A dict of dicts, indexed by target suffixes. The target specific
dict will be used to pass in additional kwargs.
**kwargs: Default arguments passed in to all generated py_test rules.
"""
for suffix, tag, args in TARGETS:
this_target_kwargs = target_kwargs.get(suffix, {})
for platform_type in ["iss", "grm"]:
if suffix.endswith(platform_type):
this_target_kwargs = dicts.add(target_kwargs.get(platform_type, {}))
break
this_target_kwargs = dicts.add(this_target_kwargs, kwargs)
py_test(
name = name + "_" + suffix,
main = main,
args = args + this_target_kwargs.pop("args", []),
srcs = [main],
tags = this_target_kwargs.pop("tags", []) + tag,
**this_target_kwargs
)