blob: 427edef8864223069fefec5e3f93b5b031de81dc [file] [log] [blame] [edit]
"""A Mosaic kernel implementing the Gather 2D op."""
import functools
import time
from typing import Sequence
from absl import app
import jax
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import memref
import numpy as np
from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
IDX = VectorFactory(ir.IndexType.get)
@functools.partial(jax.jit, static_argnames=["indices_tile_size", "embed_size"])
def gather2d(
operand: jax.Array,
indices: jax.Array,
indices_tile_size: int,
embed_size: int,
) -> jax.Array:
"""Builds the Gather 2D kernel.
Arguments:
operand: The operand from which to gather elements.
indices: The indices of the operand elements to gather.
indices_tile_size: The number of elements to gather per Mosaic tile.
embed_size: The number of elements in the inner dimension of the operand.
Returns:
A JAX function implementing the Gather1D kernel.
"""
with ir.Context() as ctx, ir.Location.unknown():
# The TPU dialect is required for its TPU_DimensionSemantics
tpu.register_dialect(ctx)
sc_tpu.register_dialect(ctx)
operand_size = operand.shape[0]
indices_size = indices.shape[0]
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
# TODO(naumsmogers): Add memory space enums.
hbm = ir.IntegerAttr.get(i32, 203)
tilespmem = ir.IntegerAttr.get(i32, 201)
sflagmem = ir.IntegerAttr.get(i32, 204)
tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>")
tile2d8 = ir.Attribute.parse("#tpu.tiled<(8),[1, 1]>")
row_offset_type = ir.StringAttr.get(b"row_offset")
# TODO(naumsmogers): hardcode the tile memory to tilespmem
@func.FuncOp.from_py_func(
i32,
i32,
ir.MemRefType.get((operand_size, embed_size), f32, tile2d8, hbm),
ir.MemRefType.get((indices_tile_size,), i32, tile1d8, tilespmem),
ir.MemRefType.get(
(indices_tile_size, embed_size), f32, tile2d8, tilespmem
),
name="main",
)
def kernel_main(i, j, operand, indices_tile, output_tile, func_op): # pylint: disable=unused-argument
constants = {}
def c(val, ty=None):
ty = index if ty is None else ty
if (val, ty) not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
constants[(val, ty)] = arith.ConstantOp(
ty, ir.IntegerAttr.get(ty, val)
)
return constants[(val, ty)]
sflag = memref.AllocaOp(
ir.MemRefType.get(
shape=(1,), element_type=i32, memory_space=sflagmem
),
dynamicSizes=[],
symbolOperands=[],
)
# 4 bytes per float * embed_size elements per gather
single_gather_bytes = embed_size * 4
total_bytes = indices_tile_size * single_gather_bytes
# 4 byte-aligned HBM accesses
granule_size = c(4)
granules = arith.DivUIOp(c(single_gather_bytes), granule_size)
sc_tpu.IndirectStreamStartOp(
src=operand,
src_indices=[c(0), c(0)],
dst=output_tile,
dst_indices=[c(0), c(0)],
granules=granules,
sflag=sflag,
sflag_indices=[c(0)],
offset=indices_tile,
offset_indices=[c(0)],
offset_list_size=c(indices_tile_size),
indirect_list_type=row_offset_type,
upd=False,
hbm4b=True,
filter=False,
enable_trace=False,
)
total_granules = arith.DivUIOp(c(total_bytes), granule_size)
sc_tpu.StreamWaitOp(
sflag=sflag, sflag_indices=c(0), granules=total_granules
)
f = kernel_main.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[indices_size // indices_tile_size, 1]
)
f.attributes["dimension_semantics"] = ir.ArrayAttr.get([
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
ir.Attribute.parse("#tpu.dimension_semantics<parallel>"),
])
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m) -> (0, 0)>"
),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m) -> (n)>"
),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse(
"affine_map<(n, m) -> (n, 0)>"
),
}),
])
f.arg_attrs = [
ir.DictAttr.get({}), # i
ir.DictAttr.get({}), # j
ir.DictAttr.get({"sc.persistent": ir.UnitAttr.get()}), # operand
ir.DictAttr.get({}), # indices_tile
ir.DictAttr.get({}), # output_tile
]
assert f.verify(), f
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
return mosaic.as_tpu_kernel(
m,
out_type=core.ShapedArray((indices_size, embed_size), jnp.float32),
device_type="sparsecore",
)(operand, indices)
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 10
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
sample_count = 2097152
indices_size = 524288
# Tile shape for 16 SC tiles and 4 cores
# embed_size = 4; indices_tile_shape = (8192,)
# embed_size = 8; indices_tile_shape = (4096,)
# embed_size = 64; indices_tile_shape = (512,)
# embed_size = 128; indices_tile_shape = (256,)
embed_size = 64
indices_tile_shape = (512,)
operand_shape = (sample_count, embed_size)
indices_shape = (indices_size,)
k1, k2 = random.split(random.PRNGKey(1234))
operand = random.normal(key=k1, shape=operand_shape, dtype=jnp.float32)
indices = random.randint(
key=k2,
shape=indices_shape,
minval=0,
maxval=sample_count,
dtype=jnp.int32,
)
kernel = functools.partial(
gather2d, indices_tile_size=indices_tile_shape[0], embed_size=embed_size
)
outputs = kernel(operand, indices)
gold = jnp.take(operand, indices, axis=0)
np.testing.assert_array_equal(outputs, gold)
s = time.perf_counter()
for _ in range(num_bench_iters):
kernel(operand, indices).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
jnp.take(operand, indices).block_until_ready()
e = time.perf_counter()
d2 = e - s
print(
"Run-time ratio (custom/XLA) (%d runs): %0.4f"
% (num_bench_iters, d1 / d2)
)
if __name__ == "__main__":
app.run(main)