blob: c709728fa4bb48c4dd65ffd098c502bdd898af47 [file] [log] [blame] [edit]
"""An example Mosaic kernel implementing a fast triangular matmul.
The kernel assumes that RHS of multiplication is lower-triangular.
"""
import itertools
import time
from typing import Sequence
from absl import app
from jax import core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import scf
from mlir.dialects import vector
import numpy as np
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = VectorFactory(ir.F32Type.get)
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 1000
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
n = 1024
tn, tk, tm = 512, 512, 512
assert n % tn == 0 and n % tk == 0 and n % tm == 0
assert tn % 128 == 0 and tk % 128 == 0 and tm % 128 == 0
rhs_minor_tiles_shape = (tk // 128, tm // 128)
ndrange = lambda dims: itertools.product(*map(range, dims))
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(
i32, i32, i32,
ir.MemRefType.get((tn, tk), f32),
ir.MemRefType.get((tk, tm), f32),
ir.MemRefType.get((tn, tm), f32),
name="main",
)
def matmul(i, j, k, lhs, rhs, out, func_op): # pylint: disable=unused-argument
constants = {}
def c(val):
if val not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
ty = ir.IndexType.get()
constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val))
return constants[val]
def mul_by_rhs_tile(rhs_minor_row, rhs_minor_col,
output_is_uninitialized):
rhs_tile = vector.LoadOp(
F32[128, 128], rhs,
[c(rhs_minor_row * 128), c(rhs_minor_col * 128)])
lhs_tile = vector.LoadOp(
F32[tn, 128], lhs,
[c(0), c(rhs_minor_row * 128)])
out_tile = vector.LoadOp(
F32[tn, 128], out,
[c(0), c(rhs_minor_col * 128)])
if output_is_uninitialized is not None:
out_tile = arith.SelectOp(
output_is_uninitialized, zero_minor_tile, out_tile)
new_out_tile = vector.ContractionOp(
F32[tn, 128], lhs_tile, rhs_tile, out_tile,
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
])
)
vector.StoreOp(
new_out_tile, out,
[c(0), c(rhs_minor_col * 128)])
kv = vector.BroadcastOp(I32[tn, 128], k)
jv = vector.BroadcastOp(I32[tn, 128], j)
output_is_uninitialized = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0), kv, jv)
zero_minor_tile = arith.ConstantOp(
F32[tn, 128],
ir.DenseElementsAttr.get_splat(F32[tn, 128],
ir.FloatAttr.get_f32(0.0)))
# Note that we don't do anything if we are above the diagonal!
rhs_major_tile_on_diag = arith.CmpIOp(
ir.IntegerAttr.get(i64, 0), j, k) # j == k
on_diag_if = scf.IfOp(rhs_major_tile_on_diag.result)
with ir.InsertionPoint(on_diag_if.then_block):
for (rhs_minor_row, rhs_minor_col) in ndrange(rhs_minor_tiles_shape):
# We skip any minor tiles above the diagonal.
if rhs_minor_col > rhs_minor_row: continue
# Masking might only be necessary when we're on the diagonal, since
# those are the first blocks that might initialize the output.
mul_by_rhs_tile(
rhs_minor_row, rhs_minor_col, output_is_uninitialized
if rhs_minor_row == rhs_minor_col else None)
scf.YieldOp([])
rhs_major_tile_below_diag = arith.CmpIOp(
ir.IntegerAttr.get(i64, 2), j, k) # j < k
below_diag_if = scf.IfOp(rhs_major_tile_below_diag.result)
with ir.InsertionPoint(below_diag_if.then_block):
for (rhs_minor_row, rhs_minor_col) in ndrange(rhs_minor_tiles_shape):
# Masking might only be necessary when we're processing the first row.
mul_by_rhs_tile(
rhs_minor_row, rhs_minor_col,
output_is_uninitialized if rhs_minor_row == 0 else None)
scf.YieldOp([])
@func.FuncOp.from_py_func(i32, i32, i32)
def rhs_transform_indices(i, j, k): # pylint: disable=unused-argument
rhs_upper_diag = arith.CmpIOp(ir.IntegerAttr.get(i64, 4), j, k)
if_op = scf.IfOp(rhs_upper_diag.result, hasElse=True, results_=[i32] * 2)
with ir.InsertionPoint(if_op.then_block):
scf.YieldOp([j, j]) # Preload the next useful RHS tile.
with ir.InsertionPoint(if_op.else_block):
scf.YieldOp([k, j])
return if_op
assert rhs_transform_indices.func_op.verify(), rhs_transform_indices.func_op
@func.FuncOp.from_py_func(i32, i32, i32)
def lhs_transform_indices(i, j, k):
rhs_upper_diag = arith.CmpIOp(ir.IntegerAttr.get(i64, 4), j, k)
if_op = scf.IfOp(rhs_upper_diag.result, hasElse=True, results_=[i32] * 2)
with ir.InsertionPoint(if_op.then_block):
scf.YieldOp([i, j]) # Preload the next useful LHS tile.
with ir.InsertionPoint(if_op.else_block):
scf.YieldOp([i, k])
return if_op
# Wrap all of those functions in a module.
m = ir.Module.create()
sym_tab = ir.SymbolTable(m.operation)
m.body.append(matmul.func_op)
sym_tab.insert(matmul.func_op)
m.body.append(rhs_transform_indices.func_op)
sym_tab.insert(rhs_transform_indices.func_op)
m.body.append(lhs_transform_indices.func_op)
sym_tab.insert(lhs_transform_indices.func_op)
f = matmul.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, n // tm, n // tk])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get([tn, tk]),
"transform_indices":
ir.FlatSymbolRefAttr.get(lhs_transform_indices.__name__),
}),
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get([tk, tm]),
"transform_indices":
ir.FlatSymbolRefAttr.get(rhs_transform_indices.__name__),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, m)>"),
})
])
assert f.verify(), f
custom_matmul = mosaic.as_tpu_kernel(
m, out_type=core.ShapedArray((n, n), jnp.float32))
x = jnp.ones((n, n), dtype=jnp.float32)
y = jnp.tril(jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (n, n)))
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(custom_matmul(x, y), x @ y)
s = time.perf_counter()
for _ in range(num_bench_iters):
custom_matmul(x, y).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
(x @ y).block_until_ready()
e = time.perf_counter()
d2 = e - s
print("Run-time ratio (custom/XLA): ", d1 / d2)
if __name__ == "__main__":
app.run(main)