| """A Mosaic-SparseCore codegen test of vector loading, tile indexing and arithmetic.""" |
| |
| import functools |
| from typing import Sequence |
| |
| from absl import app |
| import jax |
| from jax import core |
| from jax.experimental import mosaic |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| import numpy as np |
| |
| from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu |
| |
| |
| class VectorFactory: |
| |
| def __init__(self, elem_thunk): |
| self.elem_thunk = elem_thunk |
| |
| def __getitem__(self, idxs): |
| if isinstance(idxs, int): |
| idxs = (idxs,) |
| return ir.VectorType.get(idxs, self.elem_thunk()) |
| |
| |
| I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32)) |
| |
| |
| # TODO(naumsmogers): use test_util.MosaicTestCase when it supports SC. |
| @functools.partial( |
| jax.jit, |
| static_argnames=[ |
| "input_size", |
| "output_size", |
| "input_tile_size", |
| "output_tile_size", |
| ], |
| ) |
| def test( |
| iota: jax.Array, |
| input_size: int, |
| output_size: int, |
| input_tile_size: int, |
| output_tile_size: int, |
| ) -> jax.Array: |
| """Builds and executes a Mosaic kernel testing vload, vstore and arith ops on SC. |
| |
| Loads 8 elements from the input tile with an offset `i%8`, multiplies each by |
| two and stores the result in the beginning of the output tile. |
| |
| Arguments: |
| iota: An array with an arithmetic sequence with the common difference of one |
| starting at zero. |
| input_size: The size of the full input. |
| output_size: The size of the full output. |
| input_tile_size: The size of a single input tile. |
| output_tile_size: AThe size of a single output tile. |
| |
| Returns: |
| A JAX function implementing the test. |
| """ |
| with ir.Context() as ctx, ir.Location.unknown(): |
| # The TPU dialect is required for its TPU_DimensionSemantics |
| tpu.register_dialect(ctx) |
| sc_tpu.register_dialect(ctx) |
| |
| i32 = ir.IntegerType.get_signless(32) |
| index = ir.IndexType.get() |
| # TODO(naumsmogers): Add memory space enums. |
| tilespmem = ir.IntegerAttr.get(i32, 201) |
| # TODO(naumsmogers): hardcode the tile size to 8 for the SC for all |
| # memories. |
| tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>") |
| |
| # TODO(naumsmogers): hardcode the tile memory to tilespmem |
| @func.FuncOp.from_py_func( |
| i32, |
| ir.MemRefType.get((input_tile_size,), i32, tile1d8, tilespmem), |
| ir.MemRefType.get((output_tile_size,), i32, tile1d8, tilespmem), |
| name="main", |
| ) |
| def kernel_main(i, iota_tile, output_tile, func_op): # pylint: disable=unused-argument |
| # Load 8 elements starting from element (i % 8), then multiply them by two |
| i_mod_eight = arith.RemUIOp( |
| arith.IndexCastOp(index, i), |
| arith.ConstantOp(index, ir.IntegerAttr.get(index, 8)), |
| ) |
| iota_vector = sc_tpu.VectorLoadOp(I32[8], iota_tile, [i_mod_eight]) |
| twos_vector = arith.ConstantOp( |
| I32[8], |
| ir.DenseElementsAttr.get_splat(I32[8], ir.IntegerAttr.get(i32, 2)), |
| ) |
| result = arith.MulIOp(iota_vector, twos_vector) |
| |
| zero = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0)) |
| sc_tpu.VectorStoreOp(result, output_tile, [zero]) |
| |
| assert input_size // input_tile_size == output_size // output_tile_size |
| |
| f = kernel_main.func_op |
| f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get( |
| [output_size // output_tile_size] |
| ) |
| f.attributes["dimension_semantics"] = ir.ArrayAttr.get( |
| [ir.Attribute.parse("#tpu.dimension_semantics<parallel>")] |
| ) |
| f.attributes["window_params"] = ir.ArrayAttr.get([ |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"), |
| }), |
| ir.DictAttr.get({ |
| "transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"), |
| }), |
| ]) |
| assert f.verify(), f |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| |
| return mosaic.as_tpu_kernel( |
| m, |
| out_type=core.ShapedArray([output_size], jnp.int32), |
| device_type="sparsecore", |
| )(iota) |
| |
| |
| def main(_: Sequence[str]) -> None: |
| input_shape = (1024,) |
| # Tile shape for 16 SC tiles and 4 cores |
| input_tile_shape = (16,) |
| output_shape = (1024,) |
| # Tile shape for 16 SC tiles and 4 cores |
| output_tile_shape = (16,) |
| n_tiles = output_shape[0] // output_tile_shape[0] |
| |
| iota = jnp.arange(0, input_shape[0], 1, dtype=jnp.int32) |
| |
| kernel = functools.partial( |
| test, |
| input_size=input_shape[0], |
| output_size=output_shape[0], |
| input_tile_size=input_tile_shape[0], |
| output_tile_size=output_tile_shape[0], |
| ) |
| outputs = kernel(iota) |
| # Generate the version of the following, where all elements are multiplied |
| # by 2: |
| # [[0,1,2,3,4,5,6,7, 0,0,..,0], |
| # [16+1,17+1,..,23+1, 0,0,..,0], |
| # .., |
| # [112+7,113+7,..,119+7, 0,0,..,0], |
| # [128+0,129+0,..,135+0, 0,0,..,0], |
| # [144+1,145+1,..,151+1, 0,0,..,0], |
| # .., |
| # [496+7,497+7,..,504+7, 0,0,..,0]] |
| gold = jnp.concatenate( |
| ( |
| jnp.multiply( |
| 2, |
| np.add.outer( # no .outer in jnp |
| jnp.add( |
| jnp.arange( |
| 0, |
| input_shape[0], |
| step=input_tile_shape[0], |
| dtype=jnp.int32, |
| ), |
| jnp.mod(jnp.arange(0, n_tiles, dtype=jnp.int32), 8), |
| ), |
| jnp.arange(0, 8, dtype=jnp.int32), |
| ), |
| ), |
| jnp.zeros( |
| (n_tiles * (output_tile_shape[0] - 8)), dtype=jnp.int32 |
| ).reshape((n_tiles, -1)), |
| ), |
| axis=1, |
| ) |
| |
| # Skip the uninitialized regions of outputs when testing since the hardware |
| # fills the uninitialized regions with zeros, but ISS leaves garbage data. |
| np.testing.assert_array_equal( |
| jnp.reshape(outputs, (-1, output_tile_shape[0]))[:, 0:8], gold[:, 0:8] |
| ) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |