| """An example Mosaic kernel implementing collective matrix multiplication.""" |
| |
| import contextlib |
| import functools |
| from typing import Callable |
| |
| import jax |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import memref |
| from mlir.dialects import scf |
| from mlir.dialects import vector |
| import numpy as np |
| |
| LEFT = False |
| RIGHT = True |
| |
| |
| @contextlib.contextmanager |
| def _trace(message, level=10): |
| init_trace = tpu.TraceOp([], message=message, level=level) |
| with ir.InsertionPoint(init_trace.body): |
| yield |
| tpu.YieldOp([]) |
| |
| |
| def _constant_cache( |
| func_op: func.FuncOp, |
| ) -> Callable[..., ir.Value]: |
| """A helper for more lightweight integer constant creation. |
| |
| Arguments: |
| func_op: The function op inside which the constants should be placed. |
| |
| Returns: |
| A function that can be used to created typed constants. The default constant |
| type is index. |
| """ |
| constants: dict[tuple[int, ir.Type], ir.Value] = {} |
| def c(val: int, ty: ir.Type | None = None): |
| if ty is None: |
| ty = ir.IndexType.get() |
| result = constants.get((val, ty), None) |
| if result is None: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| constants[(val, ty)] = result = arith.ConstantOp( |
| ty, ir.IntegerAttr.get(ty, val) |
| ).result |
| return result |
| |
| return c |
| |
| |
| def allgather_kernel( |
| lhs_local: jax.ShapeDtypeStruct, |
| rhs_local: jax.ShapeDtypeStruct, |
| out_local: jax.ShapeDtypeStruct, |
| rings: np.ndarray, |
| collective_id: int, |
| buffering: int = 2, |
| backend: str = "tpu", |
| ) -> ir.Module: |
| """Builds an MLIR module for the all-gather collective matmul kernel. |
| |
| The kernel implements an all-gather collective matmul where the right hand |
| side is the stationary operand. |
| |
| Arguments: |
| lhs_local: A type specification for a shard of the left operand. |
| rhs_local: A type specification for a shard of the right operand. |
| out_local: A type specification for an output of the operation. |
| rings: A 2D 32-bit integer array of physical device ids, specifying a number |
| of independent rings along which the collective will be performed. The |
| major dimension ranges over rings, while the minor dimension ranges over |
| devices within a single ring. |
| collective_id: Any integer unique to the specified ring, across all Mosaic |
| collectives. |
| buffering: The degree of transfer buffering. Increasing buffering raises |
| memory usage, but can decrease synchronization costs. Should be a power |
| of 2. |
| backend: The JAX backend used to compile the programs. |
| |
| Returns: |
| An MLIR module implementing collective matmul in Mosaic. |
| """ |
| with ir.Context() as ctx, ir.Location.unknown(): |
| return _allgather_kernel( |
| ctx, |
| lhs_local, |
| rhs_local, |
| out_local, |
| rings, |
| collective_id, |
| buffering, |
| backend, |
| ) |
| |
| |
| def _allgather_kernel( |
| ctx, |
| lhs_local, |
| rhs_local, |
| out_local, |
| rings, |
| collective_id, |
| buffering, |
| backend, |
| ): |
| """The implementation of allgather_kernel.""" |
| tpu.register_dialect(ctx) |
| # For now we assume that RHS is the stationary input. |
| n, k_shard = lhs_local.shape |
| k_full, m = rhs_local.shape |
| if rings.ndim != 2: |
| raise ValueError("The ring should be a 2D NumPy array") |
| ring_len = rings.shape[-1] |
| if k_shard * ring_len != k_full: |
| raise ValueError( |
| f"LHS contraction dimension has size {k_shard}, ring has length" |
| f" {ring_len} so expected RHS to have shape {k_shard * ring_len}," |
| f" but got {k_full}" |
| ) |
| inputs_bf16 = lhs_local.dtype == rhs_local.dtype == np.dtype(jnp.bfloat16) |
| output_fp32 = out_local.dtype == np.dtype(np.float32) |
| if not inputs_bf16 or not output_fp32: |
| raise ValueError("Only bf16 inputs and fp32 outputs supported") |
| lhs_mlir_dtype = rhs_mlir_dtype = ir.BF16Type.get() |
| out_mlir_dtype = ir.F32Type.get() |
| if ring_len % 2 != 0: |
| raise ValueError( |
| "Bidirectional collective matmul only supports evenly sized rings" |
| ) |
| if rings.dtype != np.dtype(np.int32): |
| raise ValueError("Ring description should be encoded in 32-bit integers.") |
| if buffering.bit_count() != 1: |
| raise NotImplementedError("Buffering values should be multiples of 2") |
| |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| index = ir.IndexType.get() |
| sem_mem = ir.Attribute.parse("#tpu.memory_space<semaphore_mem>") |
| sem_type = ir.Type.parse("!tpu.semaphore") |
| dma_sem_type = ir.Type.parse("!tpu.dma_semaphore") |
| sem_ref_type = ir.MemRefType.get((), sem_type, memory_space=sem_mem) |
| dma_sem_ref_type = ir.MemRefType.get((), dma_sem_type, memory_space=sem_mem) |
| |
| smem = ir.Attribute.parse("#tpu.memory_space<smem>") |
| scratch_type = ir.MemRefType.get( |
| (buffering, 2, *lhs_local.shape), lhs_mlir_dtype |
| ) |
| |
| lhs_vec = lambda shape: ir.VectorType.get(shape, lhs_mlir_dtype) |
| rhs_vec = lambda shape: ir.VectorType.get(shape, rhs_mlir_dtype) |
| out_vec_ty = ir.VectorType.get((n, m), out_mlir_dtype) |
| |
| pred_eq = ir.IntegerAttr.get(i64, 0) |
| pred_ne = ir.IntegerAttr.get(i64, 1) |
| pred_lt = ir.IntegerAttr.get(i64, 2) # Signed. |
| pred_ge = ir.IntegerAttr.get(i64, 5) # Signed. |
| |
| if set(rings.flat) != set(range(rings.size)): |
| raise NotImplementedError( |
| "Device IDs used in collective should span a contiguous range starting" |
| " at 0" |
| ) |
| ring_position = np.empty((rings.size,), dtype=np.int32) |
| for i, device in enumerate(rings.flat): |
| ring_position[device] = i |
| |
| @func.FuncOp.from_py_func( |
| i32, |
| ir.MemRefType.get(rings.shape, i32, memory_space=smem), |
| ir.MemRefType.get((rings.size,), i32, memory_space=smem), |
| ir.MemRefType.get(lhs_local.shape, lhs_mlir_dtype), |
| ir.MemRefType.get((k_shard, m), rhs_mlir_dtype), |
| ir.MemRefType.get((k_shard, m), rhs_mlir_dtype), |
| ir.MemRefType.get((n, m), out_mlir_dtype), |
| sem_ref_type, |
| sem_ref_type, |
| dma_sem_ref_type, |
| dma_sem_ref_type, |
| dma_sem_ref_type, |
| dma_sem_ref_type, |
| scratch_type, |
| name="main", |
| ) |
| def matmul( |
| step_with_init, |
| rings_ref, |
| ring_position_ref, |
| lhs_ref, |
| rhs_ref_left, |
| rhs_ref_right, |
| out_ref, |
| capacity_left, |
| capacity_right, |
| send_right_sem, |
| send_left_sem, |
| recv_right_sem, |
| recv_left_sem, |
| scratch, |
| func_op, |
| ): # pylint: disable=unused-argument |
| c = _constant_cache(func_op) |
| |
| my_id = arith.IndexCastOp(index, tpu.DeviceIdOp()) |
| my_pos_lin = arith.IndexCastOp( |
| index, memref.LoadOp(ring_position_ref, [my_id]) |
| ) |
| my_ring = arith.DivSIOp(my_pos_lin, c(ring_len)) |
| my_pos = arith.RemSIOp(my_pos_lin, c(ring_len)) |
| first_pos = arith.CmpIOp(pred_eq, my_pos, c(0)) |
| left_pos = arith.SelectOp( |
| first_pos, c(ring_len - 1), arith.SubIOp(my_pos, c(1)) |
| ) |
| left_neighbor = memref.LoadOp(rings_ref, [my_ring, left_pos]) |
| last_pos = arith.CmpIOp(pred_eq, my_pos, c(ring_len - 1)) |
| right_pos = arith.SelectOp(last_pos, c(0), arith.AddIOp(my_pos, c(1))) |
| right_neighbor = memref.LoadOp(rings_ref, [my_ring, right_pos]) |
| |
| is_init = arith.CmpIOp(pred_eq, step_with_init, c(0, i32)) |
| |
| def signal(side): |
| sem = capacity_right if side == RIGHT else capacity_left |
| neighbor = left_neighbor if side == RIGHT else right_neighbor |
| tpu.SemaphoreSignalOp(sem, c(1, i32), device_id=neighbor) |
| |
| side_transfer_ty = ir.MemRefType.get( |
| (1, 1, *lhs_local.shape), lhs_mlir_dtype |
| ) |
| some_scratch_slice = tpu.MemRefSliceOp( |
| side_transfer_ty, scratch, [c(0, i32)] * 4, |
| ) |
| def shift(side, device_id, capacity_sem, send_sem, recv_sem): |
| src = tpu.MemRefSliceOp( |
| side_transfer_ty, |
| scratch, |
| [version_i32, c(side, i32), c(0, i32), c(0, i32)], |
| ) |
| dst = tpu.MemRefSliceOp( |
| side_transfer_ty, |
| scratch, |
| [next_version_i32, c(side, i32), c(0, i32), c(0, i32)], |
| ) |
| # Wait until the neighbor has sufficient capacity to receive. |
| tpu.SemaphoreWaitOp(capacity_sem, c(1, i32)) |
| tpu.EnqueueDMAOp(src, dst, recv_sem, source_semaphore=send_sem, |
| device_id=device_id) |
| |
| shift_left = functools.partial( |
| shift, LEFT, left_neighbor, capacity_left, send_left_sem, recv_left_sem |
| ) |
| shift_right = functools.partial( |
| shift, RIGHT, right_neighbor, |
| capacity_right, send_right_sem, recv_right_sem |
| ) |
| |
| def compute(side, version_i32): |
| twosided_lhs_ref_ty = ir.MemRefType.get( |
| (1, 2, *lhs_local.shape), lhs_mlir_dtype |
| ) |
| current_lhs_scratch = tpu.MemRefSliceOp( |
| twosided_lhs_ref_ty, |
| scratch, |
| [version_i32, c(0, i32), c(0, i32), c(0, i32)], |
| ) |
| rhs_ref = rhs_ref_left if side == LEFT else rhs_ref_right |
| rhs = vector.LoadOp( |
| rhs_vec((k_shard, m)), rhs_ref, [c(0), c(0)] |
| ) |
| # TODO(apaszke): Support major dynamic indices in vector.LoadOp |
| lhs = vector.LoadOp( |
| lhs_vec((1, 1, *lhs_local.shape)), |
| current_lhs_scratch, |
| [c(0), c(side), c(0), c(0)], |
| ) |
| lhs = vector.ShapeCastOp(lhs_vec(lhs_local.shape), lhs) |
| old_out = vector.LoadOp(out_vec_ty, out_ref, [c(0), c(0)]).result |
| new_out = vector.ContractionOp( |
| old_out.type, lhs, rhs, old_out, |
| indexing_maps=ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]), |
| iterator_types=ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]) |
| ) |
| vector.StoreOp(new_out, out_ref, [c(0), c(0)]) |
| |
| if_init = scf.IfOp(is_init.result, hasElse=True) |
| with ir.InsertionPoint(if_init.then_block): |
| # Initialize the capacity flags. |
| tpu.SemaphoreSignalOp(capacity_right, c(buffering - 1, i32)) |
| tpu.SemaphoreSignalOp(capacity_left, c(buffering - 1, i32)) |
| |
| with _trace("Init barrier"): |
| # We will only be communicating with neighbors, so we only have to |
| # synchronize with them. |
| barrier_semaphore = tpu.GetBarrierSemaphoreOp() |
| tpu.SemaphoreSignalOp( |
| barrier_semaphore, c(1, i32), device_id=left_neighbor |
| ) |
| tpu.SemaphoreSignalOp( |
| barrier_semaphore, c(1, i32), device_id=right_neighbor |
| ) |
| tpu.SemaphoreWaitOp(barrier_semaphore, c(2, i32)) |
| |
| # Initialize left part of the scratch space. |
| # We don't use a local DMA, because the tiling might be different. |
| init_lhs = vector.LoadOp(lhs_vec(lhs_local.shape), lhs_ref, [c(0), c(0)]) |
| init_lhs = vector.ShapeCastOp(lhs_vec((1, 1, *lhs_local.shape)), init_lhs) |
| vector.StoreOp(init_lhs, scratch, [c(0), c(LEFT), c(0), c(0)]) |
| |
| # Initialize the right part of the scratch space (remote DMA). |
| left_scratch = tpu.MemRefSliceOp( |
| side_transfer_ty, |
| scratch, |
| [c(0, i32), c(LEFT, i32), c(0, i32), c(0, i32)], |
| ) |
| right_scratch = tpu.MemRefSliceOp( |
| side_transfer_ty, |
| scratch, |
| [c(0, i32), c(RIGHT, i32), c(0, i32), c(0, i32)], |
| ) |
| tpu.EnqueueDMAOp( |
| left_scratch, right_scratch, recv_right_sem, |
| source_semaphore=send_right_sem, device_id=right_neighbor |
| ) |
| |
| # Initialize the output. |
| out_init = arith.ConstantOp( |
| out_vec_ty, |
| ir.DenseElementsAttr.get_splat( |
| out_vec_ty, ir.FloatAttr.get_f32(0.0) |
| ), |
| ) |
| vector.StoreOp(out_init, out_ref, [c(0), c(0)]) |
| |
| # Wait for the send to complete, to satisfy the loop invariant below. |
| with _trace("Send right"): |
| tpu.WaitDMAOp(send_right_sem, some_scratch_slice) |
| scf.YieldOp([]) |
| |
| with ir.InsertionPoint(if_init.else_block): |
| # Invariant: left data is here, right data has already been sent. |
| step = arith.SubIOp(step_with_init, c(1, i32)) |
| version_i32 = arith.AndIOp(step, c(buffering - 1, i32)) |
| next_version_i32 = arith.AddIOp(version_i32, c(1, i32)) |
| next_version_i32 = arith.AndIOp(next_version_i32, c(buffering - 1, i32)) |
| |
| last_step = c((ring_len // 2) - 1, i32) |
| more_steps = arith.CmpIOp(pred_ne, step, last_step) |
| |
| if_more_steps = scf.IfOp(more_steps.result) |
| with ir.InsertionPoint(if_more_steps.then_block): |
| with _trace("Capacity left"): |
| shift_left() |
| scf.YieldOp([]) |
| |
| with _trace("Recv right"): |
| tpu.WaitDMAOp(recv_right_sem, some_scratch_slice) |
| |
| if_more_steps = scf.IfOp(more_steps.result) |
| with ir.InsertionPoint(if_more_steps.then_block): |
| with _trace("Capacity right"): |
| shift_right() |
| scf.YieldOp([]) |
| |
| with _trace("Compute"): |
| compute(LEFT, version_i32) |
| compute(RIGHT, version_i32) |
| |
| if_more_steps = scf.IfOp(more_steps.result, hasElse=True) |
| with ir.InsertionPoint(if_more_steps.then_block): |
| with _trace("Send left"): |
| tpu.WaitDMAOp(send_left_sem, some_scratch_slice) |
| signal(LEFT) |
| with _trace("Send right"): |
| tpu.WaitDMAOp(send_right_sem, some_scratch_slice) |
| signal(RIGHT) |
| with _trace("Recv left"): |
| tpu.WaitDMAOp(recv_left_sem, some_scratch_slice) |
| scf.YieldOp([]) |
| with ir.InsertionPoint(if_more_steps.else_block): |
| with _trace("Final barrier"): |
| # We skip the increment in the last iteration to save on latency, |
| # so we only expect to see buffering - 1. |
| tpu.SemaphoreWaitOp(capacity_left, c(buffering - 1, i32)) |
| tpu.SemaphoreWaitOp(capacity_right, c(buffering - 1, i32)) |
| scf.YieldOp([]) |
| |
| scf.YieldOp([]) |
| |
| def rhs_indices( |
| step_with_init, rings_ref, ring_position_ref, func_op, *, is_right |
| ): |
| del rings_ref # Unused. |
| c = _constant_cache(func_op) |
| is_init = arith.CmpIOp(pred_eq, step_with_init, c(0, i32)) |
| step = arith.SelectOp( |
| is_init, c(0, i32), arith.SubIOp(step_with_init, c(1, i32)) |
| ) |
| my_id = arith.IndexCastOp(index, tpu.DeviceIdOp()) |
| my_pos = arith.RemSIOp( |
| memref.LoadOp(ring_position_ref, [my_id]), c(ring_len, i32) |
| ) |
| if is_right: |
| source_pos = arith.SubIOp(my_pos, arith.AddIOp(step, c(1, i32))) |
| is_oob = arith.CmpIOp(pred_lt, source_pos, c(0, i32)) |
| source_pos = arith.SelectOp( |
| is_oob, arith.AddIOp(source_pos, c(ring_len, i32)), source_pos |
| ) |
| else: |
| source_pos = arith.AddIOp(my_pos, step) |
| is_oob = arith.CmpIOp(pred_ge, source_pos, c(ring_len, i32)) |
| source_pos = arith.SelectOp( |
| is_oob, arith.SubIOp(source_pos, c(ring_len, i32)), source_pos |
| ) |
| return (source_pos.result, c(0, i32)) |
| |
| left_rhs_transform_indices = func.FuncOp.from_py_func( |
| i32, |
| ir.MemRefType.get(rings.shape, i32, memory_space=smem), |
| ir.MemRefType.get((rings.size,), i32, memory_space=smem), |
| name="left_rhs_transform_indices", |
| )(lambda *args, func_op: rhs_indices(*args, func_op, is_right=False)) |
| right_rhs_transform_indices = func.FuncOp.from_py_func( |
| i32, |
| ir.MemRefType.get(rings.shape, i32, memory_space=smem), |
| ir.MemRefType.get((rings.size,), i32, memory_space=smem), |
| name="right_rhs_transform_indices", |
| )(lambda *args, func_op: rhs_indices(*args, func_op, is_right=True)) |
| |
| matmul.func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(i64, 2) |
| matmul.func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(i64, 7) |
| # Prologue + steps |
| matmul.func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [1 + (ring_len // 2)] |
| ) |
| matmul.func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| matmul.func_op.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(i) -> (0, 0)>"), |
| }), |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get([k_shard, m]), |
| "transform_indices": ir.FlatSymbolRefAttr.get( |
| "left_rhs_transform_indices" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get([k_shard, m]), |
| "transform_indices": ir.FlatSymbolRefAttr.get( |
| "right_rhs_transform_indices" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(i) -> (0, 0)>"), |
| }), |
| ]) |
| |
| module = ir.Module.create() |
| sym_tab = ir.SymbolTable(module.operation) |
| for f in ( |
| matmul.func_op, |
| left_rhs_transform_indices.func_op, |
| right_rhs_transform_indices.func_op, |
| ): |
| module.body.append(f) |
| sym_tab.insert(f) |
| |
| module.operation.verify() |
| kernel = mosaic.as_tpu_kernel(module, out_local, backend=backend) |
| def kernel_wrapper(lhs, rhs): |
| return kernel( |
| rings, ring_position, lhs, rhs, rhs, collective_id=collective_id |
| ) |
| return jax.jit(kernel_wrapper) |