blob: 47d523c63027c2986e9448b04ba286669cfe9834 [file] [log] [blame]
"""An example Mosaic kernel implementing the Cholesky decomposition.
"""
from typing import Sequence, Callable, Any
from absl import app
import jax
from jax.experimental import mosaic
import jax.numpy as jnp
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import math
from mlir.dialects import vector
import numpy as np
jax.config.config_with_absl()
USE_CONTRACT = False # TODO(apaszke): Implement vector contraction optimization
def vec_inner_product(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
"""Shut up linter."""
return matvec(ir.VectorType.get([1, 1], ir.F32Type.get()), lhs, rhs)
def matvec(out_type: ir.Type, mat: ir.Value, vec: ir.Value) -> ir.Value:
"""Shut up linter."""
f32 = ir.F32Type.get()
i64 = ir.IntegerType.get_signless(64)
if USE_CONTRACT:
zero_f32 = arith.ConstantOp(f32, 0.0).result
acc = vector.BroadcastOp(out_type, zero_f32)
return ir.Operation.create(
"vector.contract",
results=[out_type],
operands=[mat.result, vec.result, acc.result],
attributes={
"indexing_maps":
ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, r) -> (i, r)>"),
ir.Attribute.parse("affine_map<(i, j, r) -> (j, r)>"),
ir.Attribute.parse("affine_map<(i, j, r) -> (i, j)>"),
]),
"iterator_types":
ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
])
}).result
else:
red_type = ir.VectorType.get((out_type.shape[0],), out_type.element_type)
red = vector.MultiDimReductionOp(
ir.Attribute.parse("#vector.kind<add>"),
arith.MulFOp(mat, vector.BroadcastOp(mat.result.type, vec)),
arith.ConstantOp(
red_type,
ir.DenseElementsAttr.get_splat(
red_type, ir.FloatAttr.get(f32, 0.0))),
ir.ArrayAttr.get([ir.IntegerAttr.get(i64, 1)]),
)
return vector.ShapeCastOp(out_type, red).result
def build_cholesky(n: int) -> Callable[..., Any]:
"""Shut up linter."""
with ir.Location.unknown():
f32 = ir.F32Type.get()
index = ir.IndexType.get()
full_ref = ir.MemRefType.get((n, n), f32)
@func.FuncOp.from_py_func(full_ref, full_ref, name="main")
def kernel(in_ref, out_ref):
zero_f32 = arith.ConstantOp(f32, 0.0).result
zero_index = arith.ConstantOp(index, 0).result
# Zero-initialize output reference
zero_out = vector.BroadcastOp(ir.VectorType.get([n, n], f32), zero_f32)
vector.StoreOp(zero_out, out_ref, [zero_index, zero_index])
for j in range(0, n):
j_index = arith.ConstantOp(index, j).result
# row = l[j, :j]
# x = np.sqrt(a[j, j] - np.dot(row, np.conj(row).T))
if j == 0:
row_inner = vector.BroadcastOp(
ir.VectorType.get([1, 1], f32), zero_f32)
else:
row_vec = ir.VectorType.get([1, j], f32)
row = vector.LoadOp(row_vec, out_ref, [j_index, zero_index])
row_conj = row # tpu.ConjOp(row)
row_inner = vec_inner_product(row, row_conj)
x_inv = math.RsqrtOp(
arith.SubFOp(
vector.LoadOp(
ir.VectorType.get([1, 1], f32), in_ref,
[j_index, j_index]), row_inner))
# l[j:, j] = (a[j:, j] - np.dot(l[j:, :j], np.conj(row).T)) / x
panel_type = ir.VectorType.get([n - j, 1], f32)
if j == 0:
panel_update = vector.BroadcastOp(panel_type, zero_f32)
else:
mat = vector.LoadOp(
ir.VectorType.get([n - j, j], f32), out_ref,
[j_index, zero_index])
panel_update = matvec(panel_type, mat, row_conj)
panel_old_value = vector.LoadOp(panel_type, in_ref,
[j_index, j_index])
panel_new_value = arith.MulFOp(
arith.SubFOp(panel_old_value, panel_update),
vector.BroadcastOp(panel_type, x_inv))
vector.StoreOp(panel_new_value, out_ref, [j_index, j_index])
f = kernel.func_op
assert f.verify(), f
m = ir.Module.create()
m.body.append(f)
ir.SymbolTable(m.operation).insert(f)
return mosaic.as_tpu_kernel(
m, out_type=jax.ShapeDtypeStruct((n, n), jnp.float32))
def main(argv: Sequence[str]) -> None:
if len(argv) > 2:
raise app.UsageError("Too many command-line arguments.")
elif len(argv) == 2:
n = int(argv[1])
else:
n = 128
with ir.Context():
f = build_cholesky(n)
x = jnp.tril(jnp.arange(n * n, dtype=jnp.float32).reshape(n, n) / n + 1)
output = f(x.dot(x.T))
np.testing.assert_allclose(x, output, atol=1e-5, rtol=5e-3)
print("OK!")
if __name__ == "__main__":
app.run(main)