| """A Mosaic kernel implementing the Gather 1D op.""" |
| |
| import functools |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax import random |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import cf |
| from mlir.dialects import func |
| from mlir.dialects import memref |
| from mlir.dialects import scf |
| from mlir.dialects import vector |
| import numpy as np |
| |
| from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| IDX = VectorFactory(ir.IndexType.get) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["indices_tile_size"]) |
| def gather1d( |
| operand: jax.Array, |
| indices: jax.Array, |
| indices_tile_size: int, |
| check_idx_bounds: bool = False, |
| check_slice_size_alignment: bool = False, |
| ) -> jax.Array: |
| """Builds the Gather 1D kernel. |
| |
| Arguments: |
| operand: The operand from which to gather elements. |
| indices: The indices of the operand elements to gather. |
| indices_tile_size: The number of elements to gather per Mosaic tile. |
| check_idx_bounds: A boolean enabling dynamic checking of indices against the |
| operand bounds. |
| check_slice_size_alignment: A boolean enabling dynamic checking of gather |
| slize size alignment with the HBM transfer granularity. |
| |
| Returns: |
| A JAX function implementing the Gather1D kernel. |
| """ |
| with ir.Context() as ctx, ir.Location.unknown(): |
| # The TPU dialect is required for its TPU_DimensionSemantics |
| tpu.register_dialect(ctx) |
| sc_tpu.register_dialect(ctx) |
| |
| operand_size = operand.shape[0] |
| indices_size = indices.shape[0] |
| |
| i32 = ir.IntegerType.get_signless(32) |
| i64 = ir.IntegerType.get_signless(64) |
| f32 = ir.F32Type.get() |
| index = ir.IndexType.get() |
| # TODO(naumsmogers): Add memory space enums. |
| hbm = ir.IntegerAttr.get(i32, 203) |
| tilespmem = ir.IntegerAttr.get(i32, 201) |
| sflagmem = ir.IntegerAttr.get(i32, 204) |
| # TODO(naumsmogers): hardcode the tile size to 8 for the SC for all mems. |
| tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>") |
| row_offset_type = ir.StringAttr.get(b"row_offset") |
| |
| pred_eq = ir.IntegerAttr.get(i64, 0) |
| pred_slt = ir.IntegerAttr.get(i64, 2) |
| pred_uge = ir.IntegerAttr.get(i64, 9) |
| |
| # TODO(naumsmogers): hardcode the tile memory to tilespmem |
| @func.FuncOp.from_py_func( |
| i32, |
| ir.MemRefType.get((operand_size,), f32, tile1d8, hbm), |
| ir.MemRefType.get((indices_tile_size,), i32, tile1d8, tilespmem), |
| ir.MemRefType.get((indices_tile_size,), f32, tile1d8, tilespmem), |
| name="main", |
| ) |
| def kernel_main(i, operand, indices_tile, output_tile, func_op): # pylint: disable=unused-argument |
| constants = {} |
| |
| def c(val, ty=None): |
| ty = index if ty is None else ty |
| if (val, ty) not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| constants[(val, ty)] = arith.ConstantOp( |
| ty, ir.IntegerAttr.get(ty, val) |
| ) |
| return constants[(val, ty)] |
| |
| sflag = memref.AllocaOp( |
| ir.MemRefType.get( |
| shape=(1,), element_type=i32, memory_space=sflagmem |
| ), |
| dynamicSizes=[], |
| symbolOperands=[], |
| ) |
| |
| # Check that the indices are within the bounds of the operand. |
| if check_idx_bounds: |
| loop_over_idx_chunks = scf.ForOp( |
| lower_bound=c(0), upper_bound=c(indices_tile_size), step=c(8) |
| ) |
| with ir.InsertionPoint(loop_over_idx_chunks.body): |
| chunk_offset = loop_over_idx_chunks.induction_variable |
| chunk_offset_vec = vector.SplatOp(IDX[8], chunk_offset) |
| indices_chunk = arith.IndexCastOp( |
| IDX[8], |
| sc_tpu.VectorLoadOp( |
| I32[8], |
| indices_tile, |
| [chunk_offset], |
| ), |
| ) |
| vlaneseq = sc_tpu.VlaneseqOp(IDX[8]) |
| lane_offset_vec = arith.AddIOp(chunk_offset_vec, vlaneseq) |
| dynamic_bound = vector.SplatOp(IDX[8], c(indices_tile_size)) |
| in_bounds = arith.CmpIOp(pred_slt, lane_offset_vec, dynamic_bound) |
| zero_vec = vector.SplatOp(IDX[8], c(0)) |
| indices_chunk_bounded = arith.SelectOp( |
| in_bounds, indices_chunk, zero_vec |
| ) |
| max_chunk = vector.SplatOp(IDX[8], c(operand_size)) |
| compare_result = arith.CmpIOp( |
| pred_uge, indices_chunk_bounded, max_chunk |
| ) |
| compare_count_vec = sc_tpu.VmpcntOnesOp(I32[8], compare_result) |
| compare_count = vector.ExtractElementOp( |
| compare_count_vec, position=c(0) |
| ) |
| compare_count_eq_zero = arith.CmpIOp( |
| pred_eq, compare_count, c(0, i32) |
| ) |
| cf.AssertOp( |
| compare_count_eq_zero, |
| "Element out of bounds (max {}).".format(operand_size), |
| ) |
| scf.YieldOp([]) |
| |
| # 4 bytes per float * 1 element per gather |
| single_gather_bytes = c(4 * 1) |
| total_bytes = arith.MulIOp(single_gather_bytes, c(indices_tile_size)) |
| # 4 Byte-aligned HBM accesses |
| granule_size = c(4) |
| # Check that the size of a single gathered slice is a multiple of the |
| # HBM granule size. |
| # Trivial check when single_gather_bytes is a constant. Still, leaving |
| # it here in case it gets parameterized |
| if check_slice_size_alignment: |
| remainder_bytes = arith.RemUIOp(single_gather_bytes, granule_size) |
| assert_multiple = arith.CmpIOp(pred_eq, remainder_bytes, c(0)) |
| cf.AssertOp( |
| assert_multiple, |
| "Stream operation transfer size must be a multiple of granule size:" |
| " 4 in bytes.", |
| ) |
| granules = arith.DivUIOp(single_gather_bytes, granule_size) |
| sc_tpu.IndirectStreamStartOp( |
| src=operand, |
| src_indices=[c(0)], |
| dst=output_tile, |
| dst_indices=[c(0)], |
| granules=granules, |
| sflag=sflag, |
| sflag_indices=[c(0)], |
| offset=indices_tile, |
| offset_indices=[c(0)], |
| offset_list_size=c(indices_tile_size), |
| indirect_list_type=row_offset_type, |
| upd=False, |
| hbm4b=True, |
| filter=False, |
| enable_trace=False, |
| ) |
| # 4 Byte-aligned HBM accesses |
| granule_size = c(4) |
| total_granules = arith.DivUIOp(total_bytes, granule_size) |
| sc_tpu.StreamWaitOp( |
| sflag=sflag, sflag_indices=c(0), granules=total_granules |
| ) |
| |
| f = kernel_main.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [indices_size // indices_tile_size] |
| ) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get( |
| [ir.Attribute.parse("#tpu.dimension_semantics<parallel>")] |
| ) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(n) -> (0)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"), |
| }), |
| ]) |
| f.arg_attrs = [ |
| ir.DictAttr.get({}), # i |
| ir.DictAttr.get({"sc.persistent": ir.UnitAttr.get()}), # operand |
| ir.DictAttr.get({}), # indices_tile |
| ir.DictAttr.get({}), # output_tile |
| ] |
| assert f.verify(), f |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| |
| return mosaic.as_tpu_kernel( |
| m, |
| out_type=core.ShapedArray((indices_size,), jnp.float32), |
| device_type="sparsecore", |
| )(operand, indices) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 10 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| sample_count = 2097152 |
| indices_size = 524288 |
| # Tile shape for 16 SC tiles and 4 cores |
| indices_tile_shape = (8192,) |
| |
| operand_shape = (sample_count,) |
| indices_shape = (indices_size,) |
| |
| k1, k2 = random.split(random.PRNGKey(1234)) |
| operand = random.normal(key=k1, shape=operand_shape, dtype=jnp.float32) |
| indices = random.randint( |
| key=k2, |
| shape=indices_shape, |
| minval=0, |
| maxval=sample_count, |
| dtype=jnp.int32, |
| ) |
| |
| kernel = functools.partial(gather1d, indices_tile_size=indices_tile_shape[0]) |
| outputs = kernel(operand, indices) |
| gold = jnp.take(operand, indices) |
| |
| np.testing.assert_array_equal(outputs, gold) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| kernel(operand, indices).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| jnp.take(operand, indices).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print( |
| "Run-time ratio (custom/XLA) (%d runs): %0.4f" |
| % (num_bench_iters, d1 / d2) |
| ) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |