blob: 7ba038fe706ddc87f5cb724e893b2499ce4397fb [file] [log] [blame] [edit]
"""An example Mosaic kernel implementing collective matrix multiplication."""
import contextlib
import functools
from typing import Callable
import jax
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import scf
from mlir.dialects import vector
import numpy as np
LEFT = False
RIGHT = True
@contextlib.contextmanager
def _trace(message, level=10):
init_trace = tpu.TraceOp([], message=message, level=level)
with ir.InsertionPoint(init_trace.body):
yield
tpu.YieldOp([])
def _constant_cache(
func_op: func.FuncOp,
) -> Callable[..., ir.Value]:
"""A helper for more lightweight integer constant creation.
Arguments:
func_op: The function op inside which the constants should be placed.
Returns:
A function that can be used to created typed constants. The default constant
type is index.
"""
constants: dict[tuple[int, ir.Type], ir.Value] = {}
def c(val: int, ty: ir.Type | None = None):
if ty is None:
ty = ir.IndexType.get()
result = constants.get((val, ty), None)
if result is None:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
constants[(val, ty)] = result = arith.ConstantOp(
ty, ir.IntegerAttr.get(ty, val)
).result
return result
return c
def allgather_kernel(
lhs_local: jax.ShapeDtypeStruct,
rhs_local: jax.ShapeDtypeStruct,
out_local: jax.ShapeDtypeStruct,
rings: np.ndarray,
collective_id: int,
buffering: int = 2,
backend: str = "tpu",
) -> ir.Module:
"""Builds an MLIR module for the all-gather collective matmul kernel.
The kernel implements an all-gather collective matmul where the right hand
side is the stationary operand.
Arguments:
lhs_local: A type specification for a shard of the left operand.
rhs_local: A type specification for a shard of the right operand.
out_local: A type specification for an output of the operation.
rings: A 2D 32-bit integer array of physical device ids, specifying a number
of independent rings along which the collective will be performed. The
major dimension ranges over rings, while the minor dimension ranges over
devices within a single ring.
collective_id: Any integer unique to the specified ring, across all Mosaic
collectives.
buffering: The degree of transfer buffering. Increasing buffering raises
memory usage, but can decrease synchronization costs. Should be a power
of 2.
backend: The JAX backend used to compile the programs.
Returns:
An MLIR module implementing collective matmul in Mosaic.
"""
with ir.Context() as ctx, ir.Location.unknown():
return _allgather_kernel(
ctx,
lhs_local,
rhs_local,
out_local,
rings,
collective_id,
buffering,
backend,
)
def _allgather_kernel(
ctx,
lhs_local,
rhs_local,
out_local,
rings,
collective_id,
buffering,
backend,
):
"""The implementation of allgather_kernel."""
tpu.register_dialect(ctx)
# For now we assume that RHS is the stationary input.
n, k_shard = lhs_local.shape
k_full, m = rhs_local.shape
if rings.ndim != 2:
raise ValueError("The ring should be a 2D NumPy array")
ring_len = rings.shape[-1]
if k_shard * ring_len != k_full:
raise ValueError(
f"LHS contraction dimension has size {k_shard}, ring has length"
f" {ring_len} so expected RHS to have shape {k_shard * ring_len},"
f" but got {k_full}"
)
inputs_bf16 = lhs_local.dtype == rhs_local.dtype == np.dtype(jnp.bfloat16)
output_fp32 = out_local.dtype == np.dtype(np.float32)
if not inputs_bf16 or not output_fp32:
raise ValueError("Only bf16 inputs and fp32 outputs supported")
lhs_mlir_dtype = rhs_mlir_dtype = ir.BF16Type.get()
out_mlir_dtype = ir.F32Type.get()
if ring_len % 2 != 0:
raise ValueError(
"Bidirectional collective matmul only supports evenly sized rings"
)
if rings.dtype != np.dtype(np.int32):
raise ValueError("Ring description should be encoded in 32-bit integers.")
if buffering.bit_count() != 1:
raise NotImplementedError("Buffering values should be multiples of 2")
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
index = ir.IndexType.get()
sem_mem = ir.Attribute.parse("#tpu.memory_space<semaphore_mem>")
sem_type = ir.Type.parse("!tpu.semaphore")
dma_sem_type = ir.Type.parse("!tpu.dma_semaphore")
sem_ref_type = ir.MemRefType.get((), sem_type, memory_space=sem_mem)
dma_sem_ref_type = ir.MemRefType.get((), dma_sem_type, memory_space=sem_mem)
smem = ir.Attribute.parse("#tpu.memory_space<smem>")
scratch_type = ir.MemRefType.get(
(buffering, 2, *lhs_local.shape), lhs_mlir_dtype
)
lhs_vec = lambda shape: ir.VectorType.get(shape, lhs_mlir_dtype)
rhs_vec = lambda shape: ir.VectorType.get(shape, rhs_mlir_dtype)
out_vec_ty = ir.VectorType.get((n, m), out_mlir_dtype)
pred_eq = ir.IntegerAttr.get(i64, 0)
pred_ne = ir.IntegerAttr.get(i64, 1)
pred_lt = ir.IntegerAttr.get(i64, 2) # Signed.
pred_ge = ir.IntegerAttr.get(i64, 5) # Signed.
if set(rings.flat) != set(range(rings.size)):
raise NotImplementedError(
"Device IDs used in collective should span a contiguous range starting"
" at 0"
)
ring_position = np.empty((rings.size,), dtype=np.int32)
for i, device in enumerate(rings.flat):
ring_position[device] = i
@func.FuncOp.from_py_func(
i32,
ir.MemRefType.get(rings.shape, i32, memory_space=smem),
ir.MemRefType.get((rings.size,), i32, memory_space=smem),
ir.MemRefType.get(lhs_local.shape, lhs_mlir_dtype),
ir.MemRefType.get((k_shard, m), rhs_mlir_dtype),
ir.MemRefType.get((k_shard, m), rhs_mlir_dtype),
ir.MemRefType.get((n, m), out_mlir_dtype),
sem_ref_type,
sem_ref_type,
dma_sem_ref_type,
dma_sem_ref_type,
dma_sem_ref_type,
dma_sem_ref_type,
scratch_type,
name="main",
)
def matmul(
step_with_init,
rings_ref,
ring_position_ref,
lhs_ref,
rhs_ref_left,
rhs_ref_right,
out_ref,
capacity_left,
capacity_right,
send_right_sem,
send_left_sem,
recv_right_sem,
recv_left_sem,
scratch,
func_op,
): # pylint: disable=unused-argument
c = _constant_cache(func_op)
my_id = arith.IndexCastOp(index, tpu.DeviceIdOp())
my_pos_lin = arith.IndexCastOp(
index, memref.LoadOp(ring_position_ref, [my_id])
)
my_ring = arith.DivSIOp(my_pos_lin, c(ring_len))
my_pos = arith.RemSIOp(my_pos_lin, c(ring_len))
first_pos = arith.CmpIOp(pred_eq, my_pos, c(0))
left_pos = arith.SelectOp(
first_pos, c(ring_len - 1), arith.SubIOp(my_pos, c(1))
)
left_neighbor = memref.LoadOp(rings_ref, [my_ring, left_pos])
last_pos = arith.CmpIOp(pred_eq, my_pos, c(ring_len - 1))
right_pos = arith.SelectOp(last_pos, c(0), arith.AddIOp(my_pos, c(1)))
right_neighbor = memref.LoadOp(rings_ref, [my_ring, right_pos])
is_init = arith.CmpIOp(pred_eq, step_with_init, c(0, i32))
def signal(side):
sem = capacity_right if side == RIGHT else capacity_left
neighbor = left_neighbor if side == RIGHT else right_neighbor
tpu.SemaphoreSignalOp(sem, c(1, i32), device_id=neighbor)
side_transfer_ty = ir.MemRefType.get(
(1, 1, *lhs_local.shape), lhs_mlir_dtype
)
some_scratch_slice = tpu.MemRefSliceOp(
side_transfer_ty, scratch, [c(0, i32)] * 4,
)
def shift(side, device_id, capacity_sem, send_sem, recv_sem):
src = tpu.MemRefSliceOp(
side_transfer_ty,
scratch,
[version_i32, c(side, i32), c(0, i32), c(0, i32)],
)
dst = tpu.MemRefSliceOp(
side_transfer_ty,
scratch,
[next_version_i32, c(side, i32), c(0, i32), c(0, i32)],
)
# Wait until the neighbor has sufficient capacity to receive.
tpu.SemaphoreWaitOp(capacity_sem, c(1, i32))
tpu.EnqueueDMAOp(src, dst, recv_sem, source_semaphore=send_sem,
device_id=device_id)
shift_left = functools.partial(
shift, LEFT, left_neighbor, capacity_left, send_left_sem, recv_left_sem
)
shift_right = functools.partial(
shift, RIGHT, right_neighbor,
capacity_right, send_right_sem, recv_right_sem
)
def compute(side, version_i32):
twosided_lhs_ref_ty = ir.MemRefType.get(
(1, 2, *lhs_local.shape), lhs_mlir_dtype
)
current_lhs_scratch = tpu.MemRefSliceOp(
twosided_lhs_ref_ty,
scratch,
[version_i32, c(0, i32), c(0, i32), c(0, i32)],
)
rhs_ref = rhs_ref_left if side == LEFT else rhs_ref_right
rhs = vector.LoadOp(
rhs_vec((k_shard, m)), rhs_ref, [c(0), c(0)]
)
# TODO(apaszke): Support major dynamic indices in vector.LoadOp
lhs = vector.LoadOp(
lhs_vec((1, 1, *lhs_local.shape)),
current_lhs_scratch,
[c(0), c(side), c(0), c(0)],
)
lhs = vector.ShapeCastOp(lhs_vec(lhs_local.shape), lhs)
old_out = vector.LoadOp(out_vec_ty, out_ref, [c(0), c(0)]).result
new_out = vector.ContractionOp(
old_out.type, lhs, rhs, old_out,
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
])
)
vector.StoreOp(new_out, out_ref, [c(0), c(0)])
if_init = scf.IfOp(is_init.result, hasElse=True)
with ir.InsertionPoint(if_init.then_block):
# Initialize the capacity flags.
tpu.SemaphoreSignalOp(capacity_right, c(buffering - 1, i32))
tpu.SemaphoreSignalOp(capacity_left, c(buffering - 1, i32))
with _trace("Init barrier"):
# We will only be communicating with neighbors, so we only have to
# synchronize with them.
barrier_semaphore = tpu.GetBarrierSemaphoreOp()
tpu.SemaphoreSignalOp(
barrier_semaphore, c(1, i32), device_id=left_neighbor
)
tpu.SemaphoreSignalOp(
barrier_semaphore, c(1, i32), device_id=right_neighbor
)
tpu.SemaphoreWaitOp(barrier_semaphore, c(2, i32))
# Initialize left part of the scratch space.
# We don't use a local DMA, because the tiling might be different.
init_lhs = vector.LoadOp(lhs_vec(lhs_local.shape), lhs_ref, [c(0), c(0)])
init_lhs = vector.ShapeCastOp(lhs_vec((1, 1, *lhs_local.shape)), init_lhs)
vector.StoreOp(init_lhs, scratch, [c(0), c(LEFT), c(0), c(0)])
# Initialize the right part of the scratch space (remote DMA).
left_scratch = tpu.MemRefSliceOp(
side_transfer_ty,
scratch,
[c(0, i32), c(LEFT, i32), c(0, i32), c(0, i32)],
)
right_scratch = tpu.MemRefSliceOp(
side_transfer_ty,
scratch,
[c(0, i32), c(RIGHT, i32), c(0, i32), c(0, i32)],
)
tpu.EnqueueDMAOp(
left_scratch, right_scratch, recv_right_sem,
source_semaphore=send_right_sem, device_id=right_neighbor
)
# Initialize the output.
out_init = arith.ConstantOp(
out_vec_ty,
ir.DenseElementsAttr.get_splat(
out_vec_ty, ir.FloatAttr.get_f32(0.0)
),
)
vector.StoreOp(out_init, out_ref, [c(0), c(0)])
# Wait for the send to complete, to satisfy the loop invariant below.
with _trace("Send right"):
tpu.WaitDMAOp(send_right_sem, some_scratch_slice)
scf.YieldOp([])
with ir.InsertionPoint(if_init.else_block):
# Invariant: left data is here, right data has already been sent.
step = arith.SubIOp(step_with_init, c(1, i32))
version_i32 = arith.AndIOp(step, c(buffering - 1, i32))
next_version_i32 = arith.AddIOp(version_i32, c(1, i32))
next_version_i32 = arith.AndIOp(next_version_i32, c(buffering - 1, i32))
last_step = c((ring_len // 2) - 1, i32)
more_steps = arith.CmpIOp(pred_ne, step, last_step)
if_more_steps = scf.IfOp(more_steps.result)
with ir.InsertionPoint(if_more_steps.then_block):
with _trace("Capacity left"):
shift_left()
scf.YieldOp([])
with _trace("Recv right"):
tpu.WaitDMAOp(recv_right_sem, some_scratch_slice)
if_more_steps = scf.IfOp(more_steps.result)
with ir.InsertionPoint(if_more_steps.then_block):
with _trace("Capacity right"):
shift_right()
scf.YieldOp([])
with _trace("Compute"):
compute(LEFT, version_i32)
compute(RIGHT, version_i32)
if_more_steps = scf.IfOp(more_steps.result, hasElse=True)
with ir.InsertionPoint(if_more_steps.then_block):
with _trace("Send left"):
tpu.WaitDMAOp(send_left_sem, some_scratch_slice)
signal(LEFT)
with _trace("Send right"):
tpu.WaitDMAOp(send_right_sem, some_scratch_slice)
signal(RIGHT)
with _trace("Recv left"):
tpu.WaitDMAOp(recv_left_sem, some_scratch_slice)
scf.YieldOp([])
with ir.InsertionPoint(if_more_steps.else_block):
with _trace("Final barrier"):
# We skip the increment in the last iteration to save on latency,
# so we only expect to see buffering - 1.
tpu.SemaphoreWaitOp(capacity_left, c(buffering - 1, i32))
tpu.SemaphoreWaitOp(capacity_right, c(buffering - 1, i32))
scf.YieldOp([])
scf.YieldOp([])
def rhs_indices(
step_with_init, rings_ref, ring_position_ref, func_op, *, is_right
):
del rings_ref # Unused.
c = _constant_cache(func_op)
is_init = arith.CmpIOp(pred_eq, step_with_init, c(0, i32))
step = arith.SelectOp(
is_init, c(0, i32), arith.SubIOp(step_with_init, c(1, i32))
)
my_id = arith.IndexCastOp(index, tpu.DeviceIdOp())
my_pos = arith.RemSIOp(
memref.LoadOp(ring_position_ref, [my_id]), c(ring_len, i32)
)
if is_right:
source_pos = arith.SubIOp(my_pos, arith.AddIOp(step, c(1, i32)))
is_oob = arith.CmpIOp(pred_lt, source_pos, c(0, i32))
source_pos = arith.SelectOp(
is_oob, arith.AddIOp(source_pos, c(ring_len, i32)), source_pos
)
else:
source_pos = arith.AddIOp(my_pos, step)
is_oob = arith.CmpIOp(pred_ge, source_pos, c(ring_len, i32))
source_pos = arith.SelectOp(
is_oob, arith.SubIOp(source_pos, c(ring_len, i32)), source_pos
)
return (source_pos.result, c(0, i32))
left_rhs_transform_indices = func.FuncOp.from_py_func(
i32,
ir.MemRefType.get(rings.shape, i32, memory_space=smem),
ir.MemRefType.get((rings.size,), i32, memory_space=smem),
name="left_rhs_transform_indices",
)(lambda *args, func_op: rhs_indices(*args, func_op, is_right=False))
right_rhs_transform_indices = func.FuncOp.from_py_func(
i32,
ir.MemRefType.get(rings.shape, i32, memory_space=smem),
ir.MemRefType.get((rings.size,), i32, memory_space=smem),
name="right_rhs_transform_indices",
)(lambda *args, func_op: rhs_indices(*args, func_op, is_right=True))
matmul.func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(i64, 2)
matmul.func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(i64, 7)
# Prologue + steps
matmul.func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[1 + (ring_len // 2)]
)
matmul.func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
matmul.func_op.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(i) -> (0, 0)>"),
}),
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get([k_shard, m]),
"transform_indices": ir.FlatSymbolRefAttr.get(
"left_rhs_transform_indices"
),
}),
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get([k_shard, m]),
"transform_indices": ir.FlatSymbolRefAttr.get(
"right_rhs_transform_indices"
),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(i) -> (0, 0)>"),
}),
])
module = ir.Module.create()
sym_tab = ir.SymbolTable(module.operation)
for f in (
matmul.func_op,
left_rhs_transform_indices.func_op,
right_rhs_transform_indices.func_op,
):
module.body.append(f)
sym_tab.insert(f)
module.operation.verify()
kernel = mosaic.as_tpu_kernel(module, out_local, backend=backend)
def kernel_wrapper(lhs, rhs):
return kernel(
rings, ring_position, lhs, rhs, rhs, collective_id=collective_id
)
return jax.jit(kernel_wrapper)