blob: c5e4ba25718483dccd188c835706cdef752b2be8 [file] [log] [blame] [edit]
"""An example Mosaic kernel implementing segmented matrix multiplication.
A segmented matrix multiplication is analogous to a segmented sum: it is given
the arguments, which are composed of contiguous segments, and is expected to
compute a separate output for every segment. In a way, it is a more general
version of batched matrix multiplication, where segments have to be of equal
length.
"""
import time
from typing import Sequence
from absl import app
import jax
from jax import core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import vector
import numpy as np
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
F32 = VectorFactory(ir.F32Type.get)
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
Index = VectorFactory(ir.IndexType.get)
def main(argv: Sequence[str]) -> None:
n = 4096
tn, tk, tm = 1024, 512, 512
# TODO(apaszke): This is the tile size that XLA chooses,
# but we'll need memory padding for that.
# tn = 1368
if len(argv) == 1:
num_bench_iters = 100
elif len(argv) == 2:
if argv[1] == "small":
num_bench_iters = 1
n = 256
tn, tk, tm = 128, 128, 128
else:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
# Sample a mask, you can play with different values of p to see the changes
# in performance.
tile_run_lengths = jnp.array([1, 4] * 8, dtype=jnp.int32)
tile_segments = jnp.repeat(jnp.arange(len(tile_run_lengths), dtype=jnp.int32),
tile_run_lengths)
nk = int(jnp.sum(tile_run_lengths) * tk)
num_k_tiles = (nk // tk)
assert tile_segments.shape == (
num_k_tiles,), f"{tile_segments.shape} != ({num_k_tiles},)"
assert n % tn == 0 and n % tm == 0
assert tn % 8 == 0 and tk % 128 == 0 and tm % 128 == 0
with ir.Context() as ctx, ir.Location.unknown():
tpu.register_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
smem = ir.Attribute.parse("#tpu.memory_space<smem>")
# Our representation here is highly unoptimized, but it's a prototype.
assert num_k_tiles < 2048, num_k_tiles
tile_segments_ty = ir.MemRefType.get((num_k_tiles,), i32, None, smem)
@func.FuncOp.from_py_func(
i32, i32, i32,
tile_segments_ty,
ir.MemRefType.get((tn, tk), f32),
ir.MemRefType.get((tk, tm), f32),
# FIXME(apaszke): This is a lie. It's a 3D buffer. But a harmless lie,
# because the majormost dim is of size 1, so it doesn't matter.
ir.MemRefType.get((tn, tm), f32),
name="main",
)
def matmul(i, j, k, tile_segments, lhs, rhs, out, func_op): # pylint: disable=unused-argument
constants = {}
def c(val):
if val not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
ty = ir.IndexType.get()
constants[val] = arith.ConstantOp(ty, ir.IntegerAttr.get(ty, val))
return constants[val]
zero_minor_tile = arith.ConstantOp(
F32[tn, tm],
ir.DenseElementsAttr.get_splat(
F32[tn, tm], ir.FloatAttr.get_f32(0.0)
),
)
k_idx = arith.IndexCastOp(index, k)
k_is_0 = arith.CmpIOp(ir.IntegerAttr.get(i64, 0), k,
arith.ConstantOp(i32, ir.IntegerAttr.get(i32, 0)))
k_prev_oob = arith.SubIOp(k_idx, c(1))
k_prev = arith.SelectOp(k_is_0, k_idx, k_prev_oob)
segment = memref.LoadOp(tile_segments, [k_idx])
prev_segment = memref.LoadOp(tile_segments, [k_prev])
prev_segment = arith.SelectOp(
k_is_0, arith.ConstantOp(i32, ir.IntegerAttr.get(i32, -1)),
prev_segment)
segment_v = vector.BroadcastOp(I32[tn, tm], segment)
prev_segment_v = vector.BroadcastOp(I32[tn, tm], prev_segment)
changed_segment = arith.CmpIOp(
ir.IntegerAttr.get(i64, 1), segment_v, prev_segment_v)
rhs_tile = vector.LoadOp(F32[tk, tm], rhs, [c(0), c(0)])
lhs_tile = vector.LoadOp(F32[tn, tk], lhs, [c(0), c(0)])
out_tile = vector.LoadOp(F32[tn, tm], out, [c(0), c(0)])
out_tile = arith.SelectOp(changed_segment, zero_minor_tile, out_tile)
new_out_tile = vector.ContractionOp(
F32[tn, tm],
lhs_tile,
rhs_tile,
out_tile,
indexing_maps=ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
]),
)
vector.StoreOp(new_out_tile, out, [c(0), c(0)])
@func.FuncOp.from_py_func(i32, i32, i32, tile_segments_ty)
def out_transform_indices(i, j, k, tile_segments):
k_idx = arith.IndexCastOp(index, k)
segment = memref.LoadOp(tile_segments, [k_idx])
return segment.result, i, j
assert out_transform_indices.func_op.verify(), out_transform_indices.func_op
# Wrap all of those functions in a module.
m = ir.Module.create()
sym_tab = ir.SymbolTable(m.operation)
m.body.append(matmul.func_op)
sym_tab.insert(matmul.func_op)
m.body.append(out_transform_indices.func_op)
sym_tab.insert(out_transform_indices.func_op)
f = matmul.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[n // tn, n // tm, nk // tk])
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<arbitrary>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (n, k)>"),
}),
ir.DictAttr.get({
"transform_indices":
ir.Attribute.parse("affine_map<(n, m, k) -> (k, m)>"),
}),
ir.DictAttr.get({
"window_bounds": ir.DenseI64ArrayAttr.get([1, tn, tm]),
"transform_indices":
ir.FlatSymbolRefAttr.get(out_transform_indices.__name__),
})
])
f.attributes["scalar_prefetch"] = ir.IntegerAttr.get(i64, 1)
assert f.verify(), f
segmented_matmul = mosaic.as_tpu_kernel(
m, out_type=core.ShapedArray((len(tile_run_lengths), n, n),
jnp.float32))
x = jnp.ones((n, nk), dtype=jnp.float32)
y = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (nk, n))
# To make this fully fair, this should only specialize on the number of
# stacked matrices (i.e. len(tile_run_lengths)) and not on the actual run
# lengths. But we can't do that, because dynamic_slice_in_dim requires the
# output shape to be statically known!
tile_run_lengths_list = [int(l) * tk for l in tile_run_lengths]
@jax.jit
def reference(x, y):
s = 0
outs = []
for i, _ in enumerate(tile_run_lengths_list):
xs = x[:, s:s + tile_run_lengths_list[i]]
ys = y[s:s + tile_run_lengths_list[i], :]
outs.append(xs @ ys)
s += tile_run_lengths_list[i]
return jnp.stack(outs, axis=0)
k = -time.perf_counter()
segmented_matmul.lower(tile_segments, x, y).compile()
k += time.perf_counter()
print(f"Custom kernel compile time: {k}s")
r = -time.perf_counter()
reference.lower(x, y).compile()
r += time.perf_counter()
print(f"XLA reference compile time: {r}s")
run_kernel = lambda: segmented_matmul(tile_segments, x, y)
run_reference = lambda: reference(x, y)
# Make sure the implementation is correct + warm-up the caches.
np.testing.assert_allclose(run_kernel(), run_reference())
s = time.perf_counter()
for _ in range(num_bench_iters):
run_kernel().block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
run_reference().block_until_ready()
e = time.perf_counter()
d2 = e - s
print(f"Run-time ratio (custom/XLA): {d1 / d2} (spent {d1 + d2}s"
" benchmarking)")
if __name__ == "__main__":
app.run(main)