blob: 56a4869dfa0c0a579e4e1cc249492bbc978ab523 [file] [log] [blame] [edit]
"""A Mosaic example for implementing flash attention [1].
[1] Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention
with io-awareness." Advances in Neural Information Processing Systems 35 (2022):
16344-16359. (https://arxiv.org/pdf/2205.14135.pdf)
"""
import dataclasses
import functools
import math as pymath
import time
from typing import Any, Sequence
from absl import app
import jax
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import math
from mlir.dialects import scf
from mlir.dialects import vector
import numpy as np
@dataclasses.dataclass
class VectorFactory:
elem_thunk: Any
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = VectorFactory(ir.F32Type.get)
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
@functools.partial(jax.jit, static_argnames=["tn", "tm", "scale"])
def flash_attention_kernel(
query: jax.Array, # (batch, num_heads, seq_len, d_model)
key: jax.Array, # (batch, num_heads, seq_len, d_model)
value: jax.Array, # (batch, num_heads, seq_len, d_model)
bias: jax.Array, # (batch, num_heads, seq_len, seq_len)
scale: int = 1.0,
*,
tn: int = 512,
tm: int = 512,
) -> jax.Array:
"""Build the Mosaic flash attention kernel.
Args:
query: an array with shape (batch, num_heads, seq_len, d_model).
key: an array with shape (batch, num_heads, seq_len, d_model).
value: an array with shape (batch, num_heads, seq_len, d_model).
bias: an array with shape (batch, num_heads, seq_len, seq_len).
scale: the scaling factor for (Q @ K^T + bias).
tn: the block size of query. The shape of each block is (tn, d_model).
tm: the block size of key, value. The shape of each block is (tm, d_model).
Returns:
attention: an array with shape (batch, num_heads, seq_len, d_model).
"""
batch, num_heads, seq_len, d_model = query.shape
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(
i32, # b
i32, # h
i32, # n
i32, # m
ir.MemRefType.get((batch, num_heads, tn, d_model), f32), # query_ref
ir.MemRefType.get((batch, num_heads, tm, d_model), f32), # key_ref
ir.MemRefType.get((batch, num_heads, tm, d_model), f32), # value_ref
ir.MemRefType.get((batch, num_heads, tn, tm), f32), # bias_ref
ir.MemRefType.get((batch, num_heads, tn, d_model), f32), # out_ref
ir.MemRefType.get((1, 1, tn, 1), f32), # global_row_max_ref
ir.MemRefType.get((1, 1, tn, 1), f32), # global_row_sum_ref
name="main",
)
# pylint: disable=unused-argument
def _flash_attention_kernel(
b,
h,
n,
m,
query_ref,
key_ref,
value_ref,
bias_ref,
out_ref,
global_row_max_ref,
global_row_sum_ref,
func_op,
):
constants = {}
def c(val):
if val not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
ty = ir.IndexType.get()
constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val))
return constants[val]
# Get (query @ key^T + bias)
key_tile = vector.LoadOp(
F32[1, 1, tm, d_model],
key_ref,
[c(0), c(0), c(0), c(0)],
)
query_tile = vector.LoadOp(
F32[1, 1, tn, d_model],
query_ref,
[c(0), c(0), c(0), c(0)],
)
bias_tile = vector.LoadOp(
F32[1, 1, tn, tm],
bias_ref,
[c(0), c(0), c(0), c(0)],
)
qk_matmul_2d = vector.ContractionOp(
F32[tn, tm],
vector.ShapeCastOp(F32[tn, d_model], query_tile),
vector.ShapeCastOp(F32[tm, d_model], key_tile),
vector.ShapeCastOp(F32[tn, tm], bias_tile),
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
]),
)
qk_matmul = vector.ShapeCastOp(F32[1, 1, tn, tm], qk_matmul_2d)
scales = arith.ConstantOp(
F32[1, 1, tn, tm],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn, tm], ir.FloatAttr.get_f32(scale)
),
)
# Get (query @ key^T + bias) * scales
qk_tile = arith.MulFOp(qk_matmul, scales)
# Get local row max for qk_tile (tn, tm).
local_row_max = vector.MultiDimReductionOp(
kind=ir.Attribute.parse("#vector.kind<maximumf>"),
source=qk_tile,
acc=arith.ConstantOp(
F32[1, 1, tn],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn], ir.FloatAttr.get_f32(-pymath.inf)
),
),
reduction_dims=ir.Attribute.parse("[3]"),
)
local_row_max = vector.ShapeCastOp(F32[1, 1, tn, 1], local_row_max)
# Get exp(qk_tile - local_row_max)
local_exp_diff = math.ExpOp(
arith.SubFOp(
qk_tile, vector.BroadcastOp(F32[1, 1, tn, tm], local_row_max)
)
)
# Get local row sum for qk_tile (tn, tm)
local_row_sum = vector.MultiDimReductionOp(
kind=ir.Attribute.parse("#vector.kind<add>"),
source=local_exp_diff,
acc=arith.ConstantOp(
F32[1, 1, tn],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn], ir.FloatAttr.get_f32(0)
),
),
reduction_dims=ir.Attribute.parse("[3]"),
)
local_row_sum = vector.ShapeCastOp(F32[1, 1, tn, 1], local_row_sum)
# Check initialization for scrach operands and output.
is_row_start = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0),
m,
arith.ConstantOp(i32, ir.IntegerAttr.get(i32, 0)),
)
if_op = scf.IfOp(is_row_start.result)
with ir.InsertionPoint(if_op.then_block):
vector.StoreOp(
arith.ConstantOp(
F32[1, 1, tn, 1],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn, 1], ir.FloatAttr.get_f32(float("-inf"))
),
),
global_row_max_ref,
[c(0), c(0), c(0), c(0)],
)
vector.StoreOp(
arith.ConstantOp(
F32[1, 1, tn, 1],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn, 1], ir.FloatAttr.get_f32(0.0)
),
),
global_row_sum_ref,
[c(0), c(0), c(0), c(0)],
)
vector.StoreOp(
arith.ConstantOp(
F32[1, 1, tn, d_model],
ir.DenseElementsAttr.get_splat(
F32[1, 1, tn, d_model], ir.FloatAttr.get_f32(0.0)
),
),
out_ref,
[c(0), c(0), c(0), c(0)],
)
scf.YieldOp([])
# Get global row max for big tile (tn, m)
global_row_max = vector.LoadOp(
F32[1, 1, tn, 1],
global_row_max_ref,
[c(0), c(0), c(0), c(0)],
)
# Calculate the new global row max
new_global_row_max = arith.MaximumFOp(global_row_max, local_row_max)
vector.StoreOp(
new_global_row_max, global_row_max_ref, [c(0), c(0), c(0), c(0)]
)
# Get global row sum for big tile (tn, m)
global_row_sum = vector.LoadOp(
F32[1, 1, tn, 1],
global_row_sum_ref,
[c(0), c(0), c(0), c(0)],
)
# Calculate the new global row sum
exp1 = math.ExpOp(arith.SubFOp(global_row_max, new_global_row_max))
exp2 = math.ExpOp(arith.SubFOp(local_row_max, new_global_row_max))
new_global_row_sum_lhs = arith.MulFOp(global_row_sum, exp1)
new_global_row_sum = arith.AddFOp(
new_global_row_sum_lhs,
arith.MulFOp(
local_row_sum,
exp2,
),
)
vector.StoreOp(
new_global_row_sum, global_row_sum_ref, [c(0), c(0), c(0), c(0)]
)
# Calculate the new output rhs
value_tile = vector.LoadOp(
F32[1, 1, tm, d_model],
value_ref,
[c(0), c(0), c(0), c(0)],
)
acc_tile = arith.ConstantOp(
F32[tn, d_model],
ir.DenseElementsAttr.get_splat(
F32[tn, d_model], ir.FloatAttr.get_f32(0.0)
),
)
out_rhs = vector.ContractionOp(
F32[tn, d_model],
vector.ShapeCastOp(F32[tn, tm], local_exp_diff),
vector.ShapeCastOp(F32[tm, d_model], value_tile),
acc_tile,
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
]),
)
out_rhs = vector.ShapeCastOp(F32[1, 1, tn, d_model], out_rhs)
out_rhs = arith.MulFOp(
out_rhs,
vector.BroadcastOp(F32[1, 1, tn, d_model], exp2),
)
# Calculate the new output lhs
out_lhs = vector.LoadOp(
F32[1, 1, tn, d_model],
out_ref,
[c(0), c(0), c(0), c(0)],
)
out_lhs = arith.MulFOp(
out_lhs,
vector.BroadcastOp(F32[1, 1, tn, d_model], new_global_row_sum_lhs),
)
# Calculate the new output
out_tile = arith.DivFOp(
arith.AddFOp(out_lhs, out_rhs),
vector.BroadcastOp(F32[1, 1, tn, d_model], new_global_row_sum),
)
vector.StoreOp(
out_tile,
out_ref,
[c(0), c(0), c(0), c(0)],
)
f = _flash_attention_kernel.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[batch, num_heads, seq_len // tn, seq_len // tm]
)
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
# query_ref (b, h, tn, d_model)
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(b, h, n, m) -> (b, h, n, 0)>"
),
}),
# key_ref (b, h, tm, d_model)
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(b, h, n, m) -> (b, h, m, 0)>"
),
}),
# value_ref (b, h, tm, d_model)
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(b, h, n, m) -> (b, h, m, 0)>"
),
}),
# bias_ref (b, h, tn, tm)
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(b, h, n, m) -> (b, h, n, m)>"
),
}),
# out_ref (b, h, tn, d_model)
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(b, h, n, m) -> (b, h, n, 0)>"
),
}),
])
f.attributes["scratch_operands"] = ir.IntegerAttr.get(i64, 2)
assert f.verify(), f
module = ir.Module.create()
module.body.append(f)
ir.SymbolTable(module.operation).insert(f)
return mosaic.as_tpu_kernel(
module,
out_type=core.ShapedArray(
(batch, num_heads, seq_len, d_model), jnp.float32
),
)(query, key, value, bias)
@functools.partial(jax.jit, static_argnames=["scale"])
@jax.default_matmul_precision("bfloat16")
def baseline(
q: jax.Array, k: jax.Array, v: jax.Array, bias: jax.Array, scale: int = 1.0
) -> jax.Array:
qk = jnp.einsum("bhnd,bhmd->bhnm", q, k)
weights = jax.nn.softmax((qk + bias) * scale, axis=-1)
return jnp.einsum("bhnm,bhmk->bhnk", weights, v)
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 1000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
batch, num_heads, seq_len, d_model = 2, 8, 1024, 128
tn, tm = 512, 512
assert seq_len % tn == 0 and seq_len % tm == 0
assert d_model % 128 == 0
assert tn % 128 == 0 and tm % 128 == 0
assert batch > 0
assert num_heads > 0
flash_attention = functools.partial(flash_attention_kernel, tn=tn, tm=tm)
full_shape = (batch, num_heads, seq_len, d_model)
k1, k2, k3, k4 = random.split(random.PRNGKey(1234), 4)
query = random.normal(k1, full_shape, jnp.float32)
key = random.normal(k2, full_shape, jnp.float32)
value = random.normal(k3, full_shape, jnp.float32)
bias = random.normal(k4, (batch, num_heads, seq_len, seq_len), jnp.float32)
query = query.astype(jnp.bfloat16).astype(jnp.float32)
key = key.astype(jnp.bfloat16).astype(jnp.float32)
value = value.astype(jnp.bfloat16).astype(jnp.float32)
bias = bias.astype(jnp.bfloat16).astype(jnp.float32)
scale = 1 / pymath.sqrt(d_model)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(
flash_attention(query, key, value, bias, scale=scale),
baseline(query, key, value, bias, scale=scale),
# atol=0.05,
# rtol=0.05,
atol=1e-2,
rtol=1e-2,
)
s = time.perf_counter()
for _ in range(num_bench_iters):
flash_attention(query, key, value, bias, scale=scale).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
baseline(query, key, value, bias, scale=scale).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA):", d1 / d2)
if __name__ == "__main__":
app.run(main)