blob: f438ec4c722affe36fb6fe3c374ba8841e4b865b [file] [log] [blame] [edit]
"""A Mosaic kernel implementing the ROPE embedding kernel."""
import functools
import time
from typing import Sequence
from absl import app
import jax
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import math
from mlir.dialects import vector
import numpy as np
class _VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = _VectorFactory(ir.F32Type.get)
I32 = _VectorFactory(lambda: ir.IntegerType.get_signless(32))
def _build_kernel(full_shape: tuple[int, int, int, int],
tile_shape: tuple[int, int, int, int],
constant_fold_freq: bool | None = None,
seq_minor: bool = False):
"""Builds the Mosaic ROPE kernel.
Arguments:
full_shape: A size 4 tuple containing: the batch size, number of heads,
sequence length, and model dimension.
tile_shape: A size 4 tuple of the same structure as full_shape.
constant_fold_freq: Whether to constant fold the computation of frequency.
The impact on runtime should be quite negligible, but it can help with
precision (float exponentials are imprecise on TPUs).
seq_minor: If True, the expected layout of inputs (and shape arguments) is
with sequence length and model dimension swapped.
Returns:
A JAX function implementing the ROPE kernel.
"""
if seq_minor:
_, _, d_model, _ = full_shape # [batch, num_heads, d_model, seq_len]
else:
_, _, _, d_model = full_shape # [batch, num_heads, seq_len, d_model]
# Tiling over d_model is not implemented.
if tile_shape[(-1 - seq_minor)] != d_model:
raise NotImplementedError
if constant_fold_freq is None:
constant_fold_freq = d_model <= 512
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(
i32, i32, i32, i32,
ir.MemRefType.get(tile_shape, f32),
ir.MemRefType.get(tile_shape, f32),
name="main",
)
def kernel_main(i, j, k, l, x_ref, y_ref, func_op): # pylint: disable=unused-argument
constants = {}
def c(val, ty=None):
ty = ir.IndexType.get() if ty is None else ty
if (val, ty) not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
constants[(val, ty)] = arith.ConstantOp(
ty, ir.IntegerAttr.get(ty, val))
return constants[(val, ty)]
if seq_minor:
loop_tile_shape = (1, 1, tile_shape[-2], 128)
tile_size_model, tile_size_seq = tile_shape[-2:]
seq_loop_step = 128
else:
loop_tile_shape = (1, 1, 8, tile_shape[-1])
tile_size_seq, tile_size_model = tile_shape[-2:]
seq_loop_step = 8
def i32_splat(val):
return arith.ConstantOp(
I32[loop_tile_shape],
ir.DenseElementsAttr.get_splat(
I32[loop_tile_shape], ir.IntegerAttr.get(i32, val)))
seq_tile = l if seq_minor else k
s_base = arith.MulIOp(seq_tile, c(tile_size_seq, i32))
# This only varies on the model axis.
if constant_fold_freq:
exponent = np.arange(0, d_model, 2, dtype=np.float32) / d_model
freq_1d = np.repeat(10000 ** -exponent, 2, axis=-1)
if seq_minor:
freq_val = np.broadcast_to(freq_1d[None, None, :, None],
loop_tile_shape)
else:
freq_val = np.broadcast_to(freq_1d, loop_tile_shape)
freq = arith.ConstantOp(
F32[loop_tile_shape],
ir.DenseFPElementsAttr.get(np.ascontiguousarray(freq_val),
type=F32[loop_tile_shape]),
)
else:
if seq_minor:
raise NotImplementedError
c2 = i32_splat(2)
iota = arith.MulIOp(
arith.DivSIOp(
tpu.IotaOp(I32[loop_tile_shape], dimension=3), c2), c2)
c_d_model_inv = arith.ConstantOp(
F32[loop_tile_shape],
ir.DenseElementsAttr.get_splat(
F32[loop_tile_shape], ir.FloatAttr.get_f32(1 / d_model)
))
exponent = arith.MulFOp(
arith.SIToFPOp(F32[loop_tile_shape], iota), c_d_model_inv)
c10000 = arith.ConstantOp(
F32[loop_tile_shape],
ir.DenseElementsAttr.get_splat(F32[loop_tile_shape],
ir.FloatAttr.get_f32(10000.)))
freq = math.PowFOp(c10000, arith.NegFOp(exponent))
gather_indices = (
np.arange(tile_size_model, dtype=np.int32)
.reshape(-1, 2)[:, ::-1]
.reshape(-1)
)
gather_indices = ir.DenseI32ArrayAttr.get(gather_indices)
lsb_mask = i32_splat(0x00000001)
neg_mask = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0), # eq
arith.AndIOp(
tpu.IotaOp(I32[loop_tile_shape], dimension=(3 - seq_minor)),
lsb_mask),
i32_splat(0))
if tile_size_seq % seq_loop_step != 0:
raise NotImplementedError
for s_offset in range(0, tile_size_seq, seq_loop_step):
s = arith.AddIOp(s_base, c(s_offset, i32))
pos = arith.SIToFPOp(
F32[loop_tile_shape],
arith.AddIOp(tpu.IotaOp(I32[loop_tile_shape],
dimension=(2 + seq_minor)),
vector.BroadcastOp(I32[loop_tile_shape], s)))
adj_pos = arith.MulFOp(pos, freq)
# TODO(apaszke): Consider constant folding this if seq_len is small.
adj_pos_cos = math.CosOp(adj_pos)
adj_pos_sin = math.SinOp(adj_pos)
for b in range(tile_shape[0]):
for h in range(tile_shape[1]):
if seq_minor:
indices = [c(b), c(h), c(0), c(s_offset)]
else:
indices = [c(b), c(h), c(s_offset), c(0)]
x = vector.LoadOp(
F32[loop_tile_shape], x_ref, indices)
x_rotated_no_neg = tpu.GatherOp(
F32[loop_tile_shape], x, gather_indices, dimension=(3 - seq_minor)
)
x_rotated = arith.SelectOp(
neg_mask, arith.NegFOp(x_rotated_no_neg), x_rotated_no_neg)
result = arith.AddFOp(arith.MulFOp(x, adj_pos_cos),
arith.MulFOp(x_rotated, adj_pos_sin))
vector.StoreOp(result, y_ref, indices)
f = kernel_main.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[s // t for s, t in zip(full_shape, tile_shape)])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get(
[ir.Attribute.parse("#tpu.dimension_semantics<parallel>")] * 4)
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get(tile_shape),
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k, q) -> (n, m, k, q)>"),
}),
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get(tile_shape),
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k, q) -> (n, m, k, q)>"),
})
])
assert f.verify(), f
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
return mosaic.as_tpu_kernel(
m, out_type=core.ShapedArray(full_shape, jnp.float32))
@functools.partial(jax.jit, static_argnames=["tile_shape", "seq_minor"])
def rope(
x: jax.Array,
tile_shape: tuple[int, int, int, int] | None = None,
seq_minor: bool = False,
) -> jax.Array:
if tile_shape is None:
tile_shape = (x.shape[0], x.shape[1], min(128, x.shape[2]), x.shape[3])
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
kernel = _build_kernel(x.shape, tile_shape, seq_minor=seq_minor)
return kernel(x)
@functools.partial(jax.jit, static_argnames=["seq_minor"])
def rope_baseline(x: jax.Array, seq_minor: bool = False) -> jax.Array: # pylint: disable=missing-function-docstring
if seq_minor:
x = jnp.swapaxes(x, -1, -2)
_, _, seq_len, head_dim = x.shape # [batch, num_heads, seq_len, head_dim]
exponent = jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim
inv_freq = 10000 ** exponent
pos = jnp.arange(seq_len, dtype=jnp.float32)[None, None, :, None]
adj_pos = jnp.repeat(pos / inv_freq[None, None, None, :], 2, axis=-1)
x_even, x_odd = x[..., ::2], x[..., 1::2]
x_rotated = jnp.stack((-x_odd, x_even), axis=-1).reshape(x.shape)
y = x * jnp.cos(adj_pos) + x_rotated * jnp.sin(adj_pos)
if seq_minor:
y = jnp.swapaxes(y, -1, -2)
return y
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 15000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
batch_size = 32
seq_len = 1024
num_heads = 8
d_model = 64
seq_minor = True
if seq_minor:
full_shape = (batch_size, num_heads, d_model, seq_len)
tile_shape = (8, num_heads, d_model, 128)
else:
full_shape = (batch_size, num_heads, seq_len, d_model)
tile_shape = (batch_size, num_heads, 128, d_model)
x = random.normal(random.PRNGKey(1234), full_shape, jnp.float32)
# Make sure the implementation is correct + warm-up the caches.
baseline = functools.partial(rope_baseline, seq_minor=seq_minor)
kernel = functools.partial(rope, tile_shape=tile_shape, seq_minor=seq_minor)
np.testing.assert_allclose(baseline(x), kernel(x), atol=2e-4, rtol=2e-4)
s = time.perf_counter()
for _ in range(num_bench_iters):
kernel(x).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
baseline(x).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (XLA/Mosaic): ", d2 / d1)
if min(d1, d2) < 1:
print(
f"WARNING: Only {min(d1, d2)}s spent benchmarking one of the versions!"
)
if __name__ == "__main__":
app.run(main)