blob: 65ce7524017145f30a6b1ab3c7d80603a4cfb528 [file] [log] [blame] [edit]
"""Common utilities for Mosaic tests."""
import unittest
import hypothesis
import hypothesis.strategies as st
import jax
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import mhlo
import numpy as np
from google3.platforms.xla.mosaic.python.dialects import llo
from google3.testing.pybase import parameterized
hypothesis.settings.register_profile(
"deterministic",
database=None,
derandomize=True,
deadline=None,
print_blob=True,
)
hypothesis.settings.load_profile("deterministic")
class VectorFactory:
"""A convenience class for constructing MLIR vector types.
Don't create instances of this class yourself. Use the ones already provided
below (F32, I1, ...).
"""
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
def __call__(self):
return self.elem_thunk()
F32 = VectorFactory(ir.F32Type.get)
I1 = VectorFactory(lambda: ir.IntegerType.get_signless(1))
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
I64 = VectorFactory(lambda: ir.IntegerType.get_signless(64))
class MosaicTestCase(parameterized.TestCase):
"""A specialization of TestCase for Mosaic tests.
This class ensures that every test has a well defined MLIR context and default
location set. It also preloads the LLO and TPU dialects and passes into the
dialect.
"""
def setUp(self):
super().setUp()
self.ctx = ir.Context()
self.ctx.__enter__()
self.loc = ir.Location.unknown()
self.loc.__enter__()
llo.register_llo_dialect(self.ctx)
tpu.register_dialect(self.ctx)
mhlo.register_mhlo_dialect(self.ctx)
self._np_rng_instance = None
def tearDown(self):
self.loc.__exit__(None, None, None)
self.ctx.__exit__(None, None, None)
del self.loc, self.ctx
super().tearDown()
@property
def _np_rng(self):
if (rng := self._np_rng_instance) is None:
rng = self._np_rng_instance = np.random.default_rng(1234)
return rng
def normal(self, shape, dtype):
return self._np_rng.standard_normal(shape).astype(dtype)
@property
def i1(self):
return ir.IntegerType.get_signless(1)
@property
def i8(self):
return ir.IntegerType.get_signless(8)
@property
def i16(self):
return ir.IntegerType.get_signless(16)
@property
def i32(self):
return ir.IntegerType.get_signless(32)
@property
def i64(self):
return ir.IntegerType.get_signless(64)
@property
def f32(self):
return ir.F32Type.get()
@property
def bf16(self):
return ir.BF16Type.get()
@property
def f32_vreg(self):
return ir.VectorType.get((8, 128), self.f32)
def f32v(self, *shape):
return ir.VectorType.get(shape, self.f32)
def i32v(self, *shape):
return ir.VectorType.get(shape, self.i32)
@property
def index(self):
return ir.IndexType.get()
def normalize_generation(generation: int | str) -> int:
if isinstance(generation, int):
return generation
if generation == "jellyfish":
return 2
if generation == "dragonfish":
return 3
if generation == "pufferfish":
return 4
if generation == "viperfish":
return 5
raise NotImplementedError(f"Unrecognized TPU generation: {generation}")
def only_tpus_from(generation: int | str) -> ...:
"""Skips the test on older TPU generations."""
generation = normalize_generation(generation)
def wrapper(test_orig):
def test(*args, **kwargs):
if generation_under_test() < generation:
raise unittest.SkipTest(f"Test requires TPU v{generation}")
return test_orig(*args, **kwargs)
return test
return wrapper
def generation_under_test():
kind = jax.devices()[0].device_kind
if kind.endswith(" lite"):
kind = kind[:-len(" lite")]
assert kind[:-1] == "TPU v", kind
return int(kind[-1])
@st.composite
def shapes(draw: st.DrawFn, allow_1d: bool = True) -> tuple[int, ...]:
if allow_1d:
one_dim = draw(st.booleans())
if one_dim:
return (draw(st.integers(1, 32768)),)
leading = draw(st.lists(st.integers(1, 4), min_size=0, max_size=2))
minor = draw(st.integers(1, 512))
second_minor = draw(st.integers(1, 64))
return (*leading, second_minor, minor)
def and_flatmap(first, second) -> ...:
"""Like standard Hypothesis flatmap, but returns a pair of both results."""
return first.flatmap(lambda f: st.tuples(st.just(f), second(f)))
@st.composite
def aligned_shape(
draw: st.DrawFn,
rank: int,
min_factor: int,
max_factor: int,
*,
major: int = 8,
minor: int = 128,
) -> tuple[int, ...]:
"""Generate a random shape whose last 2 dims are aligned with (major, minor) tile layout."""
assert rank > 1, f"{rank=} should be at least 2"
shape = draw(
st.lists(
st.integers(min_factor, max_factor), min_size=rank, max_size=rank
)
)
shape[-1] *= minor
shape[-2] *= major
return shape
def infer_tile_strides(mem_shape: tuple[int, ...], tile_shape: tuple[int, ...]):
"""Calculate the tile strides based on the memref shape and tile."""
tile_strides = [None] * len(mem_shape)
stride = 1
for i in range(len(mem_shape) - 1, -1, -1):
tile_strides[i] = stride
tile_i = i + len(tile_shape) - len(mem_shape)
if 0 <= tile_i < len(tile_shape):
stride *= (mem_shape[i] + tile_shape[tile_i] - 1) // tile_shape[tile_i]
else:
stride *= mem_shape[i]
return tile_strides
def jax_to_mlir_dtype(jax_dtype):
jax_dtype = jnp.dtype(jax_dtype)
if jax_dtype == jnp.dtype(jnp.float32):
return ir.F32Type.get()
elif jax_dtype == jnp.dtype(jnp.bfloat16):
return ir.BF16Type.get()
elif jax_dtype == jnp.dtype(jnp.int4):
return ir.IntegerType.get_signless(4)
elif jnp.issubdtype(jax_dtype, jnp.integer):
return ir.IntegerType.get_signless(jax_dtype.itemsize * 8)
raise NotImplementedError(jax_dtype)