blob: 9df3225925b33930d9a3404fea322818a21ccf20 [file] [log] [blame] [edit]
"""A small and simple interpreter for MLIR.
Only to be used for testing.
"""
from collections.abc import Callable, Mapping
from typing import Optional, Protocol, Sequence
from jax import dtypes
from mlir import ir
from mlir.dialects import func
import numpy as np
def _ir_type_to_dtype(ir_type: ir.Type) -> np.dtype:
"""Convert the given mlir type to the corresponding numpy type.
Args:
ir_type: mlir type to convert.
Returns:
the numpy type corresponding to the given mlir type.
"""
if ir_type == ir.IntegerType.get_signless(1):
return np.dtype(np.bool_)
elif ir_type == ir.IntegerType.get_signless(8):
return np.dtype(np.int8)
elif ir_type == ir.IntegerType.get_signless(16):
return np.dtype(np.int16)
elif ir_type == ir.IntegerType.get_signless(32):
return np.dtype(np.int32)
elif ir_type == ir.IntegerType.get_signless(64):
return np.dtype(np.int64)
elif ir_type == ir.IntegerType.get_unsigned(8):
return np.dtype(np.uint8)
elif ir_type == ir.IntegerType.get_unsigned(16):
return np.dtype(np.uint16)
elif ir_type == ir.IntegerType.get_unsigned(32):
return np.dtype(np.uint32)
elif ir_type == ir.IntegerType.get_unsigned(64):
return np.dtype(np.uint64)
elif ir_type == ir.BF16Type.get():
return np.dtype(dtypes.bfloat16)
elif ir_type == ir.F16Type.get():
return np.dtype(np.float16)
elif ir_type == ir.F32Type.get():
return np.dtype(np.float32)
elif ir_type == ir.F64Type.get():
return np.dtype(np.float64)
elif ir_type == ir.IndexType.get():
return np.dtype(np.int32)
elif ir.ShapedType.isinstance(ir_type):
return _ir_type_to_dtype(ir.ShapedType(ir_type).element_type)
else:
raise NotImplementedError(f"Unsupported ir type: {ir_type}")
class Rule(Protocol):
"""An interpretation rule for an MLIR primitive."""
def __call__(self, op: ir.OpView, *args: np.ndarray) -> Sequence[np.ndarray]:
...
def interpret(
f: func.FuncOp,
*args: np.ndarray,
custom_rules: Optional[Mapping[str, Rule]] = None) -> Sequence[np.ndarray]:
"""Evaluates the given MLIR function using NumPy operations.
Every vector argument is expected to be backed by a NumPy array of
corresponding shape and dtype. Every scalar argument should be a NumPy scalar
of the corresponding dtype. When this is not the case the inputs are converted
to numpy data structures.
Notice that MLIR and NumPy type systems are fundamentally incompatible wrt
unsigned data types. MLIR typically works with signeless data types and the
signedness of the operation is defined by the operation themselves.
NumPy instead employs signed and unsigned integers data types.
For this reason we cannot afford to be strict when checking the correspondence
between NumPy and MLIR types.
Args:
f: Function to evaluate.
*args: The arguments to evaluate at.
custom_rules: A dictionary used to extend the set of standard interpretation
rules. Maps operation name to its rule.
Returns:
The result of applying f to args.
"""
if custom_rules is None:
custom_rules = {}
env = {}
assert len(f.arguments) == len(args)
for formal, actual in zip(f.arguments, args):
# Convert every actual argument of the block to a numpy type.
# Mixing the python/JAX/numpy type systems leads to unexpected behaviors.
if isinstance(actual, int) or isinstance(actual, float):
actual = _ir_type_to_dtype(formal.type).type(actual)
elif not isinstance(actual, (np.ndarray, np.ScalarType)):
# This case should be hit by jax.Array. We do not want to add a dependency
# on the jax library so we do not have an explicitly check for this type.
actual = np.asarray(actual)
env[formal] = actual
for op in f.entry_block:
op_name = op.OPERATION_NAME
if op_name == "func.return":
return [env[formal] for formal in op.operands]
if not (rule := custom_rules.get(op_name, None)):
rule = _interpreter_rules[op_name]
results = rule(op, *(env[arg] for arg in op.operands))
assert len(op.results) == len(results)
for formal, actual in zip(op.results, results):
env[formal] = actual
raise ValueError("Unterminated function?")
_interpreter_rules: dict[str, Rule] = {}
def _def_rule(name: str) -> Callable[[Rule], None]:
def f(rule):
_interpreter_rules[name] = rule
return f
# go/keep-sorted start
_def_rule("arith.addf")(lambda _, x, y: (x + y,))
_def_rule("arith.addi")(lambda _, x, y: (x + y,))
_def_rule("arith.andi")(lambda _, x, y: (x & y,))
_def_rule("arith.divf")(lambda _, x, y: (x / y,))
_def_rule("arith.index_cast")(lambda _, x: (x,))
_def_rule("arith.maximumf")(lambda _, x, y: (np.maximum(x, y),))
_def_rule("arith.maxsi")(lambda _, x, y: (np.maximum(x, y),))
_def_rule("arith.minimumf")(lambda _, x, y: (np.minimum(x, y),))
_def_rule("arith.minsi")(lambda _, x, y: (np.minimum(x, y),))
_def_rule("arith.mulf")(lambda _, x, y: (x * y,))
_def_rule("arith.muli")(lambda _, x, y: (x * y,)) # go/NOTYPO
_def_rule("arith.negf")(lambda _, x: (-x,))
_def_rule("arith.ori")(lambda _, x, y: (x | y,))
_def_rule("arith.select")(lambda _, c, t, f: (np.where(c, t, f),))
_def_rule("arith.shrui")(lambda _, x, y: (x >> y,))
_def_rule("arith.subf")(lambda _, x, y: (x - y,))
_def_rule("arith.subi")(lambda _, x, y: (x - y,))
_def_rule("arith.xori")(lambda _, x, y: (x ^ y,))
_def_rule("math.absf")(lambda _, x: (np.abs(x),))
_def_rule("math.absi")(lambda _, x: (np.abs(x),))
_def_rule("math.cos")(lambda _, x: (np.cos(x),))
_def_rule("math.ctlz")(lambda _, x: (f"{x:032b}1".find("1"),))
_def_rule("math.exp")(lambda _, x: (np.exp(x),))
_def_rule("math.exp2")(lambda _, x: (np.power(2., x),))
_def_rule("math.log")(lambda _, x: (np.log(x),))
_def_rule("math.log1p")(lambda _, x: (np.log1p(x),))
_def_rule("math.powf")(lambda _, x, y: (np.power(x, y),))
_def_rule("math.round")(lambda _, x: (np.round(x),))
_def_rule("math.roundeven")(lambda _, x: (np.rint(x),))
_def_rule("math.rsqrt")(lambda _, x: (1.0 / np.sqrt(x),))
_def_rule("math.sin")(lambda _, x: (np.sin(x),))
_def_rule("math.sqrt")(lambda _, x: (np.sqrt(x),))
_def_rule("math.tanh")(lambda _, x: (np.tanh(x),))
_def_rule("memref.load")(lambda _, arr, i: (arr[i],))
# go/keep-sorted end
@_def_rule("arith.constant")
def _constant_rule(op): # pylint: disable=missing-function-docstring
res_ty = op.result.type
value = op.attributes["value"]
if ir.IntegerType.isinstance(res_ty) or ir.IndexType.isinstance(res_ty):
return (_ir_type_to_dtype(res_ty).type(ir.IntegerAttr(value).value),)
elif ir.VectorType.isinstance(res_ty):
res_ty = ir.VectorType(res_ty)
if ir.F32Type.isinstance(res_ty.element_type):
attr = ir.DenseFPElementsAttr(value)
dtype = np.dtype(np.float32)
elif res_ty.element_type == ir.IntegerType.get_signless(32):
attr = ir.DenseIntElementsAttr(value)
dtype = np.dtype(np.int32)
else:
raise NotImplementedError(f"Unsupported constant type: {res_ty}")
values = np.asarray(list(attr), dtype=dtype)
return (values.reshape(res_ty.shape),)
raise NotImplementedError(f"Unsupported constant type: {res_ty}")
@_def_rule("arith.cmpi")
def _cmpi_rule(op, x, y): # pylint: disable=missing-function-docstring
cmp = ir.IntegerAttr(op.attributes["predicate"]).value
if cmp == 0:
return (x == y,)
if cmp == 2:
if not all(np.issubdtype(v.dtype, np.signedinteger) for v in (x, y)):
raise RuntimeError("Signed comparison requested for unsupported dtype")
return (x < y,)
elif cmp == 5:
return (x >= y,)
else:
raise NotImplementedError(f"Unsupported comparison {cmp}")
@_def_rule("arith.sitofp")
def _sitofp_rule(op, x):
res_ty = op.result.type
if ir.F32Type.isinstance(res_ty):
return (x.astype(np.float32),)
elif ir.VectorType.isinstance(res_ty):
res_ty = ir.VectorType(res_ty)
if ir.F32Type.isinstance(res_ty.element_type):
return (x.astype(np.float32),)
raise NotImplementedError(f"Unsupported sitofp type: {res_ty}")
@_def_rule("vector.contract")
def _contract_rule(op, lhs, rhs, acc): # pylint: disable=missing-function-docstring
matmul_maps = ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
])
matmul_rhs_transpose_maps = ir.ArrayAttr.get([
ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>"),
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
])
indexing_maps = op.attributes["indexing_maps"]
if (indexing_maps != matmul_maps and
indexing_maps != matmul_rhs_transpose_maps):
raise NotImplementedError("Unsupported indexing maps")
iterator_types = ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
])
if op.attributes["iterator_types"] != iterator_types:
raise NotImplementedError("Unsupported iterator types")
if len(op.operands) != 3:
raise NotImplementedError("Unsupported masks")
rhs_mult = rhs.T if indexing_maps == matmul_rhs_transpose_maps else rhs
return (acc + lhs @ rhs_mult,)
@_def_rule("tpu.matmul")
def _matmul_rule(op, lhs, rhs, acc): # pylint: disable=missing-function-docstring
if op.transpose_lhs.value:
lhs = lhs.T
if op.transpose_rhs.value:
rhs = rhs.T
return (acc + lhs @ rhs,)
@_def_rule("vector.load")
def _vector_load_rule(op, base, *indices):
res_ty = ir.VectorType(op.result.type)
assert len(indices) == len(base.shape)
assert len(indices) == len(res_ty.shape)
loc = tuple(
slice(indices[i], indices[i] + res_ty.shape[i])
for i in range(len(indices))
)
return (base[loc],)
@_def_rule("vector.store")
def _vector_store_rule(_, val, base, *indices):
assert len(indices) == len(base.shape)
assert len(indices) == len(val.shape)
loc = tuple(
slice(indices[i], indices[i] + val.shape[i]) for i in range(len(indices))
)
base[loc] = val
return ()
@_def_rule("tpu.create_mask")
def _create_mask_rule(op, *indices):
assert len(indices) % 2 == 0
low, high = indices[:len(indices) // 2], indices[len(indices) // 2:]
res_ty = ir.VectorType(op.result.type)
result = np.full(res_ty.shape, False, dtype=bool)
result[tuple(slice(l, h) for l, h in zip(low, high))] = True
return (result,)
@_def_rule("tpu.all_reduce")
def _all_reduce_rule(op, x):
dim = op.dim.value
kind = op.attributes["kind"]
if kind == ir.Attribute.parse("#tpu.reduction_kind<sum>"):
y = np.sum(x, axis=dim, keepdims=True)
elif kind == ir.Attribute.parse("#tpu.reduction_kind<max>"):
y = np.max(x, axis=dim, keepdims=True)
else:
raise NotImplementedError(f"Unsupported reduction kind: {kind}")
return (np.broadcast_to(y, x.shape),)
@_def_rule("tpu.iota")
def _iota_rule(op):
dim = op.dimension.value
vty = ir.VectorType(op.result.type)
assert vty.element_type == ir.IntegerType.get_signless(32)
shape = vty.shape
nd_arange = np.broadcast_to(
np.arange(shape[dim], dtype=np.int32),
(*shape[:dim], *shape[dim + 1:], shape[dim]))
return (np.moveaxis(nd_arange, -1, dim),)
@_def_rule("tpu.gather")
def _gather_rule(op, source):
dim = op.dimension.value
indices = ir.DenseI32ArrayAttr(op.attributes["indices"])
return (np.swapaxes(np.swapaxes(source, dim, 0)[list(indices)], dim, 0),)
@_def_rule("tpu.dynamic_gather")
def _dynamic_gather_rule(op, source, indices):
dim = op.dimension.value
assert indices.ndim == 2
return (np.swapaxes(np.swapaxes(source, dim, 0)[indices[0]], dim, 0),)
@_def_rule("arith.shli")
def _shli(_, x, y):
# np.left_shift can upcast the output a wider dtype.
# We do not want that. So force the type of the output to match the type of
# the input.
return (np.left_shift(x, y).astype(x.dtype),)
@_def_rule("arith.divsi")
def _divsi_rule(_, dividend, divisor):
# Native python integer division rounds towards negative infinity.
# MLIR-based implementation round towards zero.
abs_dividend = np.abs(dividend)
abs_divisor = np.abs(divisor)
q = abs_dividend // abs_divisor
opposite_sign = np.logical_xor(np.less(dividend, 0), np.less(divisor, 0))
divisor_not_zero = np.not_equal(divisor, 0)
choose_neg = np.logical_and(opposite_sign, divisor_not_zero)
return (np.where(choose_neg, -q, q),)
@_def_rule("arith.remsi")
def _remsi_rule(_, dividend, divisor):
abs_dividend = np.abs(dividend)
abs_divisor = np.abs(divisor)
divsi = _interpreter_rules["arith.divsi"]
q = divsi(None, abs_dividend, abs_divisor)[0]
abs_remainder = abs_dividend - q * abs_divisor
out = np.where(np.less(dividend, 0), -abs_remainder, abs_remainder)
return (out.astype(abs_dividend.dtype),)
@_def_rule("tpu.broadcast_in_sublanes")
def _broadcast_in_sublanes_rule(op, source):
lane = op.lane.value
return (np.broadcast_to(source[0, lane:lane + 8, np.newaxis], source.shape),)
@_def_rule("tpu.rotate")
def _rotate_rule(op, source): # pylint: disable=missing-function-docstring
amount = op.amount.value
axis = op.dimension.value
stride = None if op.stride is None else op.stride.value
stride_axis = (
None if op.stride_dimension is None else op.stride_dimension.value
)
assert (stride is None) == (stride_axis is None)
if stride is None:
return (np.roll(source, amount, axis),)
outputs = [
np.roll(xs, amount + i * stride, axis)
for i, xs in enumerate(
np.split(source, source.shape[stride_axis], stride_axis)
)
]
return (np.concatenate(outputs, stride_axis),)
@_def_rule("tpu.concatenate")
def _concatenate_rule(op, *arr):
return (np.concatenate(arr, axis=op.dimension.value),)