blob: 9d239032d09d2a2db19ac1d8b18c1e18b0937878 [file] [log] [blame] [edit]
"""An example LLO kernel implementing a fast triangular matmul for Jellyfish.
The current implementation assumes that the RHS is lower triangular. It's
a prototype, which might not necessarily work for non-square matrices and
non-square RHS tiles.
"""
import time
from typing import Sequence
from absl import app
import jax
from jax._src import tpu_custom_call
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import func
from mlir.dialects import scf
import numpy as np
from google3.platforms.xla.mosaic.python.dialects import llo
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 10000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
n = 1024
tn, tk, tm = 512, 512, 512
assert n % tn == 0 and n % tk == 0 and n % tm == 0
assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0
num_latches_rows = tk // 128
num_latches_cols = tm // 128
latch_row_stride = 8 * num_latches_cols # In vmem words.
latch_tile_row_stride = 16 * latch_row_stride
num_lhs_subtile_rows = tn // 128
num_lhs_subtile_cols = tk // 128
lhs_row_stride = 8 * num_lhs_subtile_cols
lhs_tile_row_stride = 16 * lhs_row_stride
out_row_stride = 8 * num_latches_cols
out_tile_row_stride = 16 * out_row_stride
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
llo.register_llo_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
vmask = ir.VectorType.get((8, 128), ir.IntegerType.get_signless(1))
f32_vreg = ir.VectorType.get((8, 128), f32)
i32_vreg = ir.VectorType.get((8, 128), i32)
f32_vreg_c0_attr = ir.DenseElementsAttr.get_splat(
f32_vreg, ir.FloatAttr.get_f32(0.0))
@func.FuncOp.from_py_func(i32, i32, i32, i32, i32, i32, name="main")
def matmul(i, j, k, xaddr, yaddr, oaddr, *, func_op): # pylint: disable=unused-argument
constants = {}
def c(val):
assert isinstance(val, int)
if val not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
constants[val] = llo.ConstantOp(i32, ir.IntegerAttr.get(i32, val))
return constants[val]
f32_vreg_c0 = llo.ConstantOp(f32_vreg, f32_vreg_c0_attr)
# This is taken pretty much exactly from llo_matmul_jf.py, with the
# exception of the two predicates that parameterize the body.
def matmul_tile(should_skip, is_first_tile_window):
for ln in range(num_latches_rows):
for lm in range(num_latches_cols):
if should_skip(ln, lm): continue
# Load up the latch
latch_base_offset = ln * latch_tile_row_stride + lm * 8
latch_tile_base = llo.ScalarAddressVmemOp(
yaddr, c(latch_base_offset))
for i in range(16):
yt = llo.VectorLoadOp(f32_vreg, latch_tile_base,
displacement=c(i * latch_row_stride))
llo.VectorLatchOp(
yt, ir.Attribute.parse("#llo.gain_latch_mode<f32>"))
# Recall that num_lhs_subtile_cols == num_latches_rows.
lhs_col = ln
for lhs_row in range(num_lhs_subtile_rows):
lhs_tile_base_offset = lhs_row * lhs_tile_row_stride + lhs_col * 8
lhs_tile_base = llo.ScalarAddressVmemOp(
xaddr, c(lhs_tile_base_offset))
out_tile_base_offset = lhs_row * out_tile_row_stride + lm * 8
out_tile_base = llo.ScalarAddressVmemOp(
oaddr, c(out_tile_base_offset))
for i in range(16):
xt = llo.VectorLoadOp(
f32_vreg, lhs_tile_base, displacement=c(i * lhs_row_stride))
llo.VectorMatmulOp(
xt, ir.Attribute.parse("#llo.matmul_mode<round>"))
zt = llo.VectorMatresOp(f32_vreg)
# Blend zt with acc
if is_first_tile_window(ln, lm):
raw_acc = llo.VectorLoadOp(f32_vreg, out_tile_base,
displacement=c(i * out_row_stride))
acc = llo.VectorSelectOp(
f32_vreg, is_first_window_vm, f32_vreg_c0, raw_acc)
new_acc = llo.VectorAddF32Op(acc, zt)
llo.VectorStoreOp(out_tile_base,
displacement=c(i * out_row_stride),
to_store=new_acc)
else:
acc = llo.VectorLoadOp(f32_vreg, out_tile_base,
displacement=c(i * out_row_stride))
new_acc = llo.VectorAddF32Op(acc, zt)
llo.VectorStoreOp(out_tile_base,
displacement=c(i * out_row_stride),
to_store=new_acc)
llo.VectorDoneWithGainsOp()
# Note that this differs from the standard matmul! All the major RHS tiles
# above the diagonal do not even initialize the output. The first major
# tile that does is when k == j (instead of k == 0).
kv = llo.ScalarToVectorOp(i32_vreg, k)
jv = llo.ScalarToVectorOp(i32_vreg, j)
is_first_window_vm = llo.VectorCmpEqS32Op(vmask, kv, jv)
rhs_upper_diag = llo.ScalarCmpGtS32Op(j, k).result
if_op = scf.IfOp(rhs_upper_diag, hasElse=True)
with ir.InsertionPoint(if_op.then_block):
scf.YieldOp([]) # Nothing to do!
with ir.InsertionPoint(if_op.else_block):
# We split every major tile into 128x128 sub-tiles, so we can skip
# some of those when we're working on diagonal major tiles.
rhs_diag = llo.ScalarCmpEqS32Op(j, k).result
if_op = scf.IfOp(rhs_diag, hasElse=True)
with ir.InsertionPoint(if_op.then_block):
matmul_tile(should_skip=lambda ln, lm: lm > ln,
is_first_tile_window=lambda ln, lm: ln == lm)
# But the code above at the moment does not handle the case of some
# columns being all zero. This can happen for non-square tiles.
assert num_latches_rows >= num_latches_cols
scf.YieldOp([])
with ir.InsertionPoint(if_op.else_block):
# A major tile under the diagonal gives us no guarantees,
# so we perform a regular matmul.
matmul_tile(should_skip=lambda _, __: False,
is_first_tile_window=lambda ln, _: ln == 0)
scf.YieldOp([])
scf.YieldOp([])
f = matmul.func_op
lhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tk//128},1]>")
rhs_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>")
out_layout = ir.Attribute.parse(f"#tpu.tiled<(8,128),[{tm//128},1]>")
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
lhs_memref = ir.MemRefType.get((tn, tk), f32, lhs_layout, vmem)
rhs_memref = ir.MemRefType.get((tk, tm), f32, rhs_layout, vmem)
out_memref = ir.MemRefType.get((tn, tm), f32, out_layout, vmem)
for i, ty in enumerate([i32, i32, i32, lhs_memref, rhs_memref, out_memref]):
tpu.private_set_arg_attr(f, i, "llo.type", ir.TypeAttr.get(ty))
if i > 2:
tpu.private_set_arg_attr(f, i, "llo.layout", ty.layout)
assert f.verify(), f
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, n // tm, n // tk])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"),
})
])
x = jnp.ones((n, n), dtype=jnp.float32)
y = jnp.tril(jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (n, n)))
m = ir.Module.create()
m.body.append(matmul.func_op)
ir.SymbolTable(m.operation).insert(matmul.func_op)
# We use jax.jit to make sure we hit the fast compilation cache.
custom_matmul = tpu_custom_call._lowered_as_tpu_kernel( # pylint: disable=protected-access
m.operation.get_asm(binary=True),
jax.ShapeDtypeStruct(x.shape, x.dtype),
serialization_format=None,
)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(custom_matmul(x, y), x @ y)
s = time.perf_counter()
for _ in range(num_bench_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA): ", d1 / d2)
if __name__ == "__main__":
app.run(main)