| """Utilities for exporting benchmark results to MLCompass.""" |
| |
| import json |
| import os |
| import time |
| from typing import Any |
| |
| from absl import flags |
| from absl import logging |
| from absl.testing import absltest |
| from jax import random |
| import jax.numpy as jnp |
| import numpy as np |
| |
| from google3.perftools.accelerators.xprof.api.python import xprof_analysis_client |
| from google3.perftools.accelerators.xprof.api.python import xprof_session |
| from google3.platforms.xla.mosaic.examples import matmul |
| from google3.pyglib import gfile |
| |
| _BENCHMARK_OUTPUT_DIR = flags.DEFINE_string( |
| "benchmark_output_dir", default=None, help="Benchmark output directory." |
| ) |
| |
| |
| def get_xprof_metrics(xprof_session_id: str) -> dict[str, Any]: |
| """Extract the metrics that we care about from Xprof. |
| |
| Args: |
| xprof_session_id: ID of XProf session to obtain metrics from. |
| |
| Returns: |
| A dict with the metrics. For now, the only metric is "kernel_time". |
| """ |
| # TODO(tlongeri): There is similar code that could be unified in |
| # google3/third_party/py/jax_triton/google/pallas_tpu/ops/splash_attention/splash_attention_benchmark.py |
| # google3/learning/brain/research/megablox/benchmarks/common.py |
| xprof_client = xprof_analysis_client.XprofAnalysisClient() |
| _, trace = xprof_client.get_profile_data( |
| "trace_viewer.json", xprof_session_id |
| ) |
| jtrace = json.loads(trace) |
| relevant_events = [] |
| used_pid = None |
| for e in jtrace["traceEvents"]: |
| if not e["args"] or "run_id" not in e["args"]: |
| continue |
| if not e["name"].startswith("jit_"): |
| continue |
| if used_pid is None: |
| used_pid = e["pid"] |
| if e["pid"] != used_pid: |
| continue # Only gather events from one of the cores |
| relevant_events.append(e) |
| times = np.asarray([e["dur"] for e in relevant_events], dtype=np.float64) |
| return {"kernel_time": times.sum()} |
| |
| |
| def report_benchmark( |
| name: str = None, |
| succeeded: bool = True, |
| wall_time: float | None = None, |
| metrics: dict[str, Any] | None = None, |
| iters: int | None = None, |
| xprof_session_id: str | None = None, |
| extras: dict[str, Any] | None = None, |
| ) -> None: |
| """Writes benchmark results report. |
| |
| Based on _MLCompassBenchmark._report_benchmark_results in |
| //learning/deepmind/benchmarks/mlcompass/benchmark.py. |
| |
| Args: |
| name: Name of test to export data (required). |
| succeeded: Whether the test succeeded. |
| wall_time: Ellapsed time from beginning to end (required). |
| metrics: Additional numerical metrics. |
| iters: Number of iterations run |
| xprof_session_id: XProf session of benchmarks (optional). Some XProf |
| measurements are extracted and appended to metrics (see |
| get_xprof_metrics). |
| extras: Extra metrics or data. |
| |
| Raises TypeError if required arguments (name or wall_time) are missing. |
| |
| Results are reported as a JSON string with the following schema: |
| ``` |
| { |
| "name": <class.testMethod> |
| "succeeded": true / false |
| "wall_time": float (containing wall-time for the benchmark) |
| "metrics": { |
| "iters": int (if iters was provided) |
| "string" -> float map of other performance metrics |
| } |
| "extras": { |
| "string" -> "string" map containing anything else of interest |
| } |
| } |
| ``` |
| """ |
| if name is None: |
| raise TypeError("name is required") |
| if wall_time is None: |
| raise TypeError("wall_time is required") |
| if metrics is None: |
| metrics = {} |
| if xprof_session_id: |
| metrics.update(get_xprof_metrics(xprof_session_id)) |
| if iters is not None: |
| metrics["iters"] = iters |
| report = {} |
| report["name"] = name |
| report["succeeded"] = succeeded |
| report["wall_time"] = wall_time |
| if metrics: |
| report["metrics"] = metrics |
| if extras: |
| report["extras"] = extras |
| |
| logging.info("Benchmark report: %s", report) |
| |
| output_dir = _BENCHMARK_OUTPUT_DIR.value |
| if output_dir: |
| if not gfile.Exists(output_dir): |
| gfile.MakeDirs(output_dir) |
| |
| file_name = os.path.join(output_dir, name + ".json") |
| with gfile.GFile(file_name, "w") as fout: |
| json.dump(report, fout) |
| else: |
| print(report) |
| |
| |
| class MatmulBenchmarks(absltest.TestCase): |
| """Benchmarks for Mosaic vs. JAX on matmul.""" |
| |
| num_iters = 50 |
| n, m, k = 2048, 1024, 512 |
| tn, tk, tm = 512, 512, 512 |
| |
| @classmethod |
| def _report_benchmark(cls, name, **kwargs): |
| """Prepend class name to test and forward to global report_benchmark.""" |
| name = f"{cls.__name__}.{name}" |
| return report_benchmark(name=name, **kwargs) |
| |
| def test_matmul_mosaic(self): |
| """Benchmark of matmul using Mosaic.""" |
| custom_matmul = matmul.create_matmul( |
| self.n, self.m, self.k, self.tn, self.tm, self.tk |
| ) |
| |
| k1, k2 = random.split(random.PRNGKey(1234)) |
| x = random.normal(k1, (self.n, self.k), jnp.float32) |
| y = random.normal(k2, (self.k, self.m), jnp.float32) |
| x = x.astype(jnp.bfloat16).astype(jnp.float32) |
| y = y.astype(jnp.bfloat16).astype(jnp.float32) |
| |
| # Warm-up |
| custom_matmul(x, y).block_until_ready() |
| |
| xprof_sess = xprof_session.XprofSession() |
| xprof_sess.start_session(enable_python_tracer=True, host_trace_level=2) |
| s = time.perf_counter() |
| for _ in range(self.num_iters): |
| custom_matmul(x, y).block_until_ready() |
| e = time.perf_counter() |
| xprof_session_id = xprof_sess.end_session_and_get_session_id() |
| |
| self._report_benchmark( |
| name="test_matmul_mosaic", |
| wall_time=e - s, |
| xprof_session_id=xprof_session_id, |
| ) |
| |
| def test_matmul_jax(self): |
| """Benchmark of matmul using JAX.""" |
| k1, k2 = random.split(random.PRNGKey(1234)) |
| x = random.normal(k1, (self.n, self.k), jnp.float32) |
| y = random.normal(k2, (self.k, self.m), jnp.float32) |
| x = x.astype(jnp.bfloat16).astype(jnp.float32) |
| y = y.astype(jnp.bfloat16).astype(jnp.float32) |
| |
| # Warm-up |
| (x @ y).block_until_ready() |
| |
| xprof_sess = xprof_session.XprofSession() |
| xprof_sess.start_session( |
| device_name="jellyfish", enable_python_tracer=True, host_trace_level=2 |
| ) |
| s = time.perf_counter() |
| for _ in range(self.num_iters): |
| (x @ y).block_until_ready() |
| e = time.perf_counter() |
| xprof_session_id = xprof_sess.end_session_and_get_session_id() |
| |
| self._report_benchmark( |
| name="test_matmul_jax", |
| wall_time=e - s, |
| xprof_session_id=xprof_session_id, |
| ) |
| |
| |
| if __name__ == "__main__": |
| absltest.main() |