| r"""Testing utility for verifying that a change doesn't affect generated code. |
| |
| This tool is mainly useful to gain more confidence in compiler cleanups that |
| should be non-functional changes. In the baseline mode, it gathers the LLO |
| generated at HEAD for all examples that can be run locally. For example: |
| |
| blaze run //platforms/xla/mosaic/tests:llo_dump_diff -- \ |
| --mode baseline --baseline_path=/tmp/mosaic_llo |
| |
| In diff mode it uses the current workspace to generate the LLO dumps and |
| compares them against a previously generated baseline: |
| |
| blaze run //platforms/xla/mosaic/tests:llo_dump_diff -- \ |
| --mode diff --baseline_path=/tmp/mosaic_llo |
| """ |
| |
| from collections.abc import Sequence |
| import os |
| import subprocess |
| import tempfile |
| |
| from absl import app |
| from absl import flags |
| from absl import logging |
| |
| _MODE = flags.DEFINE_enum( |
| "mode", |
| None, |
| ["baseline", "diff"], |
| "Generate baseline LLO dumps or compare against them.", |
| required=True, |
| ) |
| _BASELINE = flags.DEFINE_string( |
| "baseline", None, "CL to use for baseline generation (HEAD if unspecified)") |
| _BASELINE_PATH = flags.DEFINE_string( |
| "baseline_path", None, "Path to the baseline files", |
| required=True) |
| |
| |
| def gather_llo(g3, destination): |
| """Gathers LLO dumps for Mosaic ops from examples.""" |
| assert os.path.exists(g3) |
| assert os.path.exists(destination) |
| targets_bytes = subprocess.check_output( |
| ["blaze", "query", "//platforms/xla/mosaic/examples/..."], cwd=g3) |
| targets = targets_bytes.decode("ascii").split() |
| # Only use examples that we can run locally and that do use Mosaic. |
| iss_targets = [t for t in targets if t.endswith("iss") and "llo" not in t] |
| logging.info("Using targets: %s", iss_targets) |
| subprocess.check_call(["blaze", "build"] + iss_targets, cwd=g3) |
| for target in iss_targets: |
| binary = target[target.rindex(":") + 1:] |
| with tempfile.TemporaryDirectory() as target_temp: |
| subprocess.check_call(["blaze", "run", target, "--", |
| "--xla_mosaic_dump_to=" + target_temp], cwd=g3) |
| dump_files = os.listdir(target_temp) |
| if len(dump_files) != 1: |
| raise RuntimeError("Expected to find exactly one Mosaic call") |
| (dump_file,) = dump_files |
| os.rename(src=os.path.join(target_temp, dump_file), |
| dst=os.path.join(destination, binary)) |
| |
| |
| def baseline_main(): |
| """Gathers LLO dumps for Mosaic ops from examples at a baseline commit.""" |
| with tempfile.TemporaryDirectory() as workspace: |
| logging.info("Created new workspace: %s", workspace) |
| subprocess.check_call(["hg", "gclone", workspace, |
| "--include", "//platforms/xla/mosaic"]) |
| if _BASELINE.value is not None: |
| subprocess.check_call(["hg", "sync", "@" + _BASELINE.value]) |
| g3 = os.path.join(workspace, "google3") |
| try: |
| os.makedirs(_BASELINE_PATH.value, exist_ok=False) |
| except FileExistsError: |
| raise RuntimeError( |
| f"Baseline directory already exists: {_BASELINE_PATH.value}" |
| ) from None |
| gather_llo(g3, _BASELINE_PATH.value) |
| |
| |
| def diff_main(): |
| """Gathers LLO dumps from examples and compares them to baseline.""" |
| # See https://g3doc.corp.google.com/devtools/blaze/g3doc/user-manual.html#run |
| if (g3 := os.getenv("BUILD_WORKSPACE_DIRECTORY")) is None: |
| raise RuntimeError("Failed to infer the google3 workspace") |
| logging.info("Gathering new LLO dumps from %s", g3) |
| with tempfile.TemporaryDirectory() as llo_dir: |
| gather_llo(g3, llo_dir) |
| for f in os.listdir(llo_dir): |
| try: |
| subprocess.check_call(["diff", os.path.join(llo_dir, f), |
| os.path.join(_BASELINE_PATH.value, f)]) |
| logging.info("OK: %s", f) |
| except subprocess.CalledProcessError: |
| logging.error("GENERATED DIFFERENT CODE: %s", f) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) > 1: |
| raise app.UsageError("Too many command-line arguments.") |
| if _MODE.value == "baseline": |
| return baseline_main() |
| elif _MODE.value == "diff": |
| return diff_main() |
| raise ValueError(f"Invalid mode: {_MODE.value}") |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |