| """A small and simple interpreter for MLIR. |
| |
| Only to be used for testing. |
| """ |
| from collections.abc import Callable, Mapping |
| from typing import Optional, Protocol, Sequence |
| |
| from jax import dtypes |
| from mlir import ir |
| from mlir.dialects import func |
| import numpy as np |
| |
| |
| def _ir_type_to_dtype(ir_type: ir.Type) -> np.dtype: |
| """Convert the given mlir type to the corresponding numpy type. |
| |
| Args: |
| ir_type: mlir type to convert. |
| |
| Returns: |
| the numpy type corresponding to the given mlir type. |
| """ |
| |
| if ir_type == ir.IntegerType.get_signless(1): |
| return np.dtype(np.bool_) |
| elif ir_type == ir.IntegerType.get_signless(8): |
| return np.dtype(np.int8) |
| elif ir_type == ir.IntegerType.get_signless(16): |
| return np.dtype(np.int16) |
| elif ir_type == ir.IntegerType.get_signless(32): |
| return np.dtype(np.int32) |
| elif ir_type == ir.IntegerType.get_signless(64): |
| return np.dtype(np.int64) |
| elif ir_type == ir.IntegerType.get_unsigned(8): |
| return np.dtype(np.uint8) |
| elif ir_type == ir.IntegerType.get_unsigned(16): |
| return np.dtype(np.uint16) |
| elif ir_type == ir.IntegerType.get_unsigned(32): |
| return np.dtype(np.uint32) |
| elif ir_type == ir.IntegerType.get_unsigned(64): |
| return np.dtype(np.uint64) |
| elif ir_type == ir.BF16Type.get(): |
| return np.dtype(dtypes.bfloat16) |
| elif ir_type == ir.F16Type.get(): |
| return np.dtype(np.float16) |
| elif ir_type == ir.F32Type.get(): |
| return np.dtype(np.float32) |
| elif ir_type == ir.F64Type.get(): |
| return np.dtype(np.float64) |
| elif ir_type == ir.IndexType.get(): |
| return np.dtype(np.int32) |
| elif ir.ShapedType.isinstance(ir_type): |
| return _ir_type_to_dtype(ir.ShapedType(ir_type).element_type) |
| else: |
| raise NotImplementedError(f"Unsupported ir type: {ir_type}") |
| |
| |
| class Rule(Protocol): |
| """An interpretation rule for an MLIR primitive.""" |
| |
| def __call__(self, op: ir.OpView, *args: np.ndarray) -> Sequence[np.ndarray]: |
| ... |
| |
| |
| def interpret( |
| f: func.FuncOp, |
| *args: np.ndarray, |
| custom_rules: Optional[Mapping[str, Rule]] = None) -> Sequence[np.ndarray]: |
| """Evaluates the given MLIR function using NumPy operations. |
| |
| Every vector argument is expected to be backed by a NumPy array of |
| corresponding shape and dtype. Every scalar argument should be a NumPy scalar |
| of the corresponding dtype. When this is not the case the inputs are converted |
| to numpy data structures. |
| Notice that MLIR and NumPy type systems are fundamentally incompatible wrt |
| unsigned data types. MLIR typically works with signeless data types and the |
| signedness of the operation is defined by the operation themselves. |
| NumPy instead employs signed and unsigned integers data types. |
| For this reason we cannot afford to be strict when checking the correspondence |
| between NumPy and MLIR types. |
| |
| Args: |
| f: Function to evaluate. |
| *args: The arguments to evaluate at. |
| custom_rules: A dictionary used to extend the set of standard interpretation |
| rules. Maps operation name to its rule. |
| |
| Returns: |
| The result of applying f to args. |
| """ |
| if custom_rules is None: |
| custom_rules = {} |
| env = {} |
| assert len(f.arguments) == len(args) |
| for formal, actual in zip(f.arguments, args): |
| # Convert every actual argument of the block to a numpy type. |
| # Mixing the python/JAX/numpy type systems leads to unexpected behaviors. |
| if isinstance(actual, int) or isinstance(actual, float): |
| actual = _ir_type_to_dtype(formal.type).type(actual) |
| elif not isinstance(actual, (np.ndarray, np.ScalarType)): |
| # This case should be hit by jax.Array. We do not want to add a dependency |
| # on the jax library so we do not have an explicitly check for this type. |
| actual = np.asarray(actual) |
| |
| env[formal] = actual |
| for op in f.entry_block: |
| op_name = op.OPERATION_NAME |
| if op_name == "func.return": |
| return [env[formal] for formal in op.operands] |
| if not (rule := custom_rules.get(op_name, None)): |
| rule = _interpreter_rules[op_name] |
| results = rule(op, *(env[arg] for arg in op.operands)) |
| assert len(op.results) == len(results) |
| for formal, actual in zip(op.results, results): |
| env[formal] = actual |
| raise ValueError("Unterminated function?") |
| |
| |
| _interpreter_rules: dict[str, Rule] = {} |
| |
| |
| def _def_rule(name: str) -> Callable[[Rule], None]: |
| def f(rule): |
| _interpreter_rules[name] = rule |
| return f |
| |
| |
| # go/keep-sorted start |
| _def_rule("arith.addf")(lambda _, x, y: (x + y,)) |
| _def_rule("arith.addi")(lambda _, x, y: (x + y,)) |
| _def_rule("arith.andi")(lambda _, x, y: (x & y,)) |
| _def_rule("arith.divf")(lambda _, x, y: (x / y,)) |
| _def_rule("arith.index_cast")(lambda _, x: (x,)) |
| _def_rule("arith.maximumf")(lambda _, x, y: (np.maximum(x, y),)) |
| _def_rule("arith.maxsi")(lambda _, x, y: (np.maximum(x, y),)) |
| _def_rule("arith.minimumf")(lambda _, x, y: (np.minimum(x, y),)) |
| _def_rule("arith.minsi")(lambda _, x, y: (np.minimum(x, y),)) |
| _def_rule("arith.mulf")(lambda _, x, y: (x * y,)) |
| _def_rule("arith.muli")(lambda _, x, y: (x * y,)) # go/NOTYPO |
| _def_rule("arith.negf")(lambda _, x: (-x,)) |
| _def_rule("arith.ori")(lambda _, x, y: (x | y,)) |
| _def_rule("arith.select")(lambda _, c, t, f: (np.where(c, t, f),)) |
| _def_rule("arith.shrui")(lambda _, x, y: (x >> y,)) |
| _def_rule("arith.subf")(lambda _, x, y: (x - y,)) |
| _def_rule("arith.subi")(lambda _, x, y: (x - y,)) |
| _def_rule("arith.xori")(lambda _, x, y: (x ^ y,)) |
| _def_rule("math.absf")(lambda _, x: (np.abs(x),)) |
| _def_rule("math.absi")(lambda _, x: (np.abs(x),)) |
| _def_rule("math.cos")(lambda _, x: (np.cos(x),)) |
| _def_rule("math.ctlz")(lambda _, x: (f"{x:032b}1".find("1"),)) |
| _def_rule("math.exp")(lambda _, x: (np.exp(x),)) |
| _def_rule("math.exp2")(lambda _, x: (np.power(2., x),)) |
| _def_rule("math.log")(lambda _, x: (np.log(x),)) |
| _def_rule("math.log1p")(lambda _, x: (np.log1p(x),)) |
| _def_rule("math.powf")(lambda _, x, y: (np.power(x, y),)) |
| _def_rule("math.round")(lambda _, x: (np.round(x),)) |
| _def_rule("math.roundeven")(lambda _, x: (np.rint(x),)) |
| _def_rule("math.rsqrt")(lambda _, x: (1.0 / np.sqrt(x),)) |
| _def_rule("math.sin")(lambda _, x: (np.sin(x),)) |
| _def_rule("math.sqrt")(lambda _, x: (np.sqrt(x),)) |
| _def_rule("math.tanh")(lambda _, x: (np.tanh(x),)) |
| _def_rule("memref.load")(lambda _, arr, i: (arr[i],)) |
| # go/keep-sorted end |
| |
| |
| @_def_rule("arith.constant") |
| def _constant_rule(op): # pylint: disable=missing-function-docstring |
| res_ty = op.result.type |
| value = op.attributes["value"] |
| if ir.IntegerType.isinstance(res_ty) or ir.IndexType.isinstance(res_ty): |
| return (_ir_type_to_dtype(res_ty).type(ir.IntegerAttr(value).value),) |
| elif ir.VectorType.isinstance(res_ty): |
| res_ty = ir.VectorType(res_ty) |
| if ir.F32Type.isinstance(res_ty.element_type): |
| attr = ir.DenseFPElementsAttr(value) |
| dtype = np.dtype(np.float32) |
| elif res_ty.element_type == ir.IntegerType.get_signless(32): |
| attr = ir.DenseIntElementsAttr(value) |
| dtype = np.dtype(np.int32) |
| else: |
| raise NotImplementedError(f"Unsupported constant type: {res_ty}") |
| values = np.asarray(list(attr), dtype=dtype) |
| return (values.reshape(res_ty.shape),) |
| raise NotImplementedError(f"Unsupported constant type: {res_ty}") |
| |
| |
| @_def_rule("arith.cmpi") |
| def _cmpi_rule(op, x, y): # pylint: disable=missing-function-docstring |
| cmp = ir.IntegerAttr(op.attributes["predicate"]).value |
| if cmp == 0: |
| return (x == y,) |
| if cmp == 2: |
| if not all(np.issubdtype(v.dtype, np.signedinteger) for v in (x, y)): |
| raise RuntimeError("Signed comparison requested for unsupported dtype") |
| return (x < y,) |
| elif cmp == 5: |
| return (x >= y,) |
| else: |
| raise NotImplementedError(f"Unsupported comparison {cmp}") |
| |
| |
| @_def_rule("arith.sitofp") |
| def _sitofp_rule(op, x): |
| res_ty = op.result.type |
| if ir.F32Type.isinstance(res_ty): |
| return (x.astype(np.float32),) |
| elif ir.VectorType.isinstance(res_ty): |
| res_ty = ir.VectorType(res_ty) |
| if ir.F32Type.isinstance(res_ty.element_type): |
| return (x.astype(np.float32),) |
| raise NotImplementedError(f"Unsupported sitofp type: {res_ty}") |
| |
| |
| @_def_rule("vector.contract") |
| def _contract_rule(op, lhs, rhs, acc): # pylint: disable=missing-function-docstring |
| matmul_maps = ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]) |
| matmul_rhs_transpose_maps = ir.ArrayAttr.get([ |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>"), |
| ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"), |
| ]) |
| indexing_maps = op.attributes["indexing_maps"] |
| if (indexing_maps != matmul_maps and |
| indexing_maps != matmul_rhs_transpose_maps): |
| raise NotImplementedError("Unsupported indexing maps") |
| iterator_types = ir.ArrayAttr.get([ |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<parallel>"), |
| ir.Attribute.parse("#vector.iterator_type<reduction>"), |
| ]) |
| if op.attributes["iterator_types"] != iterator_types: |
| raise NotImplementedError("Unsupported iterator types") |
| if len(op.operands) != 3: |
| raise NotImplementedError("Unsupported masks") |
| rhs_mult = rhs.T if indexing_maps == matmul_rhs_transpose_maps else rhs |
| return (acc + lhs @ rhs_mult,) |
| |
| |
| @_def_rule("tpu.matmul") |
| def _matmul_rule(op, lhs, rhs, acc): # pylint: disable=missing-function-docstring |
| if op.transpose_lhs.value: |
| lhs = lhs.T |
| if op.transpose_rhs.value: |
| rhs = rhs.T |
| return (acc + lhs @ rhs,) |
| |
| |
| @_def_rule("vector.load") |
| def _vector_load_rule(op, base, *indices): |
| res_ty = ir.VectorType(op.result.type) |
| assert len(indices) == len(base.shape) |
| assert len(indices) == len(res_ty.shape) |
| loc = tuple( |
| slice(indices[i], indices[i] + res_ty.shape[i]) |
| for i in range(len(indices)) |
| ) |
| return (base[loc],) |
| |
| |
| @_def_rule("vector.store") |
| def _vector_store_rule(_, val, base, *indices): |
| assert len(indices) == len(base.shape) |
| assert len(indices) == len(val.shape) |
| loc = tuple( |
| slice(indices[i], indices[i] + val.shape[i]) for i in range(len(indices)) |
| ) |
| base[loc] = val |
| return () |
| |
| |
| @_def_rule("tpu.create_mask") |
| def _create_mask_rule(op, *indices): |
| assert len(indices) % 2 == 0 |
| low, high = indices[:len(indices) // 2], indices[len(indices) // 2:] |
| res_ty = ir.VectorType(op.result.type) |
| result = np.full(res_ty.shape, False, dtype=bool) |
| result[tuple(slice(l, h) for l, h in zip(low, high))] = True |
| return (result,) |
| |
| |
| @_def_rule("tpu.all_reduce") |
| def _all_reduce_rule(op, x): |
| dim = op.dim.value |
| kind = op.attributes["kind"] |
| if kind == ir.Attribute.parse("#tpu.reduction_kind<sum>"): |
| y = np.sum(x, axis=dim, keepdims=True) |
| elif kind == ir.Attribute.parse("#tpu.reduction_kind<max>"): |
| y = np.max(x, axis=dim, keepdims=True) |
| else: |
| raise NotImplementedError(f"Unsupported reduction kind: {kind}") |
| return (np.broadcast_to(y, x.shape),) |
| |
| |
| @_def_rule("tpu.iota") |
| def _iota_rule(op): |
| dim = op.dimension.value |
| vty = ir.VectorType(op.result.type) |
| assert vty.element_type == ir.IntegerType.get_signless(32) |
| shape = vty.shape |
| nd_arange = np.broadcast_to( |
| np.arange(shape[dim], dtype=np.int32), |
| (*shape[:dim], *shape[dim + 1:], shape[dim])) |
| return (np.moveaxis(nd_arange, -1, dim),) |
| |
| |
| @_def_rule("tpu.gather") |
| def _gather_rule(op, source): |
| dim = op.dimension.value |
| indices = ir.DenseI32ArrayAttr(op.attributes["indices"]) |
| return (np.swapaxes(np.swapaxes(source, dim, 0)[list(indices)], dim, 0),) |
| |
| |
| @_def_rule("tpu.dynamic_gather") |
| def _dynamic_gather_rule(op, source, indices): |
| dim = op.dimension.value |
| assert indices.ndim == 2 |
| return (np.swapaxes(np.swapaxes(source, dim, 0)[indices[0]], dim, 0),) |
| |
| |
| @_def_rule("arith.shli") |
| def _shli(_, x, y): |
| # np.left_shift can upcast the output a wider dtype. |
| # We do not want that. So force the type of the output to match the type of |
| # the input. |
| return (np.left_shift(x, y).astype(x.dtype),) |
| |
| |
| @_def_rule("arith.divsi") |
| def _divsi_rule(_, dividend, divisor): |
| # Native python integer division rounds towards negative infinity. |
| # MLIR-based implementation round towards zero. |
| abs_dividend = np.abs(dividend) |
| abs_divisor = np.abs(divisor) |
| q = abs_dividend // abs_divisor |
| opposite_sign = np.logical_xor(np.less(dividend, 0), np.less(divisor, 0)) |
| divisor_not_zero = np.not_equal(divisor, 0) |
| choose_neg = np.logical_and(opposite_sign, divisor_not_zero) |
| return (np.where(choose_neg, -q, q),) |
| |
| |
| @_def_rule("arith.remsi") |
| def _remsi_rule(_, dividend, divisor): |
| abs_dividend = np.abs(dividend) |
| abs_divisor = np.abs(divisor) |
| divsi = _interpreter_rules["arith.divsi"] |
| q = divsi(None, abs_dividend, abs_divisor)[0] |
| abs_remainder = abs_dividend - q * abs_divisor |
| out = np.where(np.less(dividend, 0), -abs_remainder, abs_remainder) |
| return (out.astype(abs_dividend.dtype),) |
| |
| |
| @_def_rule("tpu.broadcast_in_sublanes") |
| def _broadcast_in_sublanes_rule(op, source): |
| lane = op.lane.value |
| return (np.broadcast_to(source[0, lane:lane + 8, np.newaxis], source.shape),) |
| |
| |
| @_def_rule("tpu.rotate") |
| def _rotate_rule(op, source): # pylint: disable=missing-function-docstring |
| amount = op.amount.value |
| axis = op.dimension.value |
| stride = None if op.stride is None else op.stride.value |
| stride_axis = ( |
| None if op.stride_dimension is None else op.stride_dimension.value |
| ) |
| assert (stride is None) == (stride_axis is None) |
| if stride is None: |
| return (np.roll(source, amount, axis),) |
| outputs = [ |
| np.roll(xs, amount + i * stride, axis) |
| for i, xs in enumerate( |
| np.split(source, source.shape[stride_axis], stride_axis) |
| ) |
| ] |
| return (np.concatenate(outputs, stride_axis),) |
| |
| |
| @_def_rule("tpu.concatenate") |
| def _concatenate_rule(op, *arr): |
| return (np.concatenate(arr, axis=op.dimension.value),) |