| """An example Mosaic kernel implementing segmented matrix multiplication. |
| |
| A segmented matrix multiplication is analogous to a segmented sum: it is given |
| the arguments, which are composed of contiguous segments, and is expected to |
| compute a separate output for every segment. In a way, it is a more general |
| version of batched matrix multiplication, where segments have to be of equal |
| length. |
| """ |
| |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import memref |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| F32 = VectorFactory(ir.F32Type.get) |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| Index = VectorFactory(ir.IndexType.get) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| n = 4096 |
| tn, tk, tm = 1024, 512, 512 |
| # TODO(apaszke): This is the tile size that XLA chooses, |
| # but we'll need memory padding for that. |
| # tn = 1368 |
| |
| if len(argv) == 1: |
| num_bench_iters = 100 |
| elif len(argv) == 2: |
| if argv[1] == "small": |
| num_bench_iters = 1 |
| n = 256 |
| tn, tk, tm = 128, 128, 128 |
| else: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| # Sample a mask, you can play with different values of p to see the changes |
| # in performance. |
| tile_run_lengths = jnp.array([1, 4] * 8, dtype=jnp.int32) |
| tile_segments = jnp.repeat(jnp.arange(len(tile_run_lengths), dtype=jnp.int32), |
| tile_run_lengths) |
| nk = int(jnp.sum(tile_run_lengths) * tk) |
| num_k_tiles = (nk // tk) |
| assert tile_segments.shape == ( |
| num_k_tiles,), f"{tile_segments.shape} != ({num_k_tiles},)" |
| |
| assert n % tn == 0 and n % tm == 0 |
| assert tn % 8 == 0 and tk % 128 == 0 and tm % 128 == 0 |
| |
| with ir.Context() as ctx, ir.Location.unknown(): |
| tpu.register_dialect(ctx) |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| index = ir.IndexType.get() |
| smem = ir.Attribute.parse("#tpu.memory_space<smem>") |
| |
| # Our representation here is highly unoptimized, but it's a prototype. |
| assert num_k_tiles < 2048, num_k_tiles |
| tile_segments_ty = ir.MemRefType.get((num_k_tiles,), i32, None, smem) |
| |
| @func.FuncOp.from_py_func( |
| i32, i32, i32, |
| tile_segments_ty, |
| ir.MemRefType.get((tn, tk), f32), |
| ir.MemRefType.get((tk, tm), f32), |
| # FIXME(apaszke): This is a lie. It's a 3D buffer. But a harmless lie, |
| # because the majormost dim is of size 1, so it doesn't matter. |
| ir.MemRefType.get((tn, tm), f32), |
| name="main", |
| ) |
| def matmul(i, j, k, tile_segments, lhs, rhs, out, func_op): # pylint: disable=unused-argument |
| constants = {} |
| def c(val): |
| if val not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| ty = ir.IndexType.get() |
| constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val)) |
| return constants[val] |
| |
| zero_minor_tile = arith.ConstantOp( |
| F32[tn, tm], |
| ir.DenseElementsAttr.get_splat( |
| F32[tn, tm], ir.FloatAttr.get_f32(0.0) |
| ), |
| ) |
| |
| k_idx = arith.IndexCastOp(index, k) |
| k_is_0 = arith.CmpIOp(ir.IntegerAttr.get(i64, 0), k, |
| arith.ConstantOp(i32, ir.IntegerAttr.get(i32, 0))) |
| k_prev_oob = arith.SubIOp(k_idx, c(1)) |
| k_prev = arith.SelectOp(k_is_0, k_idx, k_prev_oob) |
| segment = memref.LoadOp(tile_segments, [k_idx]) |
| prev_segment = memref.LoadOp(tile_segments, [k_prev]) |
| prev_segment = arith.SelectOp( |
| k_is_0, arith.ConstantOp(i32, ir.IntegerAttr.get(i32, -1)), |
| prev_segment) |
| segment_v = vector.BroadcastOp(I32[tn, tm], segment) |
| prev_segment_v = vector.BroadcastOp(I32[tn, tm], prev_segment) |
| changed_segment = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 1), segment_v, prev_segment_v) |
| |
| rhs_tile = vector.LoadOp(F32[tk, tm], rhs, [c(0), c(0)]) |
| lhs_tile = vector.LoadOp(F32[tn, tk], lhs, [c(0), c(0)]) |
| out_tile = vector.LoadOp(F32[tn, tm], out, [c(0), c(0)]) |
| out_tile = arith.SelectOp(changed_segment, zero_minor_tile, out_tile) |
| new_out_tile = vector.ContractionOp( |
| F32[tn, tm], |
| lhs_tile, |
| rhs_tile, |
| out_tile, |
| indexing_maps=ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]), |
| iterator_types=ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]), |
| ) |
| vector.StoreOp(new_out_tile, out, [c(0), c(0)]) |
| |
| @func.FuncOp.from_py_func(i32, i32, i32, tile_segments_ty) |
| def out_transform_indices(i, j, k, tile_segments): |
| k_idx = arith.IndexCastOp(index, k) |
| segment = memref.LoadOp(tile_segments, [k_idx]) |
| return segment.result, i, j |
| assert out_transform_indices.func_op.verify(), out_transform_indices.func_op |
| |
| # Wrap all of those functions in a module. |
| m = ir.Module.create() |
| sym_tab = ir.SymbolTable(m.operation) |
| m.body.append(matmul.func_op) |
| sym_tab.insert(matmul.func_op) |
| m.body.append(out_transform_indices.func_op) |
| sym_tab.insert(out_transform_indices.func_op) |
| |
| f = matmul.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [n // tn, n // tm, nk // tk]) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"), |
| }), |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get([1, tn, tm]), |
| "transform_indices": |
| ir.FlatSymbolRefAttr.get(out_transform_indices.__name__), |
| }) |
| ]) |
| f.attributes["scalar_prefetch"] = ir.IntegerAttr.get(i64, 1) |
| assert f.verify(), f |
| |
| segmented_matmul = mosaic.as_tpu_kernel( |
| m, out_type=core.ShapedArray((len(tile_run_lengths), n, n), |
| jnp.float32)) |
| |
| x = jnp.ones((n, nk), dtype=jnp.float32) |
| y = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (nk, n)) |
| |
| # To make this fully fair, this should only specialize on the number of |
| # stacked matrices (i.e. len(tile_run_lengths)) and not on the actual run |
| # lengths. But we can't do that, because dynamic_slice_in_dim requires the |
| # output shape to be statically known! |
| tile_run_lengths_list = [int(l) * tk for l in tile_run_lengths] |
| @jax.jit |
| def reference(x, y): |
| s = 0 |
| outs = [] |
| for i, _ in enumerate(tile_run_lengths_list): |
| xs = x[:, s:s + tile_run_lengths_list[i]] |
| ys = y[s:s + tile_run_lengths_list[i], :] |
| outs.append(xs @ ys) |
| s += tile_run_lengths_list[i] |
| return jnp.stack(outs, axis=0) |
| |
| k = -time.perf_counter() |
| segmented_matmul.lower(tile_segments, x, y).compile() |
| k += time.perf_counter() |
| print(f"Custom kernel compile time: {k}s") |
| r = -time.perf_counter() |
| reference.lower(x, y).compile() |
| r += time.perf_counter() |
| print(f"XLA reference compile time: {r}s") |
| |
| run_kernel = lambda: segmented_matmul(tile_segments, x, y) |
| run_reference = lambda: reference(x, y) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| np.testing.assert_allclose(run_kernel(), run_reference()) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| run_kernel().block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| run_reference().block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print(f"Run-time ratio (custom/XLA): {d1 / d2} (spent {d1 + d2}s" |
| " benchmarking)") |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |