| """A Mosaic example for implementing flash attention [1]. |
| |
| [1] Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention |
| with io-awareness." Advances in Neural Information Processing Systems 35 (2022): |
| 16344-16359. (https://arxiv.org/pdf/2205.14135.pdf) |
| """ |
| |
| import dataclasses |
| import functools |
| import math as pymath |
| import time |
| from typing import Any, Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax import random |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import math |
| from mlir.dialects import scf |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| @dataclasses.dataclass |
| class VectorFactory: |
| elem_thunk: Any |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| |
| F32 = VectorFactory(ir.F32Type.get) |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["tn", "tm", "scale"]) |
| def flash_attention_kernel( |
| query: jax.Array, # (batch, num_heads, seq_len, d_model) |
| key: jax.Array, # (batch, num_heads, seq_len, d_model) |
| value: jax.Array, # (batch, num_heads, seq_len, d_model) |
| bias: jax.Array, # (batch, num_heads, seq_len, seq_len) |
| scale: int = 1.0, |
| *, |
| tn: int = 512, |
| tm: int = 512, |
| ) -> jax.Array: |
| """Build the Mosaic flash attention kernel. |
| |
| Args: |
| query: an array with shape (batch, num_heads, seq_len, d_model). |
| key: an array with shape (batch, num_heads, seq_len, d_model). |
| value: an array with shape (batch, num_heads, seq_len, d_model). |
| bias: an array with shape (batch, num_heads, seq_len, seq_len). |
| scale: the scaling factor for (Q @ K^T + bias). |
| tn: the block size of query. The shape of each block is (tn, d_model). |
| tm: the block size of key, value. The shape of each block is (tm, d_model). |
| |
| Returns: |
| attention: an array with shape (batch, num_heads, seq_len, d_model). |
| """ |
| batch, num_heads, seq_len, d_model = query.shape |
| |
| with ir.Context() as ctx, ir.Location.unknown(): |
| tpu.register_dialect(ctx) |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| |
| @func.FuncOp.from_py_func( |
| i32, # b |
| i32, # h |
| i32, # n |
| i32, # m |
| ir.MemRefType.get((batch, num_heads, tn, d_model), f32), # query_ref |
| ir.MemRefType.get((batch, num_heads, tm, d_model), f32), # key_ref |
| ir.MemRefType.get((batch, num_heads, tm, d_model), f32), # value_ref |
| ir.MemRefType.get((batch, num_heads, tn, tm), f32), # bias_ref |
| ir.MemRefType.get((batch, num_heads, tn, d_model), f32), # out_ref |
| ir.MemRefType.get((1, 1, tn, 1), f32), # global_row_max_ref |
| ir.MemRefType.get((1, 1, tn, 1), f32), # global_row_sum_ref |
| name="main", |
| ) |
| # pylint: disable=unused-argument |
| def _flash_attention_kernel( |
| b, |
| h, |
| n, |
| m, |
| query_ref, |
| key_ref, |
| value_ref, |
| bias_ref, |
| out_ref, |
| global_row_max_ref, |
| global_row_sum_ref, |
| func_op, |
| ): |
| constants = {} |
| |
| def c(val): |
| if val not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| ty = ir.IndexType.get() |
| constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val)) |
| return constants[val] |
| |
| # Get (query @ key^T + bias) |
| key_tile = vector.LoadOp( |
| F32[1, 1, tm, d_model], |
| key_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| query_tile = vector.LoadOp( |
| F32[1, 1, tn, d_model], |
| query_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| bias_tile = vector.LoadOp( |
| F32[1, 1, tn, tm], |
| bias_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| qk_matmul_2d = vector.ContractionOp( |
| F32[tn, tm], |
| vector.ShapeCastOp(F32[tn, d_model], query_tile), |
| vector.ShapeCastOp(F32[tm, d_model], key_tile), |
| vector.ShapeCastOp(F32[tn, tm], bias_tile), |
| indexing_maps=ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]), |
| iterator_types=ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]), |
| ) |
| qk_matmul = vector.ShapeCastOp(F32[1, 1, tn, tm], qk_matmul_2d) |
| scales = arith.ConstantOp( |
| F32[1, 1, tn, tm], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn, tm], ir.FloatAttr.get_f32(scale) |
| ), |
| ) |
| |
| # Get (query @ key^T + bias) * scales |
| qk_tile = arith.MulFOp(qk_matmul, scales) |
| |
| # Get local row max for qk_tile (tn, tm). |
| local_row_max = vector.MultiDimReductionOp( |
| kind=ir.Attribute.parse("#vector.kind<maximumf>"), |
| source=qk_tile, |
| acc=arith.ConstantOp( |
| F32[1, 1, tn], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn], ir.FloatAttr.get_f32(-pymath.inf) |
| ), |
| ), |
| reduction_dims=ir.Attribute.parse("[3]"), |
| ) |
| local_row_max = vector.ShapeCastOp(F32[1, 1, tn, 1], local_row_max) |
| |
| # Get exp(qk_tile - local_row_max) |
| local_exp_diff = math.ExpOp( |
| arith.SubFOp( |
| qk_tile, vector.BroadcastOp(F32[1, 1, tn, tm], local_row_max) |
| ) |
| ) |
| |
| # Get local row sum for qk_tile (tn, tm) |
| local_row_sum = vector.MultiDimReductionOp( |
| kind=ir.Attribute.parse("#vector.kind<add>"), |
| source=local_exp_diff, |
| acc=arith.ConstantOp( |
| F32[1, 1, tn], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn], ir.FloatAttr.get_f32(0) |
| ), |
| ), |
| reduction_dims=ir.Attribute.parse("[3]"), |
| ) |
| local_row_sum = vector.ShapeCastOp(F32[1, 1, tn, 1], local_row_sum) |
| |
| # Check initialization for scrach operands and output. |
| is_row_start = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 0), |
| m, |
| arith.ConstantOp(i32, ir.IntegerAttr.get(i32, 0)), |
| ) |
| if_op = scf.IfOp(is_row_start.result) |
| with ir.InsertionPoint(if_op.then_block): |
| vector.StoreOp( |
| arith.ConstantOp( |
| F32[1, 1, tn, 1], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn, 1], ir.FloatAttr.get_f32(float("-inf")) |
| ), |
| ), |
| global_row_max_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| vector.StoreOp( |
| arith.ConstantOp( |
| F32[1, 1, tn, 1], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn, 1], ir.FloatAttr.get_f32(0.0) |
| ), |
| ), |
| global_row_sum_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| vector.StoreOp( |
| arith.ConstantOp( |
| F32[1, 1, tn, d_model], |
| ir.DenseElementsAttr.get_splat( |
| F32[1, 1, tn, d_model], ir.FloatAttr.get_f32(0.0) |
| ), |
| ), |
| out_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| scf.YieldOp([]) |
| |
| # Get global row max for big tile (tn, m) |
| global_row_max = vector.LoadOp( |
| F32[1, 1, tn, 1], |
| global_row_max_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| |
| # Calculate the new global row max |
| new_global_row_max = arith.MaximumFOp(global_row_max, local_row_max) |
| vector.StoreOp( |
| new_global_row_max, global_row_max_ref, [c(0), c(0), c(0), c(0)] |
| ) |
| |
| # Get global row sum for big tile (tn, m) |
| global_row_sum = vector.LoadOp( |
| F32[1, 1, tn, 1], |
| global_row_sum_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| |
| # Calculate the new global row sum |
| exp1 = math.ExpOp(arith.SubFOp(global_row_max, new_global_row_max)) |
| exp2 = math.ExpOp(arith.SubFOp(local_row_max, new_global_row_max)) |
| new_global_row_sum_lhs = arith.MulFOp(global_row_sum, exp1) |
| new_global_row_sum = arith.AddFOp( |
| new_global_row_sum_lhs, |
| arith.MulFOp( |
| local_row_sum, |
| exp2, |
| ), |
| ) |
| vector.StoreOp( |
| new_global_row_sum, global_row_sum_ref, [c(0), c(0), c(0), c(0)] |
| ) |
| |
| # Calculate the new output rhs |
| value_tile = vector.LoadOp( |
| F32[1, 1, tm, d_model], |
| value_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| acc_tile = arith.ConstantOp( |
| F32[tn, d_model], |
| ir.DenseElementsAttr.get_splat( |
| F32[tn, d_model], ir.FloatAttr.get_f32(0.0) |
| ), |
| ) |
| out_rhs = vector.ContractionOp( |
| F32[tn, d_model], |
| vector.ShapeCastOp(F32[tn, tm], local_exp_diff), |
| vector.ShapeCastOp(F32[tm, d_model], value_tile), |
| acc_tile, |
| indexing_maps=ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]), |
| iterator_types=ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]), |
| ) |
| out_rhs = vector.ShapeCastOp(F32[1, 1, tn, d_model], out_rhs) |
| out_rhs = arith.MulFOp( |
| out_rhs, |
| vector.BroadcastOp(F32[1, 1, tn, d_model], exp2), |
| ) |
| |
| # Calculate the new output lhs |
| out_lhs = vector.LoadOp( |
| F32[1, 1, tn, d_model], |
| out_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| |
| out_lhs = arith.MulFOp( |
| out_lhs, |
| vector.BroadcastOp(F32[1, 1, tn, d_model], new_global_row_sum_lhs), |
| ) |
| |
| # Calculate the new output |
| out_tile = arith.DivFOp( |
| arith.AddFOp(out_lhs, out_rhs), |
| vector.BroadcastOp(F32[1, 1, tn, d_model], new_global_row_sum), |
| ) |
| vector.StoreOp( |
| out_tile, |
| out_ref, |
| [c(0), c(0), c(0), c(0)], |
| ) |
| |
| f = _flash_attention_kernel.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [batch, num_heads, seq_len // tn, seq_len // tm] |
| ) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| # query_ref (b, h, tn, d_model) |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(b, h, n, m) -> (b, h, n, 0)>" |
| ), |
| }), |
| # key_ref (b, h, tm, d_model) |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(b, h, n, m) -> (b, h, m, 0)>" |
| ), |
| }), |
| # value_ref (b, h, tm, d_model) |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(b, h, n, m) -> (b, h, m, 0)>" |
| ), |
| }), |
| # bias_ref (b, h, tn, tm) |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(b, h, n, m) -> (b, h, n, m)>" |
| ), |
| }), |
| # out_ref (b, h, tn, d_model) |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(b, h, n, m) -> (b, h, n, 0)>" |
| ), |
| }), |
| ]) |
| f.attributes["scratch_operands"] = ir.IntegerAttr.get(i64, 2) |
| assert f.verify(), f |
| |
| module = ir.Module.create() |
| module.body.append(f) |
| ir.SymbolTable(module.operation).insert(f) |
| |
| return mosaic.as_tpu_kernel( |
| module, |
| out_type=core.ShapedArray( |
| (batch, num_heads, seq_len, d_model), jnp.float32 |
| ), |
| )(query, key, value, bias) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["scale"]) |
| @jax.default_matmul_precision("bfloat16") |
| def baseline( |
| q: jax.Array, k: jax.Array, v: jax.Array, bias: jax.Array, scale: int = 1.0 |
| ) -> jax.Array: |
| qk = jnp.einsum("bhnd,bhmd->bhnm", q, k) |
| weights = jax.nn.softmax((qk + bias) * scale, axis=-1) |
| return jnp.einsum("bhnm,bhmk->bhnk", weights, v) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 1000 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| batch, num_heads, seq_len, d_model = 2, 8, 1024, 128 |
| tn, tm = 512, 512 |
| assert seq_len % tn == 0 and seq_len % tm == 0 |
| assert d_model % 128 == 0 |
| assert tn % 128 == 0 and tm % 128 == 0 |
| assert batch > 0 |
| assert num_heads > 0 |
| |
| flash_attention = functools.partial(flash_attention_kernel, tn=tn, tm=tm) |
| |
| full_shape = (batch, num_heads, seq_len, d_model) |
| |
| k1, k2, k3, k4 = random.split(random.PRNGKey(1234), 4) |
| query = random.normal(k1, full_shape, jnp.float32) |
| key = random.normal(k2, full_shape, jnp.float32) |
| value = random.normal(k3, full_shape, jnp.float32) |
| bias = random.normal(k4, (batch, num_heads, seq_len, seq_len), jnp.float32) |
| |
| query = query.astype(jnp.bfloat16).astype(jnp.float32) |
| key = key.astype(jnp.bfloat16).astype(jnp.float32) |
| value = value.astype(jnp.bfloat16).astype(jnp.float32) |
| bias = bias.astype(jnp.bfloat16).astype(jnp.float32) |
| scale = 1 / pymath.sqrt(d_model) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| np.testing.assert_allclose( |
| flash_attention(query, key, value, bias, scale=scale), |
| baseline(query, key, value, bias, scale=scale), |
| # atol=0.05, |
| # rtol=0.05, |
| atol=1e-2, |
| rtol=1e-2, |
| ) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| flash_attention(query, key, value, bias, scale=scale).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| baseline(query, key, value, bias, scale=scale).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print("Run-time ratio (custom/XLA):", d1 / d2) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |