| """An example Mosaic kernel implementing standard matrix multiplication.""" |
| |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| from jax import config |
| from jax import core |
| from jax import random |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import mhlo |
| from mlir.dialects import tensor |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| F32 = VectorFactory(ir.F32Type.get) |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| |
| |
| class RankedTensorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.RankedTensorType.get(idxs, self.elem_thunk()) |
| |
| |
| F32Tensor = RankedTensorFactory(ir.F32Type.get) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 1000 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| n, m, k = 2048, 1024, 512 |
| tn, tk, tm = 512, 512, 512 |
| assert n % tn == 0 and k % tk == 0 and m % tm == 0 |
| |
| with ir.Context() as ctx, ir.Location.unknown(): |
| mhlo.register_mhlo_dialect(ctx) |
| tpu.register_dialect(ctx) |
| mhlo.register_mhlo_passes() |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| |
| @func.FuncOp.from_py_func( |
| i32, |
| i32, |
| i32, |
| ir.MemRefType.get((tn, tk), f32), |
| ir.MemRefType.get((tk, tm), f32), |
| ir.MemRefType.get((tn, tm), f32), |
| name="main", |
| ) |
| def matmul(i, j, k, lhs, rhs, out, func_op): # pylint: disable=unused-argument |
| c0 = arith.ConstantOp( |
| ir.IndexType.get(), ir.IntegerAttr.get(ir.IndexType.get(), 0) |
| ) |
| kv = vector.BroadcastOp(I32[tn, tm], k) |
| output_is_uninitialized = arith.CmpIOp( |
| ir.IntegerAttr.get(i64, 0), |
| kv, |
| arith.ConstantOp( |
| I32[tn, tm], |
| ir.DenseElementsAttr.get_splat( |
| I32[tn, tm], ir.IntegerAttr.get(i32, 0) |
| ), |
| ), |
| ) |
| zero_tile = arith.ConstantOp( |
| F32[tn, tm], |
| ir.DenseElementsAttr.get_splat( |
| F32[tn, tm], ir.FloatAttr.get_f32(0.0) |
| ), |
| ) |
| |
| lhs_tile = vector.LoadOp(F32[tn, tk], lhs, [c0, c0]) |
| rhs_tile = vector.LoadOp(F32[tk, tm], rhs, [c0, c0]) |
| out_tile = arith.SelectOp( |
| output_is_uninitialized, |
| zero_tile, |
| vector.LoadOp(F32[tn, tm], out, [c0, c0]), |
| ) |
| |
| lhs_tensor = tensor.EmptyOp([tn, tk], f32) |
| lhs_tensor = vector.TransferWriteOp( |
| F32Tensor[tn, tk], |
| lhs_tile, |
| lhs_tensor, |
| [c0, c0], |
| ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"), |
| ) |
| |
| rhs_tensor = tensor.EmptyOp([tk, tm], f32) |
| rhs_tensor = vector.TransferWriteOp( |
| F32Tensor[tk, tm], |
| rhs_tile, |
| rhs_tensor, |
| [c0, c0], |
| ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"), |
| ) |
| |
| out_tensor = tensor.EmptyOp([tn, tm], f32) |
| out_tensor = vector.TransferWriteOp( |
| F32Tensor[tn, tm], |
| out_tile, |
| out_tensor, |
| [c0, c0], |
| ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"), |
| ) |
| |
| new_out_tensor = mhlo.DotOp(F32Tensor[tn, tm], lhs_tensor, rhs_tensor) |
| new_out_tensor = mhlo.AddOp(new_out_tensor, out_tensor) |
| |
| padding = arith.ConstantOp(f32, ir.FloatAttr.get(f32, 0)) |
| new_out_tile = vector.TransferReadOp( |
| F32[tn, tm], |
| new_out_tensor, |
| [c0, c0], |
| ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"), |
| padding, |
| ) |
| vector.StoreOp(new_out_tile, out, [c0, c0]) |
| |
| f = matmul.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [n // tn, m // tm, k // tk] |
| ) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m, k) -> (n, k)>" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m, k) -> (k, m)>" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m, k) -> (n, m)>" |
| ), |
| }), |
| ]) |
| assert f.verify(), f |
| |
| module = ir.Module.create() |
| module.body.append(f) |
| ir.SymbolTable(module.operation).insert(f) |
| custom_matmul = mosaic.as_tpu_kernel( |
| module, out_type=core.ShapedArray((n, m), jnp.float32) |
| ) |
| |
| k1, k2 = random.split(random.PRNGKey(1234)) |
| x = random.normal(k1, (n, k), jnp.float32) |
| y = random.normal(k2, (k, m), jnp.float32) |
| x = x.astype(jnp.bfloat16).astype(jnp.float32) |
| y = y.astype(jnp.bfloat16).astype(jnp.float32) |
| |
| # Make sure the implementation is correct + warm-up the caches. |
| np.testing.assert_allclose(custom_matmul(x, y), x @ y) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| custom_matmul(x, y).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| (x @ y).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print("Run-time ratio (custom/XLA): ", d1 / d2) |
| |
| |
| if __name__ == "__main__": |
| config.update("jax_mosaic_allow_hlo", True) |
| app.run(main) |