blob: 588bf55f36d103159c2f261545ac39f8fffeb2c0 [file] [log] [blame] [edit]
"""An example Mosaic kernel implementing standard matrix multiplication."""
import time
from typing import Sequence
from absl import app
from jax import config
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import mhlo
from mlir.dialects import tensor
from mlir.dialects import vector
import numpy as np
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = VectorFactory(ir.F32Type.get)
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
class RankedTensorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.RankedTensorType.get(idxs, self.elem_thunk())
F32Tensor = RankedTensorFactory(ir.F32Type.get)
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 1000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
n, m, k = 2048, 1024, 512
tn, tk, tm = 512, 512, 512
assert n % tn == 0 and k % tk == 0 and m % tm == 0
with ir.Context() as ctx, ir.Location.unknown():
mhlo.register_mhlo_dialect(ctx)
tpu.register_dialect(ctx)
mhlo.register_mhlo_passes()
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(
i32,
i32,
i32,
ir.MemRefType.get((tn, tk), f32),
ir.MemRefType.get((tk, tm), f32),
ir.MemRefType.get((tn, tm), f32),
name="main",
)
def matmul(i, j, k, lhs, rhs, out, func_op): # pylint: disable=unused-argument
c0 = arith.ConstantOp(
ir.IndexType.get(), ir.IntegerAttr.get(ir.IndexType.get(), 0)
)
kv = vector.BroadcastOp(I32[tn, tm], k)
output_is_uninitialized = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0),
kv,
arith.ConstantOp(
I32[tn, tm],
ir.DenseElementsAttr.get_splat(
I32[tn, tm], ir.IntegerAttr.get(i32, 0)
),
),
)
zero_tile = arith.ConstantOp(
F32[tn, tm],
ir.DenseElementsAttr.get_splat(
F32[tn, tm], ir.FloatAttr.get_f32(0.0)
),
)
lhs_tile = vector.LoadOp(F32[tn, tk], lhs, [c0, c0])
rhs_tile = vector.LoadOp(F32[tk, tm], rhs, [c0, c0])
out_tile = arith.SelectOp(
output_is_uninitialized,
zero_tile,
vector.LoadOp(F32[tn, tm], out, [c0, c0]),
)
lhs_tensor = tensor.EmptyOp([tn, tk], f32)
lhs_tensor = vector.TransferWriteOp(
F32Tensor[tn, tk],
lhs_tile,
lhs_tensor,
[c0, c0],
ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"),
)
rhs_tensor = tensor.EmptyOp([tk, tm], f32)
rhs_tensor = vector.TransferWriteOp(
F32Tensor[tk, tm],
rhs_tile,
rhs_tensor,
[c0, c0],
ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"),
)
out_tensor = tensor.EmptyOp([tn, tm], f32)
out_tensor = vector.TransferWriteOp(
F32Tensor[tn, tm],
out_tile,
out_tensor,
[c0, c0],
ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"),
)
new_out_tensor = mhlo.DotOp(F32Tensor[tn, tm], lhs_tensor, rhs_tensor)
new_out_tensor = mhlo.AddOp(new_out_tensor, out_tensor)
padding = arith.ConstantOp(f32, ir.FloatAttr.get(f32, 0))
new_out_tile = vector.TransferReadOp(
F32[tn, tm],
new_out_tensor,
[c0, c0],
ir.Attribute.parse("affine_map<(i, j) -> (i, j)>"),
padding,
)
vector.StoreOp(new_out_tile, out, [c0, c0])
f = matmul.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, m // tm, k // tk]
)
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m, k) -> (n, k)>"
),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m, k) -> (k, m)>"
),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m, k) -> (n, m)>"
),
}),
])
assert f.verify(), f
module = ir.Module.create()
module.body.append(f)
ir.SymbolTable(module.operation).insert(f)
custom_matmul = mosaic.as_tpu_kernel(
module, out_type=core.ShapedArray((n, m), jnp.float32)
)
k1, k2 = random.split(random.PRNGKey(1234))
x = random.normal(k1, (n, k), jnp.float32)
y = random.normal(k2, (k, m), jnp.float32)
x = x.astype(jnp.bfloat16).astype(jnp.float32)
y = y.astype(jnp.bfloat16).astype(jnp.float32)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(custom_matmul(x, y), x @ y)
s = time.perf_counter()
for _ in range(num_bench_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA): ", d1 / d2)
if __name__ == "__main__":
config.update("jax_mosaic_allow_hlo", True)
app.run(main)