| """An example Mosaic kernel implementing a fast triangular matmul. |
| |
| The kernel assumes that RHS of multiplication is lower-triangular. |
| """ |
| |
| import itertools |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| from jax import core |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import scf |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| F32 = VectorFactory(ir.F32Type.get) |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 1000 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| n = 1024 |
| tn, tk, tm = 512, 512, 512 |
| assert n % tn == 0 and n % tk == 0 and n % tm == 0 |
| assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0 |
| |
| rhs_minor_tiles_shape = (tk // 128, tm // 128) |
| ndrange = lambda dims: itertools.product(*map(range, dims)) |
| |
| with ir.Context() as ctx, ir.Location.unknown(): |
| tpu.register_dialect(ctx) |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| |
| @func.FuncOp.from_py_func( |
| i32, i32, i32, |
| ir.MemRefType.get((tn, tk), f32), |
| ir.MemRefType.get((tk, tm), f32), |
| ir.MemRefType.get((tn, tm), f32), |
| name="main", |
| ) |
| def matmul(i, j, k, lhs, rhs, out, func_op): # pylint: disable=unused-argument |
| constants = {} |
| def c(val): |
| if val not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| ty = ir.IndexType.get() |
| constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val)) |
| return constants[val] |
| |
| def mul_by_rhs_tile(rhs_minor_row, rhs_minor_col, |
| output_is_uninitialized): |
| rhs_tile = vector.LoadOp( |
| F32[128, 128], rhs, |
| [c(rhs_minor_row * 128), c(rhs_minor_col * 128)]) |
| lhs_tile = vector.LoadOp( |
| F32[tn, 128], lhs, |
| [c(0), c(rhs_minor_row * 128)]) |
| out_tile = vector.LoadOp( |
| F32[tn, 128], out, |
| [c(0), c(rhs_minor_col * 128)]) |
| if output_is_uninitialized is not None: |
| out_tile = arith.SelectOp( |
| output_is_uninitialized, zero_minor_tile, out_tile) |
| new_out_tile = vector.ContractionOp( |
| F32[tn, 128], lhs_tile, rhs_tile, out_tile, |
| indexing_maps=ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]), |
| iterator_types=ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]) |
| ) |
| vector.StoreOp( |
| new_out_tile, out, |
| [c(0), c(rhs_minor_col * 128)]) |
| |
| kv = vector.BroadcastOp(I32[tn, 128], k) |
| jv = vector.BroadcastOp(I32[tn, 128], j) |
| output_is_uninitialized = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 0), kv, jv) |
| zero_minor_tile = arith.ConstantOp( |
| F32[tn, 128], |
| ir.DenseElementsAttr.get_splat(F32[tn, 128], |
| ir.FloatAttr.get_f32(0.0))) |
| |
| # Note that we don't do anything if we are above the diagonal! |
| rhs_major_tile_on_diag = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 0), j, k) # j == k |
| on_diag_if = scf.IfOp(rhs_major_tile_on_diag.result) |
| with ir.InsertionPoint(on_diag_if.then_block): |
| for (rhs_minor_row, rhs_minor_col) in ndrange(rhs_minor_tiles_shape): |
| # We skip any minor tiles above the diagonal. |
| if rhs_minor_col > rhs_minor_row: continue |
| # Masking might only be necessary when we're on the diagonal, since |
| # those are the first blocks that might initialize the output. |
| mul_by_rhs_tile( |
| rhs_minor_row, rhs_minor_col, output_is_uninitialized |
| if rhs_minor_row == rhs_minor_col else None) |
| scf.YieldOp([]) |
| rhs_major_tile_below_diag = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 2), j, k) # j < k |
| below_diag_if = scf.IfOp(rhs_major_tile_below_diag.result) |
| with ir.InsertionPoint(below_diag_if.then_block): |
| for (rhs_minor_row, rhs_minor_col) in ndrange(rhs_minor_tiles_shape): |
| # Masking might only be necessary when we're processing the first row. |
| mul_by_rhs_tile( |
| rhs_minor_row, rhs_minor_col, |
| output_is_uninitialized if rhs_minor_row == 0 else None) |
| scf.YieldOp([]) |
| |
| @func.FuncOp.from_py_func(i32, i32, i32) |
| def rhs_transform_indices(i, j, k): # pylint: disable=unused-argument |
| rhs_upper_diag = arith.CmpIOp(ir.IntegerAttr.get(i64, 4), j, k) |
| if_op = scf.IfOp(rhs_upper_diag.result, hasElse=True, results_=[i32] * 2) |
| with ir.InsertionPoint(if_op.then_block): |
| scf.YieldOp([j, j]) # Preload the next useful RHS tile. |
| with ir.InsertionPoint(if_op.else_block): |
| scf.YieldOp([k, j]) |
| return if_op |
| assert rhs_transform_indices.func_op.verify(), rhs_transform_indices.func_op |
| |
| @func.FuncOp.from_py_func(i32, i32, i32) |
| def lhs_transform_indices(i, j, k): |
| rhs_upper_diag = arith.CmpIOp(ir.IntegerAttr.get(i64, 4), j, k) |
| if_op = scf.IfOp(rhs_upper_diag.result, hasElse=True, results_=[i32] * 2) |
| with ir.InsertionPoint(if_op.then_block): |
| scf.YieldOp([i, j]) # Preload the next useful LHS tile. |
| with ir.InsertionPoint(if_op.else_block): |
| scf.YieldOp([i, k]) |
| return if_op |
| |
| # Wrap all of those functions in a module. |
| m = ir.Module.create() |
| sym_tab = ir.SymbolTable(m.operation) |
| m.body.append(matmul.func_op) |
| sym_tab.insert(matmul.func_op) |
| m.body.append(rhs_transform_indices.func_op) |
| sym_tab.insert(rhs_transform_indices.func_op) |
| m.body.append(lhs_transform_indices.func_op) |
| sym_tab.insert(lhs_transform_indices.func_op) |
| |
| f = matmul.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [n // tn, n // tm, n // tk]) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get([tn, tk]), |
| "transform_indices": |
| ir.FlatSymbolRefAttr.get(lhs_transform_indices.__name__), |
| }), |
| ir.DictAttr.get({ |
| "window_bounds": ir.DenseI64ArrayAttr.get([tk, tm]), |
| "transform_indices": |
| ir.FlatSymbolRefAttr.get(rhs_transform_indices.__name__), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"), |
| }) |
| ]) |
| assert f.verify(), f |
| |
| custom_matmul = mosaic.as_tpu_kernel( |
| m, out_type=core.ShapedArray((n, n), jnp.float32)) |
| |
| x = jnp.ones((n, n), dtype=jnp.float32) |
| y = jnp.tril(jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (n, n))) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| np.testing.assert_allclose(custom_matmul(x, y), x @ y) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| custom_matmul(x, y).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| (x @ y).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print("Run-time ratio (custom/XLA): ", d1 / d2) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |