| """An example Mosaic kernel implementing the Cholesky decomposition. |
| """ |
| from typing import Sequence, Callable, Any |
| |
| from absl import app |
| import jax |
| from jax.experimental import mosaic |
| import jax.numpy as jnp |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import math |
| from mlir.dialects import vector |
| import numpy as np |
| |
| |
| jax.config.config_with_absl() |
| |
| |
| USE_CONTRACT = False # TODO(apaszke): Implement vector contraction optimization |
| |
| |
| def vec_inner_product(lhs: ir.Value, rhs: ir.Value) -> ir.Value: |
| """Shut up linter.""" |
| return matvec(ir.VectorType.get([1, 1], ir.F32Type.get()), lhs, rhs) |
| |
| |
| def matvec(out_type: ir.Type, mat: ir.Value, vec: ir.Value) -> ir.Value: |
| """Shut up linter.""" |
| f32 = ir.F32Type.get() |
| i64 = ir.IntegerType.get_signless(64) |
| if USE_CONTRACT: |
| zero_f32 = arith.ConstantOp(f32, 0.0).result |
| acc = vector.BroadcastOp(out_type, zero_f32) |
| return ir.Operation.create( |
| "vector.contract", |
| results=[out_type], |
| operands=[mat.result, vec.result, acc.result], |
| attributes={ |
| "indexing_maps": |
| ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, r) -> (i, r)>"), |
| ir.Attribute.parse("affine_map<(i, j, r) -> (j, r)>"), |
| ir.Attribute.parse("affine_map<(i, j, r) -> (i, j)>"), |
| ]), |
| "iterator_types": |
| ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]) |
| }).result |
| else: |
| red_type = ir.VectorType.get((out_type.shape[0],), out_type.element_type) |
| red = vector.MultiDimReductionOp( |
| ir.Attribute.parse("#vector.kind<add>"), |
| arith.MulFOp(mat, vector.BroadcastOp(mat.result.type, vec)), |
| arith.ConstantOp( |
| red_type, |
| ir.DenseElementsAttr.get_splat( |
| red_type, ir.FloatAttr.get(f32, 0.0))), |
| ir.ArrayAttr.get([ir.IntegerAttr.get(i64, 1)]), |
| ) |
| return vector.ShapeCastOp(out_type, red).result |
| |
| |
| def build_cholesky(n: int) -> Callable[..., Any]: |
| """Shut up linter.""" |
| with ir.Location.unknown(): |
| f32 = ir.F32Type.get() |
| index = ir.IndexType.get() |
| full_ref = ir.MemRefType.get((n, n), f32) |
| |
| @func.FuncOp.from_py_func(full_ref, full_ref, name="main") |
| def kernel(in_ref, out_ref): |
| zero_f32 = arith.ConstantOp(f32, 0.0).result |
| zero_index = arith.ConstantOp(index, 0).result |
| # Zero-initialize output reference |
| zero_out = vector.BroadcastOp(ir.VectorType.get([n, n], f32), zero_f32) |
| vector.StoreOp(zero_out, out_ref, [zero_index, zero_index]) |
| for j in range(0, n): |
| j_index = arith.ConstantOp(index, j).result |
| # row = l[j, :j] |
| # x = np.sqrt(a[j, j] - np.dot(row, np.conj(row).T)) |
| if j == 0: |
| row_inner = vector.BroadcastOp( |
| ir.VectorType.get([1, 1], f32), zero_f32) |
| else: |
| row_vec = ir.VectorType.get([1, j], f32) |
| row = vector.LoadOp(row_vec, out_ref, [j_index, zero_index]) |
| row_conj = row # tpu.ConjOp(row) |
| row_inner = vec_inner_product(row, row_conj) |
| x_inv = math.RsqrtOp( |
| arith.SubFOp( |
| vector.LoadOp( |
| ir.VectorType.get([1, 1], f32), in_ref, |
| [j_index, j_index]), row_inner)) |
| # l[j:, j] = (a[j:, j] - np.dot(l[j:, :j], np.conj(row).T)) / x |
| panel_type = ir.VectorType.get([n - j, 1], f32) |
| if j == 0: |
| panel_update = vector.BroadcastOp(panel_type, zero_f32) |
| else: |
| mat = vector.LoadOp( |
| ir.VectorType.get([n - j, j], f32), out_ref, |
| [j_index, zero_index]) |
| panel_update = matvec(panel_type, mat, row_conj) |
| panel_old_value = vector.LoadOp(panel_type, in_ref, |
| [j_index, j_index]) |
| panel_new_value = arith.MulFOp( |
| arith.SubFOp(panel_old_value, panel_update), |
| vector.BroadcastOp(panel_type, x_inv)) |
| vector.StoreOp(panel_new_value, out_ref, [j_index, j_index]) |
| |
| f = kernel.func_op |
| assert f.verify(), f |
| |
| m = ir.Module.create() |
| m.body.append(f) |
| ir.SymbolTable(m.operation).insert(f) |
| return mosaic.as_tpu_kernel( |
| m, out_type=jax.ShapeDtypeStruct((n, n), jnp.float32)) |
| |
| |
| def main(argv: Sequence[str]) -> None: |
| if len(argv) > 2: |
| raise app.UsageError("Too many command-line arguments.") |
| elif len(argv) == 2: |
| n = int(argv[1]) |
| else: |
| n = 128 |
| |
| with ir.Context(): |
| f = build_cholesky(n) |
| x = jnp.tril(jnp.arange(n * n, dtype=jnp.float32).reshape(n, n) / n + 1) |
| output = f(x.dot(x.T)) |
| np.testing.assert_allclose(x, output, atol=1e-5, rtol=5e-3) |
| print("OK!") |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |