blob: 67293d2e3a8f8cd5d38c4a2fc4a8300fcdc4dc2f [file] [log] [blame]
"""A Mosaic-SparseCore codegen test of vector loading, tile indexing and arithmetic."""
import functools
from typing import Sequence
from absl import app
import jax
from jax import core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
import numpy as np
from google3.platforms.xla.sparse_core.mlo.ir import sc_tpu
class VectorFactory:
def __init__(self, elem_thunk):
self.elem_thunk = elem_thunk
def __getitem__(self, idxs):
if isinstance(idxs, int):
idxs = (idxs,)
return ir.VectorType.get(idxs, self.elem_thunk())
I32 = VectorFactory(lambda: ir.IntegerType.get_signless(32))
# TODO(naumsmogers): use test_util.MosaicTestCase when it supports SC.
@functools.partial(
jax.jit,
static_argnames=[
"input_size",
"output_size",
"input_tile_size",
"output_tile_size",
],
)
def test(
iota: jax.Array,
input_size: int,
output_size: int,
input_tile_size: int,
output_tile_size: int,
) -> jax.Array:
"""Builds and executes a Mosaic kernel testing vload, vstore and arith ops on SC.
Loads 8 elements from the input tile with an offset `i%8`, multiplies each by
two and stores the result in the beginning of the output tile.
Arguments:
iota: An array with an arithmetic sequence with the common difference of one
starting at zero.
input_size: The size of the full input.
output_size: The size of the full output.
input_tile_size: The size of a single input tile.
output_tile_size: AThe size of a single output tile.
Returns:
A JAX function implementing the test.
"""
with ir.Context() as ctx, ir.Location.unknown():
# The TPU dialect is required for its TPU_DimensionSemantics
tpu.register_dialect(ctx)
sc_tpu.register_dialect(ctx)
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
# TODO(naumsmogers): Add memory space enums.
tilespmem = ir.IntegerAttr.get(i32, 201)
# TODO(naumsmogers): hardcode the tile size to 8 for the SC for all
# memories.
tile1d8 = ir.Attribute.parse("#tpu.tiled<(8),[1]>")
# TODO(naumsmogers): hardcode the tile memory to tilespmem
@func.FuncOp.from_py_func(
i32,
ir.MemRefType.get((input_tile_size,), i32, tile1d8, tilespmem),
ir.MemRefType.get((output_tile_size,), i32, tile1d8, tilespmem),
name="main",
)
def kernel_main(i, iota_tile, output_tile, func_op): # pylint: disable=unused-argument
# Load 8 elements starting from element (i % 8), then multiply them by two
i_mod_eight = arith.RemUIOp(
arith.IndexCastOp(index, i),
arith.ConstantOp(index, ir.IntegerAttr.get(index, 8)),
)
iota_vector = sc_tpu.VectorLoadOp(I32[8], iota_tile, [i_mod_eight])
twos_vector = arith.ConstantOp(
I32[8],
ir.DenseElementsAttr.get_splat(I32[8], ir.IntegerAttr.get(i32, 2)),
)
result = arith.MulIOp(iota_vector, twos_vector)
zero = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
sc_tpu.VectorStoreOp(result, output_tile, [zero])
assert input_size // input_tile_size == output_size // output_tile_size
f = kernel_main.func_op
f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(
[output_size // output_tile_size]
)
f.attributes["dimension_semantics"] = ir.ArrayAttr.get(
[ir.Attribute.parse("#tpu.dimension_semantics<parallel>")]
)
f.attributes["window_params"] = ir.ArrayAttr.get([
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"),
}),
ir.DictAttr.get({
"transform_indices": ir.Attribute.parse("affine_map<(n) -> (n)>"),
}),
])
assert f.verify(), f
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
return mosaic.as_tpu_kernel(
m,
out_type=core.ShapedArray([output_size], jnp.int32),
device_type="sparsecore",
)(iota)
def main(_: Sequence[str]) -> None:
input_shape = (1024,)
# Tile shape for 16 SC tiles and 4 cores
input_tile_shape = (16,)
output_shape = (1024,)
# Tile shape for 16 SC tiles and 4 cores
output_tile_shape = (16,)
n_tiles = output_shape[0] // output_tile_shape[0]
iota = jnp.arange(0, input_shape[0], 1, dtype=jnp.int32)
kernel = functools.partial(
test,
input_size=input_shape[0],
output_size=output_shape[0],
input_tile_size=input_tile_shape[0],
output_tile_size=output_tile_shape[0],
)
outputs = kernel(iota)
# Generate the version of the following, where all elements are multiplied
# by 2:
# [[0,1,2,3,4,5,6,7, 0,0,..,0],
# [16+1,17+1,..,23+1, 0,0,..,0],
# ..,
# [112+7,113+7,..,119+7, 0,0,..,0],
# [128+0,129+0,..,135+0, 0,0,..,0],
# [144+1,145+1,..,151+1, 0,0,..,0],
# ..,
# [496+7,497+7,..,504+7, 0,0,..,0]]
gold = jnp.concatenate(
(
jnp.multiply(
2,
np.add.outer( # no .outer in jnp
jnp.add(
jnp.arange(
0,
input_shape[0],
step=input_tile_shape[0],
dtype=jnp.int32,
),
jnp.mod(jnp.arange(0, n_tiles, dtype=jnp.int32), 8),
),
jnp.arange(0, 8, dtype=jnp.int32),
),
),
jnp.zeros(
(n_tiles * (output_tile_shape[0] - 8)), dtype=jnp.int32
).reshape((n_tiles, -1)),
),
axis=1,
)
# Skip the uninitialized regions of outputs when testing since the hardware
# fills the uninitialized regions with zeros, but ISS leaves garbage data.
np.testing.assert_array_equal(
jnp.reshape(outputs, (-1, output_tile_shape[0]))[:, 0:8], gold[:, 0:8]
)
if __name__ == "__main__":
app.run(main)