| """A Mosaic kernel implementing the ROPE embedding kernel.""" |
| |
| import functools |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax import random |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import math |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| class _VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| F32 = _VectorFactory(ir.F32Type.get) |
| I32 = _VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| |
| |
| def _build_kernel(full_shape: tuple[int, int, int, int], |
| tile_shape: tuple[int, int, int, int], |
| constant_fold_freq: bool | None = None, |
| seq_minor: bool = False): |
| """Builds the Mosaic ROPE kernel. |
| |
| Arguments: |
| full_shape: A size 4 tuple containing: the batch size, number of heads, |
| sequence length, and model dimension. |
| tile_shape: A size 4 tuple of the same structure as full_shape. |
| constant_fold_freq: Whether to constant fold the computation of frequency. |
| The impact on runtime should be quite negligible, but it can help with |
| precision (float exponentials are imprecise on TPUs). |
| seq_minor: If True, the expected layout of inputs (and shape arguments) is |
| with sequence length and model dimension swapped. |
| |
| Returns: |
| A JAX function implementing the ROPE kernel. |
| """ |
| if seq_minor: |
| _, _, d_model, _ = full_shape # [batch, num_heads, d_model, seq_len] |
| else: |
| _, _, _, d_model = full_shape # [batch, num_heads, seq_len, d_model] |
| |
| # Tiling over d_model is not implemented. |
| if tile_shape[(-1 - seq_minor)] != d_model: |
| raise NotImplementedError |
| if constant_fold_freq is None: |
| constant_fold_freq = d_model <= 512 |
| |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| |
| @func.FuncOp.from_py_func( |
| i32, i32, i32, i32, |
| ir.MemRefType.get(tile_shape, f32), |
| ir.MemRefType.get(tile_shape, f32), |
| name="main", |
| ) |
| def kernel_main(i, j, k, l, x_ref, y_ref, func_op): # pylint: disable=unused-argument |
| constants = {} |
| def c(val, ty=None): |
| ty = ir.IndexType.get() if ty is None else ty |
| if (val, ty) not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| constants[(val, ty)] = arith.ConstantOp( |
| ty, ir.IntegerAttr.get(ty, val)) |
| return constants[(val, ty)] |
| |
| if seq_minor: |
| loop_tile_shape = (1, 1, tile_shape[-2], 128) |
| tile_size_model, tile_size_seq = tile_shape[-2:] |
| seq_loop_step = 128 |
| else: |
| loop_tile_shape = (1, 1, 8, tile_shape[-1]) |
| tile_size_seq, tile_size_model = tile_shape[-2:] |
| seq_loop_step = 8 |
| |
| def i32_splat(val): |
| return arith.ConstantOp( |
| I32[loop_tile_shape], |
| ir.DenseElementsAttr.get_splat( |
| I32[loop_tile_shape], ir.IntegerAttr.get(i32, val))) |
| seq_tile = l if seq_minor else k |
| s_base = arith.MulIOp(seq_tile, c(tile_size_seq, i32)) |
| |
| # This only varies on the model axis. |
| if constant_fold_freq: |
| exponent = np.arange(0, d_model, 2, dtype=np.float32) / d_model |
| freq_1d = np.repeat(10000 ** -exponent, 2, axis=-1) |
| if seq_minor: |
| freq_val = np.broadcast_to(freq_1d[None, None, :, None], |
| loop_tile_shape) |
| else: |
| freq_val = np.broadcast_to(freq_1d, loop_tile_shape) |
| freq = arith.ConstantOp( |
| F32[loop_tile_shape], |
| ir.DenseFPElementsAttr.get(np.ascontiguousarray(freq_val), |
| type=F32[loop_tile_shape]), |
| ) |
| else: |
| if seq_minor: |
| raise NotImplementedError |
| c2 = i32_splat(2) |
| iota = arith.MulIOp( |
| arith.DivSIOp( |
| tpu.IotaOp(I32[loop_tile_shape], dimension=3), c2), c2) |
| c_d_model_inv = arith.ConstantOp( |
| F32[loop_tile_shape], |
| ir.DenseElementsAttr.get_splat( |
| F32[loop_tile_shape], ir.FloatAttr.get_f32(1 / d_model) |
| )) |
| exponent = arith.MulFOp( |
| arith.SIToFPOp(F32[loop_tile_shape], iota), c_d_model_inv) |
| c10000 = arith.ConstantOp( |
| F32[loop_tile_shape], |
| ir.DenseElementsAttr.get_splat(F32[loop_tile_shape], |
| ir.FloatAttr.get_f32(10000.))) |
| freq = math.PowFOp(c10000, arith.NegFOp(exponent)) |
| |
| gather_indices = ( |
| np.arange(tile_size_model, dtype=np.int32) |
| .reshape(-1, 2)[:, ::-1] |
| .reshape(-1) |
| ) |
| gather_indices = ir.DenseI32ArrayAttr.get(gather_indices) |
| lsb_mask = i32_splat(0x00000001) |
| neg_mask = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 0), # eq |
| arith.AndIOp( |
| tpu.IotaOp(I32[loop_tile_shape], dimension=(3 - seq_minor)), |
| lsb_mask), |
| i32_splat(0)) |
| |
| if tile_size_seq % seq_loop_step != 0: |
| raise NotImplementedError |
| for s_offset in range(0, tile_size_seq, seq_loop_step): |
| s = arith.AddIOp(s_base, c(s_offset, i32)) |
| pos = arith.SIToFPOp( |
| F32[loop_tile_shape], |
| arith.AddIOp(tpu.IotaOp(I32[loop_tile_shape], |
| dimension=(2 + seq_minor)), |
| vector.BroadcastOp(I32[loop_tile_shape], s))) |
| adj_pos = arith.MulFOp(pos, freq) |
| # TODO(apaszke): Consider constant folding this if seq_len is small. |
| adj_pos_cos = math.CosOp(adj_pos) |
| adj_pos_sin = math.SinOp(adj_pos) |
| for b in range(tile_shape[0]): |
| for h in range(tile_shape[1]): |
| if seq_minor: |
| indices = [c(b), c(h), c(0), c(s_offset)] |
| else: |
| indices = [c(b), c(h), c(s_offset), c(0)] |
| x = vector.LoadOp( |
| F32[loop_tile_shape], x_ref, indices) |
| x_rotated_no_neg = tpu.GatherOp( |
| F32[loop_tile_shape], x, gather_indices, dimension=(3 - seq_minor) |
| ) |
| x_rotated = arith.SelectOp( |
| neg_mask, arith.NegFOp(x_rotated_no_neg), x_rotated_no_neg) |
| result = arith.AddFOp(arith.MulFOp(x, adj_pos_cos), |
| arith.MulFOp(x_rotated, adj_pos_sin)) |
| vector.StoreOp(result, y_ref, indices) |
| |
| f = kernel_main.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [s // t for s, t in zip(full_shape, tile_shape)]) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get( |
| [ir.Attribute.parse("#tpu.dimension_semantics<parallel>")] * 4) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get(tile_shape), |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k, q) -> (n, m, k, q)>"), |
| }), |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get(tile_shape), |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k, q) -> (n, m, k, q)>"), |
| }) |
| ]) |
| |
| assert f.verify(), f |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| |
| return mosaic.as_tpu_kernel( |
| m, out_type=core.ShapedArray(full_shape, jnp.float32)) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["tile_shape", "seq_minor"]) |
| def rope( |
| x: jax.Array, |
| tile_shape: tuple[int, int, int, int] | None = None, |
| seq_minor: bool = False, |
| ) -> jax.Array: |
| if tile_shape is None: |
| tile_shape = (x.shape[0], x.shape[1], min(128, x.shape[2]), x.shape[3]) |
| with ir.Context() as ctx, ir.Location.unknown(): |
| tpu.register_dialect(ctx) |
| kernel = _build_kernel(x.shape, tile_shape, seq_minor=seq_minor) |
| return kernel(x) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["seq_minor"]) |
| def rope_baseline(x: jax.Array, seq_minor: bool = False) -> jax.Array: # pylint: disable=missing-function-docstring |
| if seq_minor: |
| x = jnp.swapaxes(x, -1, -2) |
| _, _, seq_len, head_dim = x.shape # [batch, num_heads, seq_len, head_dim] |
| exponent = jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim |
| inv_freq = 10000 ** exponent |
| pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None] |
| adj_pos = jnp.repeat(pos / inv_freq[None, None, None, :], 2, axis=-1) |
| x_even, x_odd = x[..., ::2], x[..., 1::2] |
| x_rotated = jnp.stack((-x_odd, x_even), axis=-1).reshape(x.shape) |
| y = x * jnp.cos(adj_pos) + x_rotated * jnp.sin(adj_pos) |
| if seq_minor: |
| y = jnp.swapaxes(y, -1, -2) |
| return y |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 15000 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| batch_size = 32 |
| seq_len = 1024 |
| num_heads = 8 |
| d_model = 64 |
| seq_minor = True |
| |
| if seq_minor: |
| full_shape = (batch_size, num_heads, d_model, seq_len) |
| tile_shape = (8, num_heads, d_model, 128) |
| else: |
| full_shape = (batch_size, num_heads, seq_len, d_model) |
| tile_shape = (batch_size, num_heads, 128, d_model) |
| x = random.normal(random.PRNGKey(1234), full_shape, jnp.float32) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| baseline = functools.partial(rope_baseline, seq_minor=seq_minor) |
| kernel = functools.partial(rope, tile_shape=tile_shape, seq_minor=seq_minor) |
| np.testing.assert_allclose(baseline(x), kernel(x), atol=2e-4, rtol=2e-4) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| kernel(x).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| baseline(x).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print("Run-time ratio (XLA/Mosaic): ", d2 / d1) |
| if min(d1, d2) < 1: |
| print( |
| f"WARNING: Only {min(d1, d2)}s spent benchmarking one of the versions!" |
| ) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |