blob: 121651c671fc9831339375f0c6129ee342c29b70 [file] [log] [blame]
"""An example LLO kernel implementing matmul for Jellyfish."""
import time
from typing import Sequence
from absl import app
import jax
from jax._src import tpu_custom_call
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import func
import numpy as np
from google3.platforms.xla.mosaic.python.dialects import llo
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 10000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
n = 1024
tn, tk, tm = 512, 512, 512
assert n % tn == 0 and n % tk == 0 and n % tm == 0
assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0
num_latches_rows = tk // 128
num_latches_cols = tm // 128
latch_row_stride = 8 * num_latches_cols # In vmem words.
latch_tile_row_stride = 16 * latch_row_stride
num_lhs_subtile_rows = tn // 128
num_lhs_subtile_cols = tk // 128
lhs_row_stride = 8 * num_lhs_subtile_cols
lhs_tile_row_stride = 16 * lhs_row_stride
out_row_stride = 8 * num_latches_cols
out_tile_row_stride = 16 * out_row_stride
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
llo.register_llo_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
vmask = ir.VectorType.get((8, 128), ir.IntegerType.get_signless(1))
f32_vreg = ir.VectorType.get((8, 128), f32)
i32_vreg = ir.VectorType.get((8, 128), i32)
i32_vreg_c0_attr = ir.DenseElementsAttr.get_splat(
i32_vreg, ir.IntegerAttr.get(i32, 0))
f32_vreg_c0_attr = ir.DenseElementsAttr.get_splat(
f32_vreg, ir.FloatAttr.get_f32(0.0))
@func.FuncOp.from_py_func(i32, i32, i32, i32, i32, i32, name="main")
def matmul(i, j, k, xaddr, yaddr, oaddr): # pylint: disable=unused-argument
constants = {}
def c(val):
assert isinstance(val, int)
if val not in constants:
constants[val] = llo.ConstantOp(i32, ir.IntegerAttr.get(i32, val))
return constants[val]
i32_vreg_c0 = llo.ConstantOp(i32_vreg, i32_vreg_c0_attr)
f32_vreg_c0 = llo.ConstantOp(f32_vreg, f32_vreg_c0_attr)
kv = llo.ScalarToVectorOp(i32_vreg, k)
is_first_window_vm = llo.VectorCmpEqS32Op(vmask, kv, i32_vreg_c0)
for ln in range(num_latches_rows):
for lm in range(num_latches_cols):
# Load up the latch
latch_base_offset = ln * latch_tile_row_stride + lm * 8
latch_tile_base = llo.ScalarAddressVmemOp(yaddr, c(latch_base_offset))
for i in range(16):
yt = llo.VectorLoadOp(f32_vreg, latch_tile_base,
displacement=c(i * latch_row_stride))
llo.VectorLatchOp(
yt, ir.Attribute.parse("#llo.gain_latch_mode<f32>"))
# Recall that num_lhs_subtile_cols == num_latches_rows.
lhs_col = ln
for lhs_row in range(num_lhs_subtile_rows):
lhs_tile_base_offset = lhs_row * lhs_tile_row_stride + lhs_col * 8
lhs_tile_base = llo.ScalarAddressVmemOp(
xaddr, c(lhs_tile_base_offset))
out_tile_base_offset = lhs_row * out_tile_row_stride + lm * 8
out_tile_base = llo.ScalarAddressVmemOp(
oaddr, c(out_tile_base_offset))
for i in range(16):
xt = llo.VectorLoadOp(
f32_vreg, lhs_tile_base, displacement=c(i * lhs_row_stride))
llo.VectorMatmulOp(
xt, ir.Attribute.parse("#llo.matmul_mode<round>"))
zt = llo.VectorMatresOp(f32_vreg)
# Blend zt with acc
if ln == 0:
raw_acc = llo.VectorLoadOp(f32_vreg, out_tile_base,
displacement=c(i * out_row_stride))
acc = llo.VectorSelectOp(
f32_vreg, is_first_window_vm, f32_vreg_c0, raw_acc)
new_acc = llo.VectorAddF32Op(acc, zt)
llo.VectorStoreOp(out_tile_base,
displacement=c(i * out_row_stride),
to_store=new_acc)
else:
acc = llo.VectorLoadOp(
f32_vreg, out_tile_base, displacement=c(i * out_row_stride))
new_acc = llo.VectorAddF32Op(acc, zt)
llo.VectorStoreOp(out_tile_base,
displacement=c(i * out_row_stride),
to_store=new_acc)
llo.VectorDoneWithGainsOp()
f = matmul.func_op
lhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tk//128},1]>")
rhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>")
out_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>")
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
lhs_memref = ir.MemRefType.get((tn, tk), f32, lhs_layout, vmem)
rhs_memref = ir.MemRefType.get((tk, tm), f32, rhs_layout, vmem)
out_memref = ir.MemRefType.get((tn, tm), f32, out_layout, vmem)
for i, ty in enumerate([i32, i32, i32, lhs_memref, rhs_memref, out_memref]):
tpu.private_set_arg_attr(f, i, "llo.type", ir.TypeAttr.get(ty))
if i > 2:
tpu.private_set_arg_attr(f, i, "llo.layout", ty.layout)
assert f.verify(), f
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, n // tm, n // tk])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"),
})
])
x = jnp.ones((n, n), dtype=jnp.float32)
y = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (n, n))
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
# We use jax.jit to make sure we hit the fast compilation cache.
custom_matmul = tpu_custom_call._lowered_as_tpu_kernel( # pylint: disable=protected-access
m.operation.get_asm(binary=True),
jax.ShapeDtypeStruct(x.shape, x.dtype),
serialization_format=None,
)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(custom_matmul(x, y), x @ y)
s = time.perf_counter()
for _ in range(num_bench_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA): ", d1 / d2)
if __name__ == "__main__":
app.run(main)