| """A Mosaic kernel implementing the Gather 2D op.""" |
| |
| import functools |
| import time |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax import random |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import memref |
| import numpy as np |
| |
| from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| IDX = VectorFactory(ir.IndexType.get) |
| |
| |
| @functools.partial(jax.jit, static_argnames=["indices_tile_size", "embed_size"]) |
| def gather2d( |
| operand: jax.Array, |
| indices: jax.Array, |
| indices_tile_size: int, |
| embed_size: int, |
| ) -> jax.Array: |
| """Builds the Gather 2D kernel. |
| |
| Arguments: |
| operand: The operand from which to gather elements. |
| indices: The indices of the operand elements to gather. |
| indices_tile_size: The number of elements to gather per Mosaic tile. |
| embed_size: The number of elements in the inner dimension of the operand. |
| |
| Returns: |
| A JAX function implementing the Gather1D kernel. |
| """ |
| with ir.Context() as ctx, ir.Location.unknown(): |
| # The TPU dialect is required for its TPU_DimensionSemantics |
| tpu.register_dialect(ctx) |
| sc_tpu.register_dialect(ctx) |
| |
| operand_size = operand.shape[0] |
| indices_size = indices.shape[0] |
| |
| i32 = ir.IntegerType.get_signless(32) |
| f32 = ir.F32Type.get() |
| index = ir.IndexType.get() |
| # TODO(naumsmogers): Add memory space enums. |
| hbm = ir.IntegerAttr.get(i32, 203) |
| tilespmem = ir.IntegerAttr.get(i32, 201) |
| sflagmem = ir.IntegerAttr.get(i32, 204) |
| tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>") |
| tile2d8 = ir.Attribute.parse("#tpu.tiled<(8),[1, 1]>") |
| row_offset_type = ir.StringAttr.get(b"row_offset") |
| |
| # TODO(naumsmogers): hardcode the tile memory to tilespmem |
| @func.FuncOp.from_py_func( |
| i32, |
| i32, |
| ir.MemRefType.get((operand_size, embed_size), f32, tile2d8, hbm), |
| ir.MemRefType.get((indices_tile_size,), i32, tile1d8, tilespmem), |
| ir.MemRefType.get( |
| (indices_tile_size, embed_size), f32, tile2d8, tilespmem |
| ), |
| name="main", |
| ) |
| def kernel_main(i, j, operand, indices_tile, output_tile, func_op): # pylint: disable=unused-argument |
| constants = {} |
| |
| def c(val, ty=None): |
| ty = index if ty is None else ty |
| if (val, ty) not in constants: |
| with ir.InsertionPoint.at_block_begin(func_op.entry_block): |
| constants[(val, ty)] = arith.ConstantOp( |
| ty, ir.IntegerAttr.get(ty, val) |
| ) |
| return constants[(val, ty)] |
| |
| sflag = memref.AllocaOp( |
| ir.MemRefType.get( |
| shape=(1,), element_type=i32, memory_space=sflagmem |
| ), |
| dynamicSizes=[], |
| symbolOperands=[], |
| ) |
| |
| # 4 bytes per float * embed_size elements per gather |
| single_gather_bytes = embed_size * 4 |
| total_bytes = indices_tile_size * single_gather_bytes |
| # 4 byte-aligned HBM accesses |
| granule_size = c(4) |
| granules = arith.DivUIOp(c(single_gather_bytes), granule_size) |
| sc_tpu.IndirectStreamStartOp( |
| src=operand, |
| src_indices=[c(0), c(0)], |
| dst=output_tile, |
| dst_indices=[c(0), c(0)], |
| granules=granules, |
| sflag=sflag, |
| sflag_indices=[c(0)], |
| offset=indices_tile, |
| offset_indices=[c(0)], |
| offset_list_size=c(indices_tile_size), |
| indirect_list_type=row_offset_type, |
| upd=False, |
| hbm4b=True, |
| filter=False, |
| enable_trace=False, |
| ) |
| |
| total_granules = arith.DivUIOp(c(total_bytes), granule_size) |
| sc_tpu.StreamWaitOp( |
| sflag=sflag, sflag_indices=c(0), granules=total_granules |
| ) |
| |
| f = kernel_main.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [indices_size // indices_tile_size, 1] |
| ) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), |
| ]) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m) -> (0, 0)>" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m) -> (n)>" |
| ), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse( |
| "affine_map<(n, m) -> (n, 0)>" |
| ), |
| }), |
| ]) |
| f.arg_attrs = [ |
| ir.DictAttr.get({}), # i |
| ir.DictAttr.get({}), # j |
| ir.DictAttr.get({"sc.persistent": ir.UnitAttr.get()}), # operand |
| ir.DictAttr.get({}), # indices_tile |
| ir.DictAttr.get({}), # output_tile |
| ] |
| assert f.verify(), f |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| |
| return mosaic.as_tpu_kernel( |
| m, |
| out_type=core.ShapedArray((indices_size, embed_size), jnp.float32), |
| device_type="sparsecore", |
| )(operand, indices) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) == 1: |
| num_bench_iters = 10 |
| elif len(argv) == 2: |
| num_bench_iters = int(argv[1]) |
| else: |
| raise app.UsageError("Too many command-line arguments.") |
| |
| sample_count = 2097152 |
| indices_size = 524288 |
| # Tile shape for 16 SC tiles and 4 cores |
| # embed_size = 4; indices_tile_shape = (8192,) |
| # embed_size = 8; indices_tile_shape = (4096,) |
| # embed_size = 64; indices_tile_shape = (512,) |
| # embed_size = 128; indices_tile_shape = (256,) |
| embed_size = 64 |
| indices_tile_shape = (512,) |
| |
| operand_shape = (sample_count, embed_size) |
| indices_shape = (indices_size,) |
| |
| k1, k2 = random.split(random.PRNGKey(1234)) |
| operand = random.normal(key=k1, shape=operand_shape, dtype=jnp.float32) |
| indices = random.randint( |
| key=k2, |
| shape=indices_shape, |
| minval=0, |
| maxval=sample_count, |
| dtype=jnp.int32, |
| ) |
| |
| kernel = functools.partial( |
| gather2d, indices_tile_size=indices_tile_shape[0], embed_size=embed_size |
| ) |
| outputs = kernel(operand, indices) |
| gold = jnp.take(operand, indices, axis=0) |
| |
| np.testing.assert_array_equal(outputs, gold) |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| kernel(operand, indices).block_until_ready() |
| e = time.perf_counter() |
| d1 = e - s |
| |
| s = time.perf_counter() |
| for _ in range(num_bench_iters): |
| jnp.take(operand, indices).block_until_ready() |
| e = time.perf_counter() |
| d2 = e - s |
| print( |
| "Run-time ratio (custom/XLA) (%d runs): %0.4f" |
| % (num_bench_iters, d1 / d2) |
| ) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |