blob: a93848e70ea6caf4c25f1737cff7b22706b37827 [file] [log] [blame] [edit]
"""An example Mosaic kernel implementing standard matrix multiplication."""
from collections.abc import Callable, Sequence
import time
from absl import app
import jax
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import vector
import numpy as np
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = VectorFactory(ir.F32Type.get)
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
def create_matmul(
n: int, m: int, k: int, tn: int, tm: int, tk: int
) -> Callable[[jax.typing.ArrayLike, jax.typing.ArrayLike], jax.Array]:
"""Returns a callable that performs matmul, implemented through Mosaic."""
assert n % tn == 0 and k % tk == 0 and m % tm == 0
assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(
i32, i32, i32,
ir.MemRefType.get((tn, tk), f32),
ir.MemRefType.get((tk, tm), f32),
ir.MemRefType.get((tn, tm), f32),
name="main",
)
def matmul(i, j, k, lhs, rhs, out, func_op): # pylint: disable=unused-argument
c0 = arith.ConstantOp(
ir.IndexType.get(), ir.IntegerAttr.get(ir.IndexType.get(), 0)
)
kv = vector.BroadcastOp(I32[tn, tm], k)
output_is_uninitialized = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0),
kv,
arith.ConstantOp(
I32[tn, tm],
ir.DenseElementsAttr.get_splat(
I32[tn, tm], ir.IntegerAttr.get(i32, 0)
),
),
)
zero_minor_tile = arith.ConstantOp(
F32[tn, tm],
ir.DenseElementsAttr.get_splat(
F32[tn, tm], ir.FloatAttr.get_f32(0.0)
),
)
rhs_tile = vector.LoadOp(F32[tk, tm], rhs, [c0, c0])
lhs_tile = vector.LoadOp(F32[tn, tk], lhs, [c0, c0])
out_tile = vector.LoadOp(F32[tn, tm], out, [c0, c0])
out_tile = arith.SelectOp(
output_is_uninitialized, zero_minor_tile, out_tile
)
new_out_tile = vector.ContractionOp(
F32[tn, tm],
lhs_tile,
rhs_tile,
out_tile,
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
]),
)
vector.StoreOp(new_out_tile, out, [c0, c0])
f = matmul.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, m // tm, k // tk])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"),
})
])
assert f.verify(), f
module = ir.Module.create()
module.body.append(f)
ir.SymbolTable(module.operation).insert(f)
return mosaic.as_tpu_kernel(
module, out_type=core.ShapedArray((n, m), jnp.float32)
)
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 1000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
n, m, k = 2048, 1024, 512
tn, tk, tm = 512, 512, 512
custom_matmul = create_matmul(n, m, k, tn, tm, tk)
k1, k2 = random.split(random.PRNGKey(1234))
x = random.normal(k1, (n, k), jnp.float32)
y = random.normal(k2, (k, m), jnp.float32)
x = x.astype(jnp.bfloat16).astype(jnp.float32)
y = y.astype(jnp.bfloat16).astype(jnp.float32)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(custom_matmul(x, y), x @ y)
s = time.perf_counter()
for _ in range(num_bench_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA): ", d1 / d2)
if __name__ == "__main__":
app.run(main)