| """An example LLO kernel implementing matmul for Jellyfish.""" |
| |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax._src import tpu_custom_call |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import func |
| import numpy as np |
| |
| from google3.platforms.xla.mosaic.python.dialects import llo |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 10000 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| n = 1024 |
| tn, tk, tm = 512, 512, 512 |
| assert n % tn == 0 and n % tk == 0 and n % tm == 0 |
| assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0 |
| num_latches_rows = tk // 128 |
| num_latches_cols = tm // 128 |
| latch_row_stride = 8 * num_latches_cols # In vmem words. |
| latch_tile_row_stride = 16 * latch_row_stride |
| num_lhs_subtile_rows = tn // 128 |
| num_lhs_subtile_cols = tk // 128 |
| lhs_row_stride = 8 * num_lhs_subtile_cols |
| lhs_tile_row_stride = 16 * lhs_row_stride |
| out_row_stride = 8 * num_latches_cols |
| out_tile_row_stride = 16 * out_row_stride |
| with ir.Context() as ctx, ir.Location.unknown(): |
| tpu.register_dialect(ctx) |
| llo.register_llo_dialect(ctx) |
| i32 = ir.IntegerType.get_signless(32) |
| f32 = ir.F32Type.get() |
| vmask = ir.VectorType.get((8, 128), ir.IntegerType.get_signless(1)) |
| f32_vreg = ir.VectorType.get((8, 128), f32) |
| i32_vreg = ir.VectorType.get((8, 128), i32) |
| i32_vreg_c0_attr = ir.DenseElementsAttr.get_splat( |
| i32_vreg, ir.IntegerAttr.get(i32, 0)) |
| f32_vreg_c0_attr = ir.DenseElementsAttr.get_splat( |
| f32_vreg, ir.FloatAttr.get_f32(0.0)) |
| |
| @func.FuncOp.from_py_func(i32, i32, i32, i32, i32, i32, name="main") |
| def matmul(i, j, k, xaddr, yaddr, oaddr): # pylint: disable=unused-argument |
| constants = {} |
| def c(val): |
| assert isinstance(val, int) |
| if val not in constants: |
| constants[val] = llo.ConstantOp(i32, ir.IntegerAttr.get(i32, val)) |
| return constants[val] |
| i32_vreg_c0 = llo.ConstantOp(i32_vreg, i32_vreg_c0_attr) |
| f32_vreg_c0 = llo.ConstantOp(f32_vreg, f32_vreg_c0_attr) |
| |
| kv = llo.ScalarToVectorOp(i32_vreg, k) |
| is_first_window_vm = llo.VectorCmpEqS32Op(vmask, kv, i32_vreg_c0) |
| |
| for ln in range(num_latches_rows): |
| for lm in range(num_latches_cols): |
| # Load up the latch |
| latch_base_offset = ln * latch_tile_row_stride + lm * 8 |
| latch_tile_base = llo.ScalarAddressVmemOp(yaddr, c(latch_base_offset)) |
| for i in range(16): |
| yt = llo.VectorLoadOp(f32_vreg, latch_tile_base, |
| displacement=c(i * latch_row_stride)) |
| llo.VectorLatchOp( |
| yt, ir.Attribute.parse("#llo.gain_latch_mode<f32>")) |
| |
| # Recall that num_lhs_subtile_cols == num_latches_rows. |
| lhs_col = ln |
| for lhs_row in range(num_lhs_subtile_rows): |
| lhs_tile_base_offset = lhs_row * lhs_tile_row_stride + lhs_col * 8 |
| lhs_tile_base = llo.ScalarAddressVmemOp( |
| xaddr, c(lhs_tile_base_offset)) |
| out_tile_base_offset = lhs_row * out_tile_row_stride + lm * 8 |
| out_tile_base = llo.ScalarAddressVmemOp( |
| oaddr, c(out_tile_base_offset)) |
| for i in range(16): |
| xt = llo.VectorLoadOp( |
| f32_vreg, lhs_tile_base, displacement=c(i * lhs_row_stride)) |
| llo.VectorMatmulOp( |
| xt, ir.Attribute.parse("#llo.matmul_mode<round>")) |
| zt = llo.VectorMatresOp(f32_vreg) |
| # Blend zt with acc |
| if ln == 0: |
| raw_acc = llo.VectorLoadOp(f32_vreg, out_tile_base, |
| displacement=c(i * out_row_stride)) |
| acc = llo.VectorSelectOp( |
| f32_vreg, is_first_window_vm, f32_vreg_c0, raw_acc) |
| new_acc = llo.VectorAddF32Op(acc, zt) |
| llo.VectorStoreOp(out_tile_base, |
| displacement=c(i * out_row_stride), |
| to_store=new_acc) |
| else: |
| acc = llo.VectorLoadOp( |
| f32_vreg, out_tile_base, displacement=c(i * out_row_stride)) |
| new_acc = llo.VectorAddF32Op(acc, zt) |
| llo.VectorStoreOp(out_tile_base, |
| displacement=c(i * out_row_stride), |
| to_store=new_acc) |
| llo.VectorDoneWithGainsOp() |
| f = matmul.func_op |
| lhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tk//128},1]>") |
| rhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>") |
| out_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>") |
| vmem = ir.Attribute.parse("#tpu.memory_space<vmem>") |
| lhs_memref = ir.MemRefType.get((tn, tk), f32, lhs_layout, vmem) |
| rhs_memref = ir.MemRefType.get((tk, tm), f32, rhs_layout, vmem) |
| out_memref = ir.MemRefType.get((tn, tm), f32, out_layout, vmem) |
| for i, ty in enumerate([i32, i32, i32, lhs_memref, rhs_memref, out_memref]): |
| tpu.private_set_arg_attr(f, i, "llo.type", ir.TypeAttr.get(ty)) |
| if i > 2: |
| tpu.private_set_arg_attr(f, i, "llo.layout", ty.layout) |
| assert f.verify(), f |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [n // tn, n // tm, n // tk]) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": |
| ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"), |
| }) |
| ]) |
| x = jnp.ones((n, n), dtype=jnp.float32) |
| y = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (n, n)) |
| |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| |
| # We use jax.jit to make sure we hit the fast compilation cache. |
| custom_matmul = tpu_custom_call._lowered_as_tpu_kernel( # pylint: disable=protected-access |
| m.operation.get_asm(binary=True), |
| jax.ShapeDtypeStruct(x.shape, x.dtype), |
| serialization_format=None, |
| ) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| np.testing.assert_allclose(custom_matmul(x, y), x @ y) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| custom_matmul(x, y).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| (x @ y).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print("Run-time ratio (custom/XLA): ", d1 / d2) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |