| import jax |
| from jax.experimental import mosaic |
| from jax.experimental import shard_map |
| from jax.experimental.mosaic.dialects import tpu |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import memref |
| from mlir.dialects import vector |
| import numpy as np |
| |
| from google3.platforms.xla.mosaic.python import test_util |
| from google3.testing.pybase import googletest |
| |
| |
| P = jax.sharding.PartitionSpec |
| |
| |
| class EndToEndTests(test_util.MosaicTestCase): |
| |
| def test_reduction(self): |
| scratch_layout = ir.Attribute.parse("#tpu.tiled<(8,128),[1,1]>") |
| vmem = ir.Attribute.parse("#tpu.memory_space<vmem>") |
| |
| @func.FuncOp.from_py_func( |
| ir.MemRefType.get((8, 128), self.f32), |
| ir.MemRefType.get((8, 128), self.f32), |
| ) |
| def main(i, o): |
| sem_type = ir.Type.parse("!tpu.semaphore") |
| dma_sem_type = ir.Type.parse("!tpu.dma_semaphore") |
| sem_mem = ir.Attribute.parse("#tpu.memory_space<semaphore_mem>") |
| sem_ref_type = ir.MemRefType.get((), sem_type, memory_space=sem_mem) |
| dma_sem_ref_type = ir.MemRefType.get((), dma_sem_type, |
| memory_space=sem_mem) |
| ready_flag = tpu.AllocaSemaphoreOp(sem_ref_type) |
| send_done_flag = tpu.AllocaSemaphoreOp(dma_sem_ref_type) |
| recv_done_flag = tpu.AllocaSemaphoreOp(dma_sem_ref_type) |
| scratch = memref.AllocaOp( |
| ir.MemRefType.get((8, 128), self.f32, scratch_layout, vmem), [], [] |
| ) |
| scratch = tpu.EraseLayoutOp( |
| ir.MemRefType.get((8, 128), self.f32, memory_space=vmem), scratch |
| ) |
| c1 = arith.ConstantOp(self.i32, ir.IntegerAttr.get(self.i32, 1)) |
| device_id = tpu.DeviceIdOp() |
| other_device_id = arith.SubIOp(c1, device_id) |
| # Perform a barrier |
| tpu.SemaphoreSignalOp( |
| ready_flag, c1, device_id=other_device_id |
| ) |
| tpu.SemaphoreWaitOp(ready_flag, c1) |
| # Start the DMA and wait for the data to arrive. |
| tpu.EnqueueDMAOp( |
| i, scratch, recv_done_flag, |
| source_semaphore=send_done_flag, |
| device_id=other_device_id |
| ) |
| tpu.WaitDMAOp(recv_done_flag, scratch) |
| vty = self.f32_vreg |
| i0 = arith.ConstantOp(self.index, ir.IntegerAttr.get(self.index, 0)) |
| x = vector.LoadOp(vty, i, [i0] * 2) |
| y = vector.LoadOp(vty, scratch, [i0] * 2) |
| z = arith.AddFOp(x, y) |
| vector.StoreOp(z, o, [i0] * 2) |
| # Make sure we've completed our send before we exit. |
| tpu.WaitDMAOp(send_done_flag, i) |
| main = main.func_op |
| |
| default_layout = ir.Attribute.parse("#tpu.tiled<(8,128),[1,1]>") |
| for i in range(len(main.arguments)): |
| tpu.private_set_arg_attr(main, i, "llo.layout", default_layout) |
| |
| module = ir.Module.create() |
| module.body.append(main) |
| ir.SymbolTable(module.operation).insert(main) |
| |
| kernel = mosaic.as_tpu_kernel( |
| module, jax.ShapeDtypeStruct((8, 128), jnp.float32), |
| ) |
| |
| # Use the last two devices, so that physical ids don't match logical ids. |
| mesh = jax.sharding.Mesh(jax.devices()[-2:], ("x")) |
| sharding = jax.sharding.NamedSharding(mesh, P("x")) |
| global_data = jnp.arange(2*8*128, dtype=jnp.float32).reshape(2 * 8, 128) |
| x = jax.jit(lambda x: x, out_shardings=sharding)(global_data) |
| run = jax.jit(shard_map.shard_map( |
| kernel, mesh=mesh, in_specs=P("x"), out_specs=P("x"), |
| check_rep=False)) |
| |
| ref = jnp.tile(global_data[:8] + global_data[8:], (2, 1)) |
| np.testing.assert_allclose(run(x), ref) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |