blob: 70d140f2d4445f95b9d23e44e5bc75198954081d [file] [log] [blame] [edit]
import jax
from jax.experimental import mosaic
from jax.experimental import shard_map
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import vector
import numpy as np
from google3.platforms.xla.mosaic.python import test_util
from google3.testing.pybase import googletest
P = jax.sharding.PartitionSpec
class EndToEndTests(test_util.MosaicTestCase):
def test_reduction(self):
scratch_layout = ir.Attribute.parse("#tpu.tiled<(8,128),[1,1]>")
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
@func.FuncOp.from_py_func(
ir.MemRefType.get((8, 128), self.f32),
ir.MemRefType.get((8, 128), self.f32),
)
def main(i, o):
sem_type = ir.Type.parse("!tpu.semaphore")
dma_sem_type = ir.Type.parse("!tpu.dma_semaphore")
sem_mem = ir.Attribute.parse("#tpu.memory_space<semaphore_mem>")
sem_ref_type = ir.MemRefType.get((), sem_type, memory_space=sem_mem)
dma_sem_ref_type = ir.MemRefType.get((), dma_sem_type,
memory_space=sem_mem)
ready_flag = tpu.AllocaSemaphoreOp(sem_ref_type)
send_done_flag = tpu.AllocaSemaphoreOp(dma_sem_ref_type)
recv_done_flag = tpu.AllocaSemaphoreOp(dma_sem_ref_type)
scratch = memref.AllocaOp(
ir.MemRefType.get((8, 128), self.f32, scratch_layout, vmem), [], []
)
scratch = tpu.EraseLayoutOp(
ir.MemRefType.get((8, 128), self.f32, memory_space=vmem), scratch
)
c1 = arith.ConstantOp(self.i32, ir.IntegerAttr.get(self.i32, 1))
device_id = tpu.DeviceIdOp()
other_device_id = arith.SubIOp(c1, device_id)
# Perform a barrier
tpu.SemaphoreSignalOp(
ready_flag, c1, device_id=other_device_id
)
tpu.SemaphoreWaitOp(ready_flag, c1)
# Start the DMA and wait for the data to arrive.
tpu.EnqueueDMAOp(
i, scratch, recv_done_flag,
source_semaphore=send_done_flag,
device_id=other_device_id
)
tpu.WaitDMAOp(recv_done_flag, scratch)
vty = self.f32_vreg
i0 = arith.ConstantOp(self.index, ir.IntegerAttr.get(self.index, 0))
x = vector.LoadOp(vty, i, [i0] * 2)
y = vector.LoadOp(vty, scratch, [i0] * 2)
z = arith.AddFOp(x, y)
vector.StoreOp(z, o, [i0] * 2)
# Make sure we've completed our send before we exit.
tpu.WaitDMAOp(send_done_flag, i)
main = main.func_op
default_layout = ir.Attribute.parse("#tpu.tiled<(8,128),[1,1]>")
for i in range(len(main.arguments)):
tpu.private_set_arg_attr(main, i, "llo.layout", default_layout)
module = ir.Module.create()
module.body.append(main)
ir.SymbolTable(module.operation).insert(main)
kernel = mosaic.as_tpu_kernel(
module, jax.ShapeDtypeStruct((8, 128), jnp.float32),
)
# Use the last two devices, so that physical ids don't match logical ids.
mesh = jax.sharding.Mesh(jax.devices()[-2:], ("x"))
sharding = jax.sharding.NamedSharding(mesh, P("x"))
global_data = jnp.arange(2*8*128, dtype=jnp.float32).reshape(2 * 8, 128)
x = jax.jit(lambda x: x, out_shardings=sharding)(global_data)
run = jax.jit(shard_map.shard_map(
kernel, mesh=mesh, in_specs=P("x"), out_specs=P("x"),
check_rep=False))
ref = jnp.tile(global_data[:8] + global_data[8:], (2, 1))
np.testing.assert_allclose(run(x), ref)
if __name__ == "__main__":
googletest.main()