| """Build rules for Mosaic tests.""" |
| |
| load("//third_party/bazel_rules/rules_python/python:py_test.bzl", "py_test") |
| load("//third_party/bazel_skylib/lib:dicts.bzl", "dicts") |
| |
| TARGETS = [ |
| ("vf_megachip_iss", ["notap"], ["--deepsea_version=viperfish " + |
| "--deepsea_chip_config_name=megachip_tccontrol " + |
| "--deepsea_platform_type=iss " + |
| "--iss_topology_mode=ISS_TOPOLOGY_ALL_CORES_SHARING_HBM " + |
| "--tpu_use_continuations=false"]), |
| # Do not target GRM on TAP to prevent timeouts |
| ("vf_megachip_grm", ["notap"], ["--deepsea_version=viperfish " + |
| "--deepsea_chip_config_name=megachip_tccontrol " + |
| "--deepsea_platform_type=grm " + |
| "--tpu_use_continuations=true"]), |
| ("vf_megachip_hardware", ["notap", "requires-viperfish"], [ |
| "--deepsea_version=viperfish " + |
| "--deepsea_chip_config_name=megachip_tccontrol " + |
| "--deepsea_platform_type=hardware " + |
| "--tpu_use_continuations=true", |
| ]), |
| ] |
| |
| def mosaic_sc_test(name, main, target_kwargs = {}, **kwargs): |
| """Defines a set of targets that run the test on multiple TPU generations. |
| |
| Args: |
| name: test name. |
| main: test file. |
| target_kwargs: A dict of dicts, indexed by target suffixes. The target specific |
| dict will be used to pass in additional kwargs. |
| **kwargs: Default arguments passed in to all generated py_test rules. |
| """ |
| for suffix, tag, args in TARGETS: |
| this_target_kwargs = target_kwargs.get(suffix, {}) |
| for platform_type in ["iss", "grm"]: |
| if suffix.endswith(platform_type): |
| this_target_kwargs = dicts.add(target_kwargs.get(platform_type, {})) |
| break |
| this_target_kwargs = dicts.add(this_target_kwargs, kwargs) |
| py_test( |
| name = name + "_" + suffix, |
| main = main, |
| args = args + this_target_kwargs.pop("args", []), |
| srcs = [main], |
| tags = this_target_kwargs.pop("tags", []) + tag, |
| **this_target_kwargs |
| ) |