blob: 4eb75132485afd18a04fc06d078373e1cc6d6ffe [file] [log] [blame] [edit]
"""Utilities for exporting benchmark results to MLCompass."""
import json
import os
import time
from typing import Any
from absl import flags
from absl import logging
from absl.testing import absltest
from jax import random
import jax.numpy as jnp
import numpy as np
from google3.perftools.accelerators.xprof.api.python import xprof_analysis_client
from google3.perftools.accelerators.xprof.api.python import xprof_session
from google3.platforms.xla.mosaic.examples import matmul
from google3.pyglib import gfile
_BENCHMARK_OUTPUT_DIR = flags.DEFINE_string(
"benchmark_output_dir", default=None, help="Benchmark output directory."
)
def get_xprof_metrics(xprof_session_id: str) -> dict[str, Any]:
"""Extract the metrics that we care about from Xprof.
Args:
xprof_session_id: ID of XProf session to obtain metrics from.
Returns:
A dict with the metrics. For now, the only metric is "kernel_time".
"""
# TODO(tlongeri): There is similar code that could be unified in
# google3/third_party/py/jax_triton/google/pallas_tpu/ops/splash_attention/splash_attention_benchmark.py
# google3/learning/brain/research/megablox/benchmarks/common.py
xprof_client = xprof_analysis_client.XprofAnalysisClient()
_, trace = xprof_client.get_profile_data(
"trace_viewer.json", xprof_session_id
)
jtrace = json.loads(trace)
relevant_events = []
used_pid = None
for e in jtrace["traceEvents"]:
if not e["args"] or "run_id" not in e["args"]:
continue
if not e["name"].startswith("jit_"):
continue
if used_pid is None:
used_pid = e["pid"]
if e["pid"] != used_pid:
continue # Only gather events from one of the cores
relevant_events.append(e)
times = np.asarray([e["dur"] for e in relevant_events], dtype=np.float64)
return {"kernel_time": times.sum()}
def report_benchmark(
name: str = None,
succeeded: bool = True,
wall_time: float | None = None,
metrics: dict[str, Any] | None = None,
iters: int | None = None,
xprof_session_id: str | None = None,
extras: dict[str, Any] | None = None,
) -> None:
"""Writes benchmark results report.
Based on _MLCompassBenchmark._report_benchmark_results in
//learning/deepmind/benchmarks/mlcompass/benchmark.py.
Args:
name: Name of test to export data (required).
succeeded: Whether the test succeeded.
wall_time: Ellapsed time from beginning to end (required).
metrics: Additional numerical metrics.
iters: Number of iterations run
xprof_session_id: XProf session of benchmarks (optional). Some XProf
measurements are extracted and appended to metrics (see
get_xprof_metrics).
extras: Extra metrics or data.
Raises TypeError if required arguments (name or wall_time) are missing.
Results are reported as a JSON string with the following schema:
```
{
"name": <class.testMethod>
"succeeded": true / false
"wall_time": float (containing wall-time for the benchmark)
"metrics": {
"iters": int (if iters was provided)
"string" -> float map of other performance metrics
}
"extras": {
"string" -> "string" map containing anything else of interest
}
}
```
"""
if name is None:
raise TypeError("name is required")
if wall_time is None:
raise TypeError("wall_time is required")
if metrics is None:
metrics = {}
if xprof_session_id:
metrics.update(get_xprof_metrics(xprof_session_id))
if iters is not None:
metrics["iters"] = iters
report = {}
report["name"] = name
report["succeeded"] = succeeded
report["wall_time"] = wall_time
if metrics:
report["metrics"] = metrics
if extras:
report["extras"] = extras
logging.info("Benchmark report: %s", report)
output_dir = _BENCHMARK_OUTPUT_DIR.value
if output_dir:
if not gfile.Exists(output_dir):
gfile.MakeDirs(output_dir)
file_name = os.path.join(output_dir, name + ".json")
with gfile.GFile(file_name, "w") as fout:
json.dump(report, fout)
else:
print(report)
class MatmulBenchmarks(absltest.TestCase):
"""Benchmarks for Mosaic vs. JAX on matmul."""
num_iters = 50
n, m, k = 2048, 1024, 512
tn, tk, tm = 512, 512, 512
@classmethod
def _report_benchmark(cls, name, **kwargs):
"""Prepend class name to test and forward to global report_benchmark."""
name = f"{cls.__name__}.{name}"
return report_benchmark(name=name, **kwargs)
def test_matmul_mosaic(self):
"""Benchmark of matmul using Mosaic."""
custom_matmul = matmul.create_matmul(
self.n, self.m, self.k, self.tn, self.tm, self.tk
)
k1, k2 = random.split(random.PRNGKey(1234))
x = random.normal(k1, (self.n, self.k), jnp.float32)
y = random.normal(k2, (self.k, self.m), jnp.float32)
x = x.astype(jnp.bfloat16).astype(jnp.float32)
y = y.astype(jnp.bfloat16).astype(jnp.float32)
# Warm-up
custom_matmul(x, y).block_until_ready()
xprof_sess = xprof_session.XprofSession()
xprof_sess.start_session(enable_python_tracer=True, host_trace_level=2)
s = time.perf_counter()
for _ in range(self.num_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
xprof_session_id = xprof_sess.end_session_and_get_session_id()
self._report_benchmark(
name="test_matmul_mosaic",
wall_time=e - s,
xprof_session_id=xprof_session_id,
)
def test_matmul_jax(self):
"""Benchmark of matmul using JAX."""
k1, k2 = random.split(random.PRNGKey(1234))
x = random.normal(k1, (self.n, self.k), jnp.float32)
y = random.normal(k2, (self.k, self.m), jnp.float32)
x = x.astype(jnp.bfloat16).astype(jnp.float32)
y = y.astype(jnp.bfloat16).astype(jnp.float32)
# Warm-up
(x @ y).block_until_ready()
xprof_sess = xprof_session.XprofSession()
xprof_sess.start_session(
device_name="jellyfish", enable_python_tracer=True, host_trace_level=2
)
s = time.perf_counter()
for _ in range(self.num_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
xprof_session_id = xprof_sess.end_session_and_get_session_id()
self._report_benchmark(
name="test_matmul_jax",
wall_time=e - s,
xprof_session_id=xprof_session_id,
)
if __name__ == "__main__":
absltest.main()