| """Common utilities for Mosaic tests.""" |
| |
| import unittest |
| |
| import hypothesis |
| import hypothesis.strategies as st |
| import jax |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import mhlo |
| import numpy as np |
| |
| from google3.platforms.xla.mosaic.python.dialects import llo |
| from google3.testing.pybase import parameterized |
| |
| |
| hypothesis.settings.register_profile( |
| "deterministic", |
| database=None, |
| derandomize=True, |
| deadline=None, |
| print_blob=True, |
| ) |
| hypothesis.settings.load_profile("deterministic") |
| |
| |
| class VectorFactory: |
| """A convenience class for constructing MLIR vector types. |
| |
| Don't create instances of this class yourself. Use the ones already provided |
| below (F32, I1, ...). |
| """ |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| def __call__(self): |
| return self.elem_thunk() |
| |
| |
| F32 = VectorFactory(ir.F32Type.get) |
| I1 = VectorFactory(lambda: ir.IntegerType.get_signless(1)) |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| I64 = VectorFactory(lambda: ir.IntegerType.get_signless(64)) |
| |
| |
| class MosaicTestCase(parameterized.TestCase): |
| """A specialization of TestCase for Mosaic tests. |
| |
| This class ensures that every test has a well defined MLIR context and default |
| location set. It also preloads the LLO and TPU dialects and passes into the |
| dialect. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| self.ctx = ir.Context() |
| self.ctx.__enter__() |
| self.loc = ir.Location.unknown() |
| self.loc.__enter__() |
| llo.register_llo_dialect(self.ctx) |
| tpu.register_dialect(self.ctx) |
| mhlo.register_mhlo_dialect(self.ctx) |
| self._np_rng_instance = None |
| |
| def tearDown(self): |
| self.loc.__exit__(None, None, None) |
| self.ctx.__exit__(None, None, None) |
| del self.loc, self.ctx |
| super().tearDown() |
| |
| @property |
| def _np_rng(self): |
| if (rng := self._np_rng_instance) is None: |
| rng = self._np_rng_instance = np.random.default_rng(1234) |
| return rng |
| |
| def normal(self, shape, dtype): |
| return self._np_rng.standard_normal(shape).astype(dtype) |
| |
| @property |
| def i1(self): |
| return ir.IntegerType.get_signless(1) |
| |
| @property |
| def i8(self): |
| return ir.IntegerType.get_signless(8) |
| |
| @property |
| def i16(self): |
| return ir.IntegerType.get_signless(16) |
| |
| @property |
| def i32(self): |
| return ir.IntegerType.get_signless(32) |
| |
| @property |
| def i64(self): |
| return ir.IntegerType.get_signless(64) |
| |
| @property |
| def f32(self): |
| return ir.F32Type.get() |
| |
| @property |
| def bf16(self): |
| return ir.BF16Type.get() |
| |
| @property |
| def f32_vreg(self): |
| return ir.VectorType.get((8, 128), self.f32) |
| |
| def f32v(self, *shape): |
| return ir.VectorType.get(shape, self.f32) |
| |
| def i32v(self, *shape): |
| return ir.VectorType.get(shape, self.i32) |
| |
| @property |
| def index(self): |
| return ir.IndexType.get() |
| |
| |
| def normalize_generation(generation: int | str) -> int: |
| if isinstance(generation, int): |
| return generation |
| if generation == "jellyfish": |
| return 2 |
| if generation == "dragonfish": |
| return 3 |
| if generation == "pufferfish": |
| return 4 |
| if generation == "viperfish": |
| return 5 |
| raise NotImplementedError(f"Unrecognized TPU generation: {generation}") |
| |
| |
| def only_tpus_from(generation: int | str) -> ...: |
| """Skips the test on older TPU generations.""" |
| generation = normalize_generation(generation) |
| def wrapper(test_orig): |
| def test(*args, **kwargs): |
| if generation_under_test() < generation: |
| raise unittest.SkipTest(f"Test requires TPU v{generation}") |
| return test_orig(*args, **kwargs) |
| return test |
| return wrapper |
| |
| |
| def generation_under_test(): |
| kind = jax.devices()[0].device_kind |
| if kind.endswith(" lite"): |
| kind = kind[:-len(" lite")] |
| assert kind[:-1] == "TPU v", kind |
| return int(kind[-1]) |
| |
| |
| @st.composite |
| def shapes(draw: st.DrawFn, allow_1d: bool = True) -> tuple[int, ...]: |
| if allow_1d: |
| one_dim = draw(st.booleans()) |
| if one_dim: |
| return (draw(st.integers(1, 32768)),) |
| |
| leading = draw(st.lists(st.integers(1, 4), min_size=0, max_size=2)) |
| minor = draw(st.integers(1, 512)) |
| second_minor = draw(st.integers(1, 64)) |
| return (*leading, second_minor, minor) |
| |
| |
| def and_flatmap(first, second) -> ...: |
| """Like standard Hypothesis flatmap, but returns a pair of both results.""" |
| return first.flatmap(lambda f: st.tuples(st.just(f), second(f))) |
| |
| |
| @st.composite |
| def aligned_shape( |
| draw: st.DrawFn, |
| rank: int, |
| min_factor: int, |
| max_factor: int, |
| *, |
| major: int = 8, |
| minor: int = 128, |
| ) -> tuple[int, ...]: |
| """Generate a random shape whose last 2 dims are aligned with (major, minor) tile layout.""" |
| assert rank > 1, f"{rank=} should be at least 2" |
| shape = draw( |
| st.lists( |
| st.integers(min_factor, max_factor), min_size=rank, max_size=rank |
| ) |
| ) |
| shape[-1] *= minor |
| shape[-2] *= major |
| return shape |
| |
| |
| def infer_tile_strides(mem_shape: tuple[int, ...], tile_shape: tuple[int, ...]): |
| """Calculate the tile strides based on the memref shape and tile.""" |
| tile_strides = [None] * len(mem_shape) |
| stride = 1 |
| for i in range(len(mem_shape) - 1, -1, -1): |
| tile_strides[i] = stride |
| tile_i = i + len(tile_shape) - len(mem_shape) |
| if 0 <= tile_i < len(tile_shape): |
| stride *= (mem_shape[i] + tile_shape[tile_i] - 1) // tile_shape[tile_i] |
| else: |
| stride *= mem_shape[i] |
| return tile_strides |
| |
| |
| def jax_to_mlir_dtype(jax_dtype): |
| jax_dtype = jnp.dtype(jax_dtype) |
| if jax_dtype == jnp.dtype(jnp.float32): |
| return ir.F32Type.get() |
| elif jax_dtype == jnp.dtype(jnp.bfloat16): |
| return ir.BF16Type.get() |
| elif jax_dtype == jnp.dtype(jnp.int4): |
| return ir.IntegerType.get_signless(4) |
| elif jnp.issubdtype(jax_dtype, jnp.integer): |
| return ir.IntegerType.get_signless(jax_dtype.itemsize * 8) |
| raise NotImplementedError(jax_dtype) |