blob: 2cd14e626389d89fa0c785573dea8fc1d726fff1 [file] [log] [blame] [edit]
"""A Mosaic kernel implementing the Gather 1D op."""
import functools
import time
from typing import Sequence
from absl import app
import jax
from jax import core
from jax import random
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import cf
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import scf
from mlir.dialects import vector
import numpy as np
from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
IDX = VectorFactory(ir.IndexType.get)
@functools.partial(jax.jit, static_argnames=["indices_tile_size"])
def gather1d(
operand: jax.Array,
indices: jax.Array,
indices_tile_size: int,
check_idx_bounds: bool = False,
check_slice_size_alignment: bool = False,
) -> jax.Array:
"""Builds the Gather 1D kernel.
Arguments:
operand: The operand from which to gather elements.
indices: The indices of the operand elements to gather.
indices_tile_size: The number of elements to gather per Mosaic tile.
check_idx_bounds: A boolean enabling dynamic checking of indices against the
operand bounds.
check_slice_size_alignment: A boolean enabling dynamic checking of gather
slize size alignment with the HBM transfer granularity.
Returns:
A JAX function implementing the Gather1D kernel.
"""
with ir.Context() as ctx, ir.Location.unknown():
# The TPU dialect is required for its TPU_DimensionSemantics
tpu.register_dialect(ctx)
sc_tpu.register_dialect(ctx)
operand_size = operand.shape[0]
indices_size = indices.shape[0]
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
# TODO(naumsmogers): Add memory space enums.
hbm = ir.IntegerAttr.get(i32, 203)
tilespmem = ir.IntegerAttr.get(i32, 201)
sflagmem = ir.IntegerAttr.get(i32, 204)
# TODO(naumsmogers): hardcode the tile size to 8 for the SC for all mems.
tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>")
row_offset_type = ir.StringAttr.get(b"row_offset")
pred_eq = ir.IntegerAttr.get(i64, 0)
pred_slt = ir.IntegerAttr.get(i64, 2)
pred_uge = ir.IntegerAttr.get(i64, 9)
# TODO(naumsmogers): hardcode the tile memory to tilespmem
@func.FuncOp.from_py_func(
i32,
ir.MemRefType.get((operand_size,), f32, tile1d8, hbm),
ir.MemRefType.get((indices_tile_size,), i32, tile1d8, tilespmem),
ir.MemRefType.get((indices_tile_size,), f32, tile1d8, tilespmem),
name="main",
)
def kernel_main(i, operand, indices_tile, output_tile, func_op): # pylint: disable=unused-argument
constants = {}
def c(val, ty=None):
ty = index if ty is None else ty
if (val, ty) not in constants:
with ir.InsertionPoint.at_block_begin(func_op.entry_block):
constants[(val, ty)] = arith.ConstantOp(
ty, ir.IntegerAttr.get(ty, val)
)
return constants[(val, ty)]
sflag = memref.AllocaOp(
ir.MemRefType.get(
shape=(1,), element_type=i32, memory_space=sflagmem
),
dynamicSizes=[],
symbolOperands=[],
)
# Check that the indices are within the bounds of the operand.
if check_idx_bounds:
loop_over_idx_chunks = scf.ForOp(
lower_bound=c(0), upper_bound=c(indices_tile_size), step=c(8)
)
with ir.InsertionPoint(loop_over_idx_chunks.body):
chunk_offset = loop_over_idx_chunks.induction_variable
chunk_offset_vec = vector.SplatOp(IDX[8], chunk_offset)
indices_chunk = arith.IndexCastOp(
IDX[8],
sc_tpu.VectorLoadOp(
I32[8],
indices_tile,
[chunk_offset],
),
)
vlaneseq = sc_tpu.VlaneseqOp(IDX[8])
lane_offset_vec = arith.AddIOp(chunk_offset_vec, vlaneseq)
dynamic_bound = vector.SplatOp(IDX[8], c(indices_tile_size))
in_bounds = arith.CmpIOp(pred_slt, lane_offset_vec, dynamic_bound)
zero_vec = vector.SplatOp(IDX[8], c(0))
indices_chunk_bounded = arith.SelectOp(
in_bounds, indices_chunk, zero_vec
)
max_chunk = vector.SplatOp(IDX[8], c(operand_size))
compare_result = arith.CmpIOp(
pred_uge, indices_chunk_bounded, max_chunk
)
compare_count_vec = sc_tpu.VmpcntOnesOp(I32[8], compare_result)
compare_count = vector.ExtractElementOp(
compare_count_vec, position=c(0)
)
compare_count_eq_zero = arith.CmpIOp(
pred_eq, compare_count, c(0, i32)
)
cf.AssertOp(
compare_count_eq_zero,
"Element out of bounds (max {}).".format(operand_size),
)
scf.YieldOp([])
# 4 bytes per float * 1 element per gather
single_gather_bytes = c(4 * 1)
total_bytes = arith.MulIOp(single_gather_bytes, c(indices_tile_size))
# 4 Byte-aligned HBM accesses
granule_size = c(4)
# Check that the size of a single gathered slice is a multiple of the
# HBM granule size.
# Trivial check when single_gather_bytes is a constant. Still, leaving
# it here in case it gets parameterized
if check_slice_size_alignment:
remainder_bytes = arith.RemUIOp(single_gather_bytes, granule_size)
assert_multiple = arith.CmpIOp(pred_eq, remainder_bytes, c(0))
cf.AssertOp(
assert_multiple,
"Stream operation transfer size must be a multiple of granule size:"
" 4 in bytes.",
)
granules = arith.DivUIOp(single_gather_bytes, granule_size)
sc_tpu.IndirectStreamStartOp(
src=operand,
src_indices=[c(0)],
dst=output_tile,
dst_indices=[c(0)],
granules=granules,
sflag=sflag,
sflag_indices=[c(0)],
offset=indices_tile,
offset_indices=[c(0)],
offset_list_size=c(indices_tile_size),
indirect_list_type=row_offset_type,
upd=False,
hbm4b=True,
filter=False,
enable_trace=False,
)
# 4 Byte-aligned HBM accesses
granule_size = c(4)
total_granules = arith.DivUIOp(total_bytes, granule_size)
sc_tpu.StreamWaitOp(
sflag=sflag, sflag_indices=c(0), granules=total_granules
)
f = kernel_main.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[indices_size // indices_tile_size]
)
f.attributes["dimension_semantics"] = ir.ArrayAttr.get(
[ir.Attribute.parse("#tpu.dimension_semantics<parallel>")]
)
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(n) -> (0)>"),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"),
}),
])
f.arg_attrs = [
ir.DictAttr.get({}), # i
ir.DictAttr.get({"sc.persistent": ir.UnitAttr.get()}), # operand
ir.DictAttr.get({}), # indices_tile
ir.DictAttr.get({}), # output_tile
]
assert f.verify(), f
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
return mosaic.as_tpu_kernel(
m,
out_type=core.ShapedArray((indices_size,), jnp.float32),
device_type="sparsecore",
)(operand, indices)
def main(argv: Sequence[str]) -> None:
if len(argv) == 1:
num_bench_iters = 10
elif len(argv) == 2:
num_bench_iters = int(argv[1])
else:
raise app.UsageError("Too many command-line arguments.")
sample_count = 2097152
indices_size = 524288
# Tile shape for 16 SC tiles and 4 cores
indices_tile_shape = (8192,)
operand_shape = (sample_count,)
indices_shape = (indices_size,)
k1, k2 = random.split(random.PRNGKey(1234))
operand = random.normal(key=k1, shape=operand_shape, dtype=jnp.float32)
indices = random.randint(
key=k2,
shape=indices_shape,
minval=0,
maxval=sample_count,
dtype=jnp.int32,
)
kernel = functools.partial(gather1d, indices_tile_size=indices_tile_shape[0])
outputs = kernel(operand, indices)
gold = jnp.take(operand, indices)
np.testing.assert_array_equal(outputs, gold)
s = time.perf_counter()
for _ in range(num_bench_iters):
kernel(operand, indices).block_until_ready()
e = time.perf_counter()
d1 = e - s
s = time.perf_counter()
for _ in range(num_bench_iters):
jnp.take(operand, indices).block_until_ready()
e = time.perf_counter()
d2 = e - s
print(
"Run-time ratio (custom/XLA) (%d runs): %0.4f"
% (num_bench_iters, d1 / d2)
)
if __name__ == "__main__":
app.run(main)