blob: 2d6d3d5679142f33bae752783965e97396810fb6 [file] [log] [blame]
#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "learning/brain/tpu/runtime/topology/tpu_topology.h"
#include "learning/brain/tpu/runtime/topology/tpu_topology_serdes.h"
#include "learning/brain/tpu/runtime/tpu_version.h"
#include "platforms/xla/mosaic/dialect/llo/llo_dialect.h"
#include "platforms/xla/mosaic/dialect/tpu/tpu_passes.h"
#include "platforms/xla/service/jellyfish/gain_latch_mode.h"
#include "platforms/xla/service/jellyfish/matmul_data_format.h"
#include "platforms/xla/service/jellyfish/target.h"
#include "platforms/xla/service/jellyfish/transfer_size_util.h"
#include "platforms/xla/service/jellyfish/vpack_format.h"
#include "third_party/absl/log/check.h"
#include "third_party/absl/log/log.h"
#include "third_party/absl/types/span.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Math/IR/Math.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/ValueRange.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "third_party/tensorflow/compiler/xla/mlir/utils/type_util.h"
#include "third_party/tensorflow/compiler/xla/shape_util.h"
namespace mlir::tpu {
#define GEN_PASS_DECL_LOWERTOLLOPASS
#define GEN_PASS_DEF_LOWERTOLLOPASS
#include "platforms/xla/mosaic/dialect/tpu/tpu_passes.h.inc"
namespace {
bool hasMemorySpace(MemRefType ty, tpu::MemorySpace space) {
auto memory_space =
dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace());
return memory_space && memory_space.getValue() == space;
}
bool getConstantIndices(SmallVectorImpl<int64_t> *storage, ValueRange indices) {
CHECK(storage->empty());
for (auto ix : indices) {
auto ix_cst = ix.getDefiningOp<llo::ConstantOp>();
if (!ix_cst) {
return false;
}
storage->push_back(cast<IntegerAttr>(ix_cst.getValue()).getInt());
}
return true;
}
std::pair<Value, Value> vmemAddrOffsetToAddrDisplacement(
PatternRewriter &rewriter, Value addr, Value offset) {
if (!offset) {
return std::make_pair(addr, nullptr);
}
if (offset.getDefiningOp<llo::ConstantOp>()) {
return std::make_pair(addr, offset);
}
return std::make_pair(
rewriter.create<llo::ScalarAddressVmemOp>(offset.getLoc(), addr, offset)
.getResult(),
nullptr);
}
template <typename Op>
using PatternFunc = std::function<LogicalResult(Op, typename Op::Adaptor,
ConversionPatternRewriter &)>;
template <typename Op>
void addPattern(RewritePatternSet &patterns, PatternFunc<Op> pattern) {
class FuncConversion final : public OpConversionPattern<Op> {
public:
FuncConversion(PatternFunc<Op> pattern, MLIRContext *context)
: OpConversionPattern<Op>(context), pattern_(std::move(pattern)) {}
LogicalResult matchAndRewrite(
Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return pattern_(op, adaptor, rewriter);
}
private:
PatternFunc<Op> pattern_;
};
patterns.insert(std::make_unique<FuncConversion>(std::move(pattern),
patterns.getContext()));
}
llo::VpackFormat GetCompressedByteFormat(unsigned bitwidth) {
switch (bitwidth) {
case 16:
return llo::VpackFormat::kCompressedB16;
case 8:
return llo::VpackFormat::kCompressedB8;
case 4:
return llo::VpackFormat::kCompressedB4;
case 2:
return llo::VpackFormat::kCompressedB2;
case 1:
return llo::VpackFormat::kCompressedB1;
default:
return llo::VpackFormat::kInvalid;
}
}
class MockTpuInstance {
public:
explicit MockTpuInstance(::tpu::TpuVersion version) {
auto maybe_topology =
::tpu::TpuTopologySerdes::Construct({.version = version});
CHECK_OK(maybe_topology.status());
topology_ = std::move(maybe_topology.value());
auto maybe_target =
::xla::jellyfish::Target::CreateFromTopology(topology_.get());
CHECK_OK(maybe_target.status());
target_ = std::move(maybe_target.value());
}
::xla::jellyfish::Target *target() { return target_.get(); }
private:
std::unique_ptr<const ::tpu::TpuTopology> topology_;
std::unique_ptr<::xla::jellyfish::Target> target_;
};
class LowerToLLOPass : public impl::LowerToLLOPassBase<LowerToLLOPass> {
public:
explicit LowerToLLOPass(const xla::jellyfish::Target *target,
bool unsafe_allow_multicore_remote_dma)
: target_(target),
unsafe_allow_multicore_remote_dma_(unsafe_allow_multicore_remote_dma) {}
LowerToLLOPass()
: target_(nullptr), unsafe_allow_multicore_remote_dma_(false) {}
protected:
void runOnOperation() override;
private:
std::unique_ptr<MockTpuInstance> tryCreatingMock() {
if (mock_target == 2) {
return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kJellyfish);
}
if (mock_target == 3) {
return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kDragonfish);
}
if (mock_target == 4) {
return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kPufferfish);
}
return nullptr;
}
int64_t laneCount() { return target_->LaneCount(); }
int64_t sublaneCount() { return target_->SublaneCount(); }
bool isRepresentableVectorType(Type ty) {
auto vty = dyn_cast<VectorType>(ty);
if (!vty) {
return false;
}
auto ety = vty.getElementType();
if (!ety.isF32() && !ety.isBF16() && !ety.isSignlessInteger()) {
return false;
}
if (vty.getRank() == 2) {
return ety.getIntOrFloatBitWidth() == 32 &&
vty.getShape() == ArrayRef<int64_t>{sublaneCount(), laneCount()};
}
return vty.getShape() ==
ArrayRef<int64_t>{sublaneCount(), laneCount(),
32 / ety.getIntOrFloatBitWidth()};
}
std::vector<int64_t> nativeVectorShape(unsigned bitwidth) {
if (bitwidth == 32) {
return {sublaneCount(), laneCount()};
}
CHECK_LT(bitwidth, 32);
CHECK(llvm::isPowerOf2_32(bitwidth));
return {sublaneCount(), laneCount(), 32 / bitwidth};
}
bool isMaskVectorType(Type ty) {
auto vty = dyn_cast<VectorType>(ty);
if (!vty || !vty.getElementType().isSignlessInteger(1) ||
vty.getRank() < 2) {
return false;
}
if (vty.getShape().take_front(2) !=
ArrayRef<int64_t>{sublaneCount(), laneCount()}) {
return false;
}
if (vty.getRank() == 3) {
int32_t max_bits = target_->BitsPerVmregLaneAndSublane();
return vty.getDimSize(2) == 1 ||
(max_bits >= 2 && vty.getDimSize(2) == 2) ||
(max_bits >= 4 && vty.getDimSize(2) == 4);
}
return vty.getRank() == 2;
}
LogicalResult convertMemRef1DTo2D(MemRefType *ty,
SmallVector<int64_t, 2> *indices) {
CHECK_EQ(ty->getRank(), 1);
auto lane_count = laneCount();
auto layout_1d = ty->getLayout().dyn_cast<tpu::TiledLayoutAttr>();
if (!layout_1d) {
return failure();
}
auto packing = 32 / ty->getElementTypeBitWidth();
if (packing == 1) {
if (layout_1d.getTiles().size() != 1) {
return failure();
}
} else {
if (layout_1d.getTiles().size() != 3) {
return failure();
}
auto second = layout_1d.getTiles()[1].dimensions();
auto third = layout_1d.getTiles()[2].dimensions();
if (second.size() != 1 || second.front() != lane_count ||
third.size() != 2 || third[0] != packing || third[1] != 1) {
return failure();
}
}
absl::Span<const int64_t> tiling_1d =
layout_1d.getTiles().front().dimensions();
if (tiling_1d.size() != 1 || tiling_1d.front() % lane_count != 0 ||
ty->getDimSize(0) % lane_count != 0) {
return failure();
}
SmallVector<xla::Tile, 2> tiles;
tiles.push_back(::xla::Tile({tiling_1d.front() / lane_count, lane_count}));
if (packing != 1) {
tiles.push_back(::xla::Tile({packing, 1}));
}
auto layout_2d = tpu::TiledLayoutAttr::get(
ty->getContext(), tiles, {layout_1d.getTileStrides()[0], 1});
int64_t length = ty->getDimSize(0);
std::array<int64_t, 2> shape_2d{length / lane_count, lane_count};
*ty = MemRefType::get(shape_2d, ty->getElementType(), layout_2d,
ty->getMemorySpace());
indices->push_back(indices->front() % lane_count);
indices->front() /= lane_count;
return success();
}
// The read memory region should be of shape [num_read_sublanes, 128].
std::optional<Value> indicesToVmemOffset(ValueRange indices,
int64_t num_read_sublanes,
MemRefType ty,
PatternRewriter &rewriter,
bool consistent_directions = true) {
if (!ty.hasStaticShape() || indices.size() != ty.getRank() ||
ty.getRank() < 1) {
return std::nullopt;
}
CHECK_NE(ty.getRank(), 0);
int64_t tiled_dims = ty.getRank() == 1 ? 1 : 2;
SmallVector<int64_t, 2> tile_idx;
if (!getConstantIndices(&tile_idx, indices.take_back(tiled_dims))) {
return std::nullopt;
}
if (ty.getRank() == 1) {
if (convertMemRef1DTo2D(&ty, &tile_idx).failed()) {
return std::nullopt;
}
}
// Now we can assume we're working with 2D data.
int64_t packing = 32 / ty.getElementTypeBitWidth();
auto layout = dyn_cast<tpu::TiledLayoutAttr>(ty.getLayout());
if (!layout) {
return std::nullopt;
}
auto tiling = layout.getTiles();
if (tiling.empty()) {
return std::nullopt;
}
auto leading_tile = tiling.front().dimensions();
if (leading_tile.size() != 2) {
return std::nullopt;
}
int64_t tile_major = leading_tile[0];
int64_t tile_minor = leading_tile[1];
if (tiling.size() == 1) {
if (packing != 1) {
return std::nullopt;
}
} else if (tiling.size() == 2) {
auto trailing_tile = tiling.back().dimensions();
if (trailing_tile.size() != 2 || trailing_tile[0] != packing ||
trailing_tile[1] != 1) {
return std::nullopt;
}
} else {
return std::nullopt;
}
auto batch_strides = layout.getTileStrides().drop_back(2);
auto tile_strides = layout.getTileStrides().take_back(2);
auto ty_tile_shape = ty.getShape().take_back(2);
if (ty_tile_shape[0] % tile_major != 0 ||
ty_tile_shape[1] % tile_minor != 0) {
return std::nullopt; // Memref doesn't have a padded size?
}
// Compute the tile coordinates.
int64_t major_among_tiles = tile_idx[0] / tile_major;
int64_t minor_among_tiles = tile_idx[1] / tile_minor;
int64_t major_in_tile = tile_idx[0] % tile_major;
int64_t minor_in_tile = tile_idx[1] % tile_minor;
// We assume tiles are Nx128, so anything else would be unaligned.
if (minor_in_tile != 0) {
return std::nullopt;
}
// Make sure the access doesn't access multiple tiles if it shouldn't.
// When the minormost dimension matches the minor tile size, moving to the
// next sublane always advances the more major dimension.
if (consistent_directions && ty_tile_shape[1] != tile_minor &&
major_in_tile + num_read_sublanes > tile_major) {
return std::nullopt;
}
// Every `packing` rows share the same memory address, so accesses should be
// aligned to its multiple.
// TODO(apaszke): Check that the tile size is a multiplicity of packing.
if (tile_major % packing != 0 || major_in_tile % packing != 0) {
return std::nullopt;
}
int64_t tile_bytes =
tile_major * tile_minor * ty.getElementTypeBitWidth() / 8;
// Tiles should be at most one chunk in size and there should be a power of
// 2 of them in each chunk.
if (tile_bytes > target_->ChunkBytes() ||
tile_bytes % target_->SublaneBytes() != 0 ||
!llvm::isPowerOf2_64(tile_bytes / target_->SublaneBytes())) {
return std::nullopt;
}
// Displacement is in the unit of sublanes.
int64_t displacement_per_tile = tile_bytes / target_->SublaneBytes();
int64_t tiled_tile_lin_idx = major_among_tiles * tile_strides[0] +
minor_among_tiles * tile_strides[1];
int64_t tiled_displacement =
tiled_tile_lin_idx * displacement_per_tile + (major_in_tile / packing);
auto batch_idx = indices.drop_back(tiled_dims);
auto s32_cst = [&](int64_t value) -> Value {
return rewriter
.create<llo::ConstantOp>(rewriter.getUnknownLoc(),
rewriter.getI32Type(),
rewriter.getI32IntegerAttr(value))
.getResult();
};
SmallVector<int64_t> non_tile_idx;
// We constant-fold the displacement computation if all indices are static.
if (getConstantIndices(&non_tile_idx, indices.drop_back(tiled_dims))) {
int64_t batch_lin_idx = 0;
CHECK_EQ(non_tile_idx.size(), batch_strides.size());
for (int i = batch_idx.size() - 1; i >= 0; --i) {
batch_lin_idx += non_tile_idx[i] * batch_strides[i];
}
int64_t total_displacement =
batch_lin_idx * displacement_per_tile + tiled_displacement;
if (total_displacement == 0) {
return Value();
}
return s32_cst(total_displacement);
}
Location loc = rewriter.getUnknownLoc();
Value batch_lin_idx_scaled = s32_cst(0);
CHECK_EQ(batch_idx.size(), batch_strides.size());
for (int i = 0; i < batch_idx.size(); ++i) {
CHECK(batch_idx[i].getType().isSignlessInteger(32));
batch_lin_idx_scaled = rewriter.create<llo::ScalarAddS32Op>(
loc, batch_lin_idx_scaled,
// We fold the multiplication by displacement_per_tile into the
// constant here, instead of adding another op at the end.
rewriter.create<llo::ScalarMulS32Op>(
loc, batch_idx[i],
s32_cst(batch_strides[i] * displacement_per_tile)));
}
return rewriter.create<llo::ScalarAddS32Op>(loc, batch_lin_idx_scaled,
s32_cst(tiled_displacement));
}
// Only used for scalar loads.
// Returns the offset to which indices point to. For packed types (<32-bit)
// additionally return the sub-element index within that address (or nullptr
// if the type is not packed).
// TODO(b/299280718): Handle packed 1D memrefs
std::pair<Value, Value> indicesToOffset(Location loc, MemRefType ty,
ValueRange indices,
ConversionPatternRewriter &rewriter) {
auto layout = dyn_cast<TiledLayoutAttr>(ty.getLayout());
if (!layout) {
return std::make_pair(nullptr, nullptr);
}
auto tiles = layout.getTiles();
absl::Span<const int64_t> tile = tiles.front().dimensions();
if (tile.size() == 1) {
// We can ignore the padding only if there's a single dimension.
// We should be able to reuse the code below if ever we need that.
if (ty.getRank() > 1 || tiles.size() != 1) {
return std::make_pair(nullptr, nullptr);
}
return std::make_pair(indices.front(), nullptr);
}
auto int_const = [&](int i) -> Value {
return rewriter.create<llo::ConstantOp>(loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(i));
};
auto mod_i = [&](const Value &v, int d) -> Value {
return rewriter.create<llo::ScalarBitwiseAndOp>(loc, v, int_const(d - 1));
};
auto div_i = [&](const Value &v, int d) -> Value {
return rewriter.create<llo::ScalarDivS32Op>(loc, v, int_const(d));
};
auto mul_i = [&](const Value &v, int d) -> Value {
return rewriter.create<llo::ScalarMulS32Op>(loc, v, int_const(d));
};
auto add = [&](const Value &v1, const Value &v2) -> Value {
return rewriter.create<llo::ScalarAddS32Op>(loc, v1, v2);
};
if (tile.size() == 2) {
// Here we assume one smem word is 4 bytes.
CHECK_EQ(target_->SmemWordSizeBytes(), 4);
unsigned packing = 32 / ty.getElementTypeBitWidth();
if (packing > 1) {
if (tiles.size() != 2 ||
tiles[1].dimensions() != absl::Span<const int64_t>{packing, 1}) {
return std::make_pair(nullptr, nullptr);
}
}
int rank = indices.size();
Value tile_offset = int_const(0);
auto tile_strides = layout.getTileStrides();
for (int i = rank - 1; i >= 0; --i) {
Value index = indices[i];
// Convert element index to tile index.
if (i >= rank - 2) {
index = div_i(index, tile[i - rank + 2]);
}
tile_offset = add(tile_offset, mul_i(index, tile_strides[i]));
}
int64_t tile_bits = tile[0] * tile[1] * ty.getElementTypeBitWidth();
int64_t word_bits = 8 * target_->SmemWordSizeBytes();
CHECK(tile_bits >= word_bits);
// Get total smem words from tiles.
Value word_offset = mul_i(tile_offset, tile_bits / word_bits);
// Calculate the extra words offset in tile when indices are not aligned.
Value part = packing > 1 ? mod_i(indices[rank - 2], packing) : nullptr;
auto major_in_tile = mod_i(indices[rank - 2], tile[0]);
auto minor_in_tile = mod_i(indices[rank - 1], tile[1]);
// word_offset += (major_in_tile / packing) * tile[1] + minor_in_tile
word_offset = add(
word_offset,
add(mul_i(div_i(major_in_tile, packing), tile[1]), minor_in_tile));
return std::make_pair(word_offset, part);
}
return std::make_pair(nullptr, nullptr);
}
std::tuple<uint32_t, int64_t, bool> encodeSublaneMask(
ArrayRef<bool> sublane_mask) {
// TODO(apaszke): Verify length.
uint32_t sublane_mask_i32 = 0;
int64_t num_read_sublanes = 0;
bool all_sublanes_read = true;
for (uint32_t i = 0; i < sublaneCount(); ++i) {
if (!sublane_mask[i]) {
all_sublanes_read = false;
continue;
}
sublane_mask_i32 |= uint32_t{1} << i;
num_read_sublanes = i;
}
return std::make_tuple(sublane_mask_i32, num_read_sublanes,
all_sublanes_read);
}
void populateVectorToLLOConversionPatterns(RewritePatternSet &patterns) {
addPattern<vector::ExtractOp>(
patterns,
[this](vector::ExtractOp op, vector::ExtractOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
if (!isRepresentableVectorType(op.getSourceVectorType()) ||
op.getSourceVectorType().getElementTypeBitWidth() != 32 ||
op.hasDynamicPosition()) {
return failure();
}
for (int64_t pos : op.getStaticPosition()) {
if (pos != 0) {
return failure();
}
}
rewriter.replaceOpWithNewOp<llo::VectorToScalarOp>(op, op.getType(),
subst.getVector());
return success();
});
addPattern<vector::TransferReadOp>(
patterns,
[this](vector::TransferReadOp op, vector::TransferReadOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
auto memref_ty = op.getSource().getType();
auto single_sublane_map = AffineMap::get(
memref_ty.getRank(), 0,
{rewriter.getAffineConstantExpr(0),
rewriter.getAffineDimExpr(memref_ty.getRank() - 1)},
rewriter.getContext());
auto regular_map = AffineMap::get(
memref_ty.getRank(), 0,
{rewriter.getAffineDimExpr(memref_ty.getRank() - 2),
rewriter.getAffineDimExpr(memref_ty.getRank() - 1)},
rewriter.getContext());
auto source_ty = getMemRefType(op.getSource());
auto [source_addr, _] = unpackMemRef(subst.getSource(), rewriter);
if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) {
return failure();
}
if (op.getPermutationMap() == single_sublane_map) {
std::optional<Value> offset =
indicesToVmemOffset(subst.getIndices(), 1, source_ty, rewriter);
if (!offset) {
return failure();
}
auto [source, displacement] = vmemAddrOffsetToAddrDisplacement(
rewriter, source_addr, *offset);
rewriter.replaceOpWithNewOp<llo::VectorLoadOp>(
op, op.getType(), source,
/*displacement=*/displacement,
/*sublane_mask=*/nullptr, /*sublane_stride=*/0);
return success();
}
if (op.getPermutationMap() == regular_map) {
std::optional<Value> offset = indicesToVmemOffset(
subst.getIndices(), sublaneCount(), source_ty, rewriter);
if (!offset) {
return failure();
}
auto [source, displacement] = vmemAddrOffsetToAddrDisplacement(
rewriter, source_addr, *offset);
rewriter.replaceOpWithNewOp<llo::VectorLoadOp>(
op, op.getType(), source,
/*displacement=*/displacement,
/*sublane_mask=*/nullptr, /*sublane_stride=*/1);
return success();
}
return failure();
});
addPattern<vector::LoadOp>(
patterns,
[this](vector::LoadOp op, vector::LoadOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
auto source_ty = getMemRefType(op.getBase());
auto [source_addr, _] = unpackMemRef(subst.getBase(), rewriter);
if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) {
return failure();
}
std::optional<Value> offset = indicesToVmemOffset(
subst.getIndices(), sublaneCount(), source_ty, rewriter);
if (!offset) {
return failure();
}
auto [base, displacement] =
vmemAddrOffsetToAddrDisplacement(rewriter, source_addr, *offset);
rewriter.replaceOpWithNewOp<llo::VectorLoadOp>(
op, op.getType(), base, /*displacement=*/displacement,
/*sublane_mask=*/nullptr, /*sublane_stride=*/1);
return success();
});
addPattern<vector::StoreOp>(
patterns, [this](vector::StoreOp op, vector::StoreOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(subst.getValueToStore().getType())) {
return failure();
}
auto source_ty = getMemRefType(op.getBase());
auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter);
if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) {
return failure();
}
std::optional<Value> offset = indicesToVmemOffset(
subst.getIndices(), sublaneCount(), source_ty, rewriter);
if (!offset) {
return failure();
}
auto [base, displacement] =
vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset);
rewriter.replaceOpWithNewOp<llo::VectorStoreOp>(
op, /*address=*/base, /*displacement=*/displacement,
/*to_store=*/subst.getValueToStore(), /*sublane_mask=*/nullptr,
/*sublane_stride=*/1);
return success();
});
addPattern<vector::ContractionOp>(
patterns,
[this](vector::ContractionOp op, vector::ContractionOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
// This pattern is only meant to support a small class of matmuls.
if (subst.getKind() != vector::CombiningKind::ADD) {
return failure();
}
auto ctx = rewriter.getContext();
auto matmul_iterator_types = rewriter.getArrayAttr({
vector::IteratorTypeAttr::get(ctx,
vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx,
vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx,
vector::IteratorType::reduction),
});
if (subst.getIteratorTypes() != matmul_iterator_types) {
return failure();
}
auto maps = subst.getIndexingMaps().getValue();
auto lhs_map = mlir::AffineMapAttr::get(mlir::AffineMap::get(
3, 0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)},
rewriter.getContext()));
auto rhs_map = mlir::AffineMapAttr::get(mlir::AffineMap::get(
3, 0,
{rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(1)},
rewriter.getContext()));
auto rhs_transp_map = mlir::AffineMapAttr::get(mlir::AffineMap::get(
3, 0,
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(2)},
rewriter.getContext()));
auto out_map = mlir::AffineMapAttr::get(mlir::AffineMap::get(
3, 0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
rewriter.getContext()));
if (maps.size() != 3) {
return failure();
}
if (maps[0] != lhs_map ||
(maps[1] != rhs_map && maps[1] != rhs_transp_map) ||
maps[2] != out_map) {
return failure();
}
bool transposed = maps[1] == rhs_transp_map;
auto precision_attr = dyn_cast_if_present<tpu::ContractPrecisionAttr>(
op->getAttr("precision"));
rewriter.replaceOpWithNewOp<tpu::MatmulOp>(
op, op.getType(), op.getLhs(), op.getRhs(), op.getAcc(),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(transposed),
precision_attr);
return success();
});
addPattern<vector::BroadcastOp>(
patterns,
[this](vector::BroadcastOp op, vector::BroadcastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
if (!op.getSourceType().isSignlessInteger(32) &&
!op.getSourceType().isF32()) {
return failure();
}
rewriter.replaceOpWithNewOp<llo::ScalarToVectorOp>(op, op.getType(),
subst.getSource());
return success();
});
addPattern<vector::TransposeOp>(
patterns,
[this](vector::TransposeOp op, vector::TransposeOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
int64_t transpose_width = target_->LaneCount();
auto src_ty = op.getSourceVectorType();
auto dst_ty = op.getResultVectorType();
ArrayRef<int64_t> permutation = op.getPermutation();
if (permutation.size() != 2 || permutation[0] != 1 ||
permutation[1] != 0) {
return failure();
}
int batching_factor;
if (src_ty == op.getResultVectorType() &&
src_ty.getShape() ==
ArrayRef<int64_t>{transpose_width, transpose_width}) {
batching_factor = 1; // No batching.
} else if (src_ty.getElementType() == dst_ty.getElementType() &&
src_ty.getElementTypeBitWidth() == 16 &&
src_ty.getShape() ==
ArrayRef<int64_t>{transpose_width,
transpose_width * 2} &&
dst_ty.getShape() == ArrayRef<int64_t>{transpose_width * 2,
transpose_width}) {
batching_factor = 2;
if (!target_->SupportsVsupp()) {
return failure();
}
} else {
return failure();
}
ValueRange src_vregs;
if (auto roll_op =
op.getVector().getDefiningOp<tpu::RollVectorsOp>()) {
src_vregs = roll_op.getOperands();
} else {
return failure();
}
auto packing =
target_->VectorScalarBitWidth() / src_ty.getElementTypeBitWidth();
if (src_vregs.size() != transpose_width * batching_factor /
(target_->SublaneCount() * packing)) {
return failure();
}
auto width = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(128));
std::vector<Value> result_vregs;
if (batching_factor == 1) {
llo::VxposeMode mode;
if (src_ty.getElementTypeBitWidth() == 32) {
mode = llo::VxposeMode::kB32;
} else if (src_ty.getElementTypeBitWidth() == 16) {
mode = llo::VxposeMode::kCompressedB16;
} else {
return failure();
}
for (int32_t i = 0; i < src_vregs.size(); ++i) {
rewriter.create<llo::VectorTransposeOp>(
op.getLoc(), src_vregs[i], mode, width, i, src_vregs.size(),
/*xlu_id=*/0, /*source_bus=*/nullptr);
}
result_vregs.reserve(src_vregs.size());
for (int32_t i = 0; i < src_vregs.size(); ++i) {
result_vregs.push_back(
rewriter.create<llo::VectorTransposeResultOp>(
op.getLoc(), src_vregs[0].getType()));
}
} else if (batching_factor == 2) {
CHECK_EQ(src_vregs.size() % 2, 0);
CHECK_EQ(src_ty.getElementTypeBitWidth(), 16);
for (int32_t i = 0; i < src_vregs.size() / 2; ++i) {
rewriter.create<llo::VectorTransposeBinaryCompressedB16Op>(
op.getLoc(), src_vregs[i * 2], src_vregs[i * 2 + 1], width, i,
src_vregs.size(), /*xlu_id=*/0, /*source_bus=*/nullptr);
}
result_vregs = std::vector<Value>(src_vregs.size(), nullptr);
for (int32_t i = 0; i < src_vregs.size() / 2; ++i) {
result_vregs[i] = rewriter.create<llo::VectorTransposeResultOp>(
op.getLoc(), src_vregs[0].getType());
result_vregs[i + src_vregs.size() / 2] =
rewriter.create<llo::VectorTransposeResultOp>(
op.getLoc(), src_vregs[0].getType());
}
} else {
LOG(FATAL) << "Unrecognized batching factor";
}
rewriter.replaceOpWithNewOp<tpu::RollVectorsOp>(op, op.getType(),
result_vregs);
return success();
});
}
template <typename Op, typename LLOOpType, int arity,
bool (*elem_type_filter)(Type ty)>
PatternFunc<Op> VectorElementwisePattern() {
return [this](Op op, typename Op::Adaptor subst,
ConversionPatternRewriter &rewriter) {
if (op->getNumOperands() != arity) {
return failure();
}
for (Value operand : op->getOperands()) {
if (!isRepresentableVectorType(operand.getType()) &&
!isMaskVectorType(operand.getType())) {
return failure();
}
if (!elem_type_filter(
cast<VectorType>(operand.getType()).getElementType())) {
return failure();
}
}
rewriter.replaceOpWithNewOp<LLOOpType>(op, op.getType(),
subst.getOperands());
return success();
};
}
template <typename Op, typename LLOOpType, int arity,
bool (*elem_type_filter)(Type ty)>
PatternFunc<Op> ScalarElementwisePattern() {
return [](Op op, typename Op::Adaptor subst,
ConversionPatternRewriter &rewriter) {
if (op->getNumOperands() != arity) {
return failure();
}
for (Value operand : op->getOperands()) {
if (!elem_type_filter(operand.getType())) {
return failure();
}
}
auto ty = subst.getOperands().front().getType();
rewriter.replaceOpWithNewOp<LLOOpType>(op, ty, subst.getOperands());
return success();
};
}
void populateArithToLLOConversionPatterns(RewritePatternSet &patterns) {
constexpr bool (*f32)(Type) = [](Type ty) { return ty.isF32(); };
constexpr bool (*i32)(Type) = [](Type ty) {
return ty.isSignlessInteger(32) || ty.isIndex();
};
constexpr bool (*i1)(Type) = [](Type ty) {
return ty.isSignlessInteger(1);
};
addPattern(
patterns,
VectorElementwisePattern<arith::AddFOp, llo::VectorAddF32Op, 2, f32>());
addPattern(
patterns,
VectorElementwisePattern<arith::SubFOp, llo::VectorSubF32Op, 2, f32>());
addPattern(
patterns,
VectorElementwisePattern<arith::MulFOp, llo::VectorMulF32Op, 2, f32>());
addPattern(patterns,
VectorElementwisePattern<arith::MaximumFOp, llo::VectorMaxF32Op,
2, f32>());
addPattern(patterns,
VectorElementwisePattern<arith::MinimumFOp, llo::VectorMinF32Op,
2, f32>());
addPattern(patterns,
VectorElementwisePattern<arith::MaxSIOp, llo::VectorMaxS32Op, 2,
i32>());
addPattern(patterns,
VectorElementwisePattern<arith::MinSIOp, llo::VectorMinS32Op, 2,
i32>());
addPattern<arith::DivFOp>(
patterns, [this](arith::DivFOp op, arith::DivFOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
if (!cast<VectorType>(op.getType()).getElementType().isF32()) {
return failure();
}
auto rhs_inv = rewriter.create<llo::VectorRecipF32Op>(op.getLoc(),
subst.getRhs());
rewriter.replaceOpWithNewOp<llo::VectorMulF32Op>(op, subst.getLhs(),
rhs_inv);
return success();
});
addPattern(
patterns,
VectorElementwisePattern<arith::AndIOp, llo::VectorMaskAndOp, 2, i1>());
addPattern(
patterns,
VectorElementwisePattern<arith::OrIOp, llo::VectorMaskOrOp, 2, i1>());
addPattern(
patterns,
VectorElementwisePattern<arith::XOrIOp, llo::VectorMaskXorOp, 2, i1>());
addPattern(
patterns,
VectorElementwisePattern<arith::NegFOp, llo::VectorNegF32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<arith::AddIOp, llo::VectorAddS32Op, 2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::SubIOp, llo::VectorSubS32Op, 2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::MulIOp, llo::VectorMulS32Op, 2, i32>());
addPattern(patterns,
VectorElementwisePattern<arith::DivSIOp, llo::VectorDivS32Op, 2,
i32>());
addPattern(patterns,
VectorElementwisePattern<arith::RemSIOp, llo::VectorRemS32Op, 2,
i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::ShLIOp, llo::VectorShiftLeftLogicalOp,
2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::ShRUIOp, llo::VectorShiftRightLogicalOp,
2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::ShRSIOp,
llo::VectorShiftRightArithmeticOp, 2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::AndIOp, llo::VectorAndU32Op, 2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::OrIOp, llo::VectorOrU32Op, 2, i32>());
addPattern(
patterns,
VectorElementwisePattern<arith::XOrIOp, llo::VectorXOrU32Op, 2, i32>());
addPattern<arith::XOrIOp>(
patterns, [](arith::XOrIOp op, arith::XOrIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
// Lower arith.xori of i1 inputs.
for (Value operand : op.getOperands()) {
if (!operand.getType().isSignlessInteger(1)) {
return failure();
}
}
// arith.xori lhs, -1
// is lowered to:
// llo.pnot lhs
// MLIR canonicalization passes (run before this stage) move constants
// of commutative operations to the rhs
// (https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules),
// so we assume that the constant -1 is on the rhs.
bool is_not = false;
Value op_rhs = op.getOperands().back();
auto crhs = op_rhs.getDefiningOp<arith::ConstantOp>();
if (crhs != nullptr) {
is_not = cast<mlir::IntegerAttr>(crhs.getValue()).getInt() == -1;
}
if (is_not) {
rewriter.replaceOpWithNewOp<llo::PredicateNegateOp>(
op, op.getType(), subst.getLhs());
return success();
}
Value lhs = subst.getLhs();
Value rhs = subst.getRhs();
// Implement XOR between lhs and rhs as follows:
//
// lhs_and_not_rhs = lhs & !rhs
// not_lhs_and_rhs = !lhs & rhs
// xor = lhs_and_not_rhs | not_lhs_and_rhs
auto not_lhs =
rewriter.create<llo::PredicateNegateOp>(op.getLoc(), lhs);
auto not_rhs =
rewriter.create<llo::PredicateNegateOp>(op.getLoc(), rhs);
auto lhs_and_not_rhs =
rewriter.create<llo::PredicateAndOp>(op.getLoc(), lhs, not_rhs);
auto not_lhs_and_rhs =
rewriter.create<llo::PredicateAndOp>(op.getLoc(), not_lhs, rhs);
rewriter.replaceOpWithNewOp<llo::PredicateOrOp>(
op, op.getType(), lhs_and_not_rhs, not_lhs_and_rhs);
return success();
});
addPattern(
patterns,
ScalarElementwisePattern<arith::NegFOp, llo::ScalarNegF32Op, 1, f32>());
addPattern(
patterns,
ScalarElementwisePattern<math::SqrtOp, llo::ScalarSqrtF32Op, 1, f32>());
addPattern(
patterns,
ScalarElementwisePattern<math::ExpOp, llo::ScalarExpF32Op, 1, f32>());
addPattern(
patterns,
ScalarElementwisePattern<math::LogOp, llo::ScalarLogF32Op, 1, f32>());
addPattern(patterns,
ScalarElementwisePattern<math::RoundOp, llo::ScalarRoundF32Op, 1,
f32>());
addPattern(patterns,
ScalarElementwisePattern<math::RoundEvenOp,
llo::ScalarRoundEvenF32Op, 1, f32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::AddFOp, llo::ScalarAddF32Op, 2, f32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::AddIOp, llo::ScalarAddS32Op, 2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::SubFOp, llo::ScalarSubF32Op, 2, f32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::SubIOp, llo::ScalarSubS32Op, 2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::MulFOp, llo::ScalarMulF32Op, 2, f32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::MulIOp, llo::ScalarMulS32Op, 2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::ShLIOp, llo::ScalarShllOp, 2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::ShRUIOp, llo::ScalarShrlOp, 2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::ShRSIOp, llo::ScalarShraOp, 2, i32>());
addPattern(patterns,
ScalarElementwisePattern<arith::OrIOp, llo::ScalarBitwiseOrOp, 2,
i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::OrIOp, llo::PredicateOrOp, 2, i1>());
addPattern(
patterns,
ScalarElementwisePattern<arith::AndIOp, llo::PredicateAndOp, 2, i1>());
addPattern(patterns,
ScalarElementwisePattern<arith::AndIOp, llo::ScalarBitwiseAndOp,
2, i32>());
addPattern(
patterns,
ScalarElementwisePattern<arith::DivFOp, llo::ScalarDivF32Op, 2, f32>());
addPattern(patterns,
ScalarElementwisePattern<arith::DivSIOp, llo::ScalarDivS32Op, 2,
i32>());
addPattern(patterns,
ScalarElementwisePattern<arith::RemSIOp, llo::ScalarRemS32Op, 2,
i32>());
addPattern(patterns,
ScalarElementwisePattern<arith::MaximumFOp, llo::ScalarMaxF32Op,
2, f32>());
addPattern(patterns,
ScalarElementwisePattern<arith::MinimumFOp, llo::ScalarMinF32Op,
2, f32>());
addPattern(patterns,
ScalarElementwisePattern<arith::MaxSIOp, llo::ScalarMaxS32Op, 2,
i32>());
addPattern(patterns,
ScalarElementwisePattern<arith::MinSIOp, llo::ScalarMinS32Op, 2,
i32>());
addPattern<arith::ConstantOp>(
patterns, [this](arith::ConstantOp op, arith::ConstantOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto attr = op.getValue();
if (attr.getType().isUnsignedInteger(32) ||
attr.getType().isSignlessInteger(32) ||
attr.getType().isSignlessInteger(1) || attr.getType().isF32()) {
goto supported;
}
if (auto splat = dyn_cast<SplatElementsAttr>(attr)) {
if (isRepresentableVectorType(op.getType()) ||
isMaskVectorType(op.getType())) {
goto supported;
}
}
if (attr.getType().isIndex()) {
rewriter.replaceOpWithNewOp<llo::ConstantOp>(
op, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(
cast<mlir::IntegerAttr>(attr).getInt()));
return success();
}
// Constants that are not representable in LLO should become dead.
rewriter.eraseOp(op);
return success();
supported:
rewriter.replaceOpWithNewOp<llo::ConstantOp>(op, op.getType(), attr);
return success();
});
addPattern<arith::IndexCastOp>(
patterns, [](arith::IndexCastOp op, arith::IndexCastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto in_ty = op.getIn().getType();
auto out_ty = op.getType();
if ((in_ty.isIndex() && out_ty.isSignlessInteger(32)) ||
(in_ty.isSignlessInteger(32) && out_ty.isIndex())) {
rewriter.replaceOp(op, subst.getIn());
return success();
}
op.emitOpError("Unsupported cast");
return failure();
});
addPattern<arith::CmpFOp>(
patterns, [this](arith::CmpFOp op, arith::CmpFOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto arg_ty = op.getOperandTypes().front();
if (auto vec_ty = dyn_cast<VectorType>(arg_ty)) {
if (!isRepresentableVectorType(vec_ty) ||
!isMaskVectorType(op.getType()) ||
!vec_ty.getElementType().isF32()) {
return failure();
}
switch (op.getPredicate()) {
case mlir::arith::CmpFPredicate::OEQ:
rewriter.replaceOpWithNewOp<llo::VectorCmpEqF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::ONE:
rewriter.replaceOpWithNewOp<llo::VectorCmpNeF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OLT:
rewriter.replaceOpWithNewOp<llo::VectorCmpLtF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OLE:
rewriter.replaceOpWithNewOp<llo::VectorCmpLeF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OGT:
rewriter.replaceOpWithNewOp<llo::VectorCmpGtF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OGE:
rewriter.replaceOpWithNewOp<llo::VectorCmpGeF32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
default:
return failure();
}
} else {
if (!arg_ty.isF32()) {
return failure();
}
switch (op.getPredicate()) {
case mlir::arith::CmpFPredicate::OEQ:
rewriter.replaceOpWithNewOp<llo::ScalarCmpEqF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::ONE:
rewriter.replaceOpWithNewOp<llo::ScalarCmpNeF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OLT:
rewriter.replaceOpWithNewOp<llo::ScalarCmpLtF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OLE:
rewriter.replaceOpWithNewOp<llo::ScalarCmpLeF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OGT:
rewriter.replaceOpWithNewOp<llo::ScalarCmpGtF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpFPredicate::OGE:
rewriter.replaceOpWithNewOp<llo::ScalarCmpGeF32Op>(
op, subst.getLhs(), subst.getRhs());
break;
default:
return failure();
}
}
return success();
});
addPattern<arith::CmpIOp>(
patterns, [this](arith::CmpIOp op, arith::CmpIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto arg_ty = op.getOperandTypes().front();
if (auto vec_ty = dyn_cast<VectorType>(arg_ty)) {
if (!isRepresentableVectorType(vec_ty) ||
!isMaskVectorType(op.getType()) ||
!vec_ty.getElementType().isSignlessInteger(32)) {
return failure();
}
switch (op.getPredicate()) {
case mlir::arith::CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<llo::VectorCmpEqS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<llo::VectorCmpNeS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<llo::VectorCmpLtS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<llo::VectorCmpLeS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<llo::VectorCmpGtS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<llo::VectorCmpGeS32Op>(
op, op.getType(), subst.getLhs(), subst.getRhs());
break;
default:
return failure();
}
} else {
if (!arg_ty.isSignlessInteger(32) && !arg_ty.isIndex()) {
return failure();
}
switch (op.getPredicate()) {
case mlir::arith::CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<llo::ScalarCmpEqS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<llo::ScalarCmpNeS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<llo::ScalarCmpLtS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<llo::ScalarCmpLeS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<llo::ScalarCmpGtS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
case mlir::arith::CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<llo::ScalarCmpGeS32Op>(
op, subst.getLhs(), subst.getRhs());
break;
default:
return failure();
}
}
return success();
});
addPattern<arith::SIToFPOp>(
patterns, [this](arith::SIToFPOp op, arith::SIToFPOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto arg_ty = subst.getIn().getType();
if (VectorType in_ty = dyn_cast<VectorType>(arg_ty)) {
VectorType out_ty = dyn_cast<VectorType>(op.getType());
if (!in_ty || !out_ty || !isRepresentableVectorType(in_ty) ||
!isRepresentableVectorType(out_ty)) {
return failure();
}
if (in_ty.getElementType() == rewriter.getI32Type() &&
out_ty.getElementType() == rewriter.getF32Type()) {
rewriter.replaceOpWithNewOp<llo::VectorConvertS32ToF32Op>(
op, out_ty, subst.getIn());
return success();
}
} else if (arg_ty == rewriter.getI32Type() &&
op.getType() == rewriter.getF32Type()) {
rewriter.replaceOpWithNewOp<llo::ScalarConvertS32ToF32Op>(
op, op.getType(), subst.getIn());
return success();
}
return failure();
});
addPattern<arith::FPToSIOp>(
patterns, [this](arith::FPToSIOp op, arith::FPToSIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto arg_ty = subst.getIn().getType();
if (VectorType in_ty = dyn_cast<VectorType>(arg_ty)) {
VectorType out_ty = dyn_cast<VectorType>(op.getType());
if (!in_ty || !out_ty || !isRepresentableVectorType(in_ty) ||
!isRepresentableVectorType(out_ty)) {
return failure();
}
if (in_ty.getElementType() == rewriter.getF32Type() &&
out_ty.getElementType() == rewriter.getI32Type()) {
rewriter.replaceOpWithNewOp<
llo::VectorConvertF32ToS32TowardsZeroPseudoOp>(op, out_ty,
subst.getIn());
return success();
}
} else if (arg_ty == rewriter.getF32Type() &&
op.getType() == rewriter.getI32Type()) {
rewriter.replaceOpWithNewOp<
llo::ScalarConvertF32ToS32TowardsZeroPseudoOp>(op, op.getType(),
subst.getIn());
return success();
}
return failure();
});
addPattern<arith::SelectOp>(
patterns, [this](arith::SelectOp op, arith::SelectOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (auto ty = dyn_cast<VectorType>(op.getType())) {
if (!isRepresentableVectorType(ty)) {
return failure();
}
if (!isa<VectorType>(subst.getCondition().getType())) {
return failure();
}
rewriter.replaceOpWithNewOp<llo::VectorSelectOp>(
op, op.getType(), subst.getCondition(), subst.getTrueValue(),
subst.getFalseValue());
return success();
}
if (op.getType().isSignlessInteger(32) || op.getType().isIndex() ||
op.getType().isF32()) {
rewriter.replaceOpWithNewOp<llo::ScalarSelectOp>(
op, subst.getTrueValue().getType(), subst.getCondition(),
subst.getTrueValue(), subst.getFalseValue());
return success();
}
return failure();
});
addPattern<arith::ExtSIOp>(
patterns, [](arith::ExtSIOp op, arith::ExtSIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
// Narrow integer types are already converted to 32-bit integers.
if (!op.getType().isSignlessInteger(32) ||
!subst.getIn().getType().isSignlessInteger(32)) {
return failure();
}
rewriter.replaceOp(op, subst.getIn());
return success();
});
addPattern<arith::TruncIOp>(
patterns, [](arith::TruncIOp op, arith::TruncIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!op.getIn().getType().isSignlessInteger(32)) {
return failure();
}
unsigned bitwidth = op.getType().getIntOrFloatBitWidth();
rewriter.replaceOpWithNewOp<llo::ScalarBitwiseAndOp>(
op, subst.getIn(),
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr((1 << bitwidth) - 1)));
return success();
});
addPattern<arith::ExtUIOp>(
patterns, [](arith::ExtUIOp op, arith::ExtUIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
Type in_ty = subst.getIn().getType();
if (in_ty == op.getType()) {
// The type of the new value is already equal to the destination
// type, there is nothing to do.
rewriter.replaceOp(op, subst.getIn());
return success();
}
if (auto in_vec_ty = dyn_cast<VectorType>(in_ty)) {
auto out_vec_ty = cast<VectorType>(op.getType());
if (in_vec_ty.getElementTypeBitWidth() != 1 ||
out_vec_ty.getElementTypeBitWidth() != 32) {
return failure();
}
auto make_splat = [&](int64_t value) -> Value {
return rewriter.create<llo::ConstantOp>(
op.getLoc(), out_vec_ty,
SplatElementsAttr::get(out_vec_ty,
rewriter.getI32IntegerAttr(value)));
};
rewriter.replaceOpWithNewOp<llo::VectorSelectOp>(
op, out_vec_ty, subst.getIn(), make_splat(1), make_splat(0));
return success();
}
if (in_ty.isSignlessInteger(1) &&
op.getType().isSignlessInteger(32)) {
auto make_const = [&](int64_t value) -> Value {
return rewriter.create<llo::ConstantOp>(
op.getLoc(), op.getType(), rewriter.getI32IntegerAttr(value));
};
rewriter.replaceOpWithNewOp<llo::ScalarSelectOp>(
op, op.getType(), subst.getIn(), make_const(1), make_const(0));
return success();
}
return failure();
});
addPattern<arith::BitcastOp>(
patterns, [](arith::BitcastOp op, arith::BitcastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto in_ty = op.getIn().getType();
auto out_ty = op.getType();
auto in_vty = dyn_cast<VectorType>(in_ty);
auto out_vty = dyn_cast<VectorType>(out_ty);
if (in_vty && out_vty &&
in_vty.getElementTypeBitWidth() ==
out_vty.getElementTypeBitWidth()) {
rewriter.replaceOpWithNewOp<llo::VectorBitcastOp>(op, out_ty,
subst.getIn());
return success();
}
if (in_ty.getIntOrFloatBitWidth() == out_ty.getIntOrFloatBitWidth()) {
// We only allow 32-bit types in LLO and any narrower type is
// represented by i32, with the valid data being in the low bits.
// So, for casts between 32-bit types we need to do the bit case.
// But, for casts between narrower types, we shouldn't do anything.
if (in_ty.getIntOrFloatBitWidth() == 32) {
rewriter.replaceOpWithNewOp<llo::ScalarBitcastOp>(op, out_ty,
subst.getIn());
} else {
rewriter.replaceOp(op, subst.getIn());
}
return success();
}
op.emitOpError("Failed to arith::bitcast")
<< in_ty << " to " << out_ty;
return failure();
});
}
void populateMathToLLOConversionPatterns(RewritePatternSet &patterns) {
constexpr bool (*f32)(Type) = [](Type ty) { return ty.isF32(); };
constexpr bool (*i32)(Type) = [](Type ty) {
return ty.isSignlessInteger(32) || ty.isIndex();
};
addPattern(patterns,
VectorElementwisePattern<math::RsqrtOp, llo::VectorRsqrtF32Op, 1,
f32>());
addPattern(
patterns,
VectorElementwisePattern<math::SqrtOp, llo::VectorSqrtF32Op, 1, f32>());
addPattern(
patterns,
ScalarElementwisePattern<math::CountLeadingZerosOp,
llo::ScalarCountLeadingZerosOp, 1, i32>());
addPattern(
patterns,
VectorElementwisePattern<math::ExpOp, llo::VectorExpF32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::Exp2Op, llo::VectorPow2F32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::CosOp, llo::VectorCosF32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::SinOp, llo::VectorSinF32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::TanhOp, llo::VectorTanhF32Op, 1, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::LogOp, llo::VectorLogF32Op, 1, f32>());
addPattern(patterns,
VectorElementwisePattern<math::Log1pOp, llo::VectorLog1pF32Op, 1,
f32>());
addPattern(
patterns,
VectorElementwisePattern<math::PowFOp, llo::VectorPowF32Op, 2, f32>());
addPattern(
patterns,
VectorElementwisePattern<math::AbsFOp, llo::VectorAbsF32Op, 1, f32>());
addPattern(patterns,
VectorElementwisePattern<math::RoundOp, llo::VectorRoundF32Op, 1,
f32>());
addPattern(patterns,
VectorElementwisePattern<math::RoundEvenOp,
llo::VectorRoundEvenF32Op, 1, f32>());
addPattern<math::AbsIOp>(
patterns, [this](math::AbsIOp op, math::AbsIOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
Value operand = subst.getOperand();
Type operandType = operand.getType();
if (!isRepresentableVectorType(operandType)) {
return failure();
}
VectorType operandVectorType = cast<VectorType>(operandType);
if (operandVectorType.getElementType() != rewriter.getI32Type()) {
return failure();
}
// Build the following instruction sequence:
// cmp = vcmp.lt.32 input, 0
// neg = vneg.32 input
// out = vselect cmp, neg, input
Value zero = rewriter.create<llo::ConstantOp>(
op.getLoc(), operandVectorType,
SplatElementsAttr::get(operandVectorType,
rewriter.getI32IntegerAttr(0)));
VectorType cmpVectorType = VectorType::get(
operandVectorType.getShape(), rewriter.getI1Type());
Value cmp = rewriter.create<llo::VectorCmpLtS32Op>(
op.getLoc(), cmpVectorType, operand, zero);
Value neg =
rewriter.create<llo::VectorNegS32Op>(op.getLoc(), operand);
rewriter.replaceOpWithNewOp<llo::VectorSelectOp>(
op, operandVectorType, cmp, neg, operand);
return success();
});
}
void populateMemrefToLLOConversionPatterns(RewritePatternSet &patterns) {
addPattern<memref::LoadOp>(
patterns, [this](memref::LoadOp op, memref::LoadOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto ty = getMemRefType(op.getMemRef());
auto [ref_addr, _] = unpackMemRef(subst.getMemref(), rewriter);
if (!hasMemorySpace(op.getMemRefType(), tpu::MemorySpace::smem)) {
return failure();
}
if (!op.getType().isIntOrFloat() ||
op.getType().getIntOrFloatBitWidth() > 32) {
return failure();
}
auto [offset, part] =
indicesToOffset(op.getLoc(), ty, subst.getIndices(), rewriter);
if (!offset) {
return failure();
}
Type new_type;
if (op.getType().isSignlessInteger()) {
new_type = rewriter.getI32Type();
} else if (isa<FloatType>(op.getType())) {
new_type = rewriter.getF32Type();
} else {
return failure();
}
Value result = rewriter.create<llo::ScalarLoadOp>(
op.getLoc(), new_type, ref_addr, offset);
if (part != nullptr) {
unsigned bitwidth = op.getType().getIntOrFloatBitWidth();
unsigned base_mask = (1 << bitwidth) - 1;
Value mask = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(base_mask));
Value shift = rewriter.create<llo::ScalarMulS32Op>(
op.getLoc(), part,
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(bitwidth)));
// Bitcast result to I32 type for shifting.
result = rewriter.create<llo::ScalarBitcastOp>(
op.getLoc(), rewriter.getI32Type(), result);
// Shift and mask the subelement we want.
result = rewriter.create<llo::ScalarBitwiseAndOp>(
op.getLoc(), mask,
rewriter.create<llo::ScalarShrlOp>(op.getLoc(), result, shift));
}
rewriter.replaceOp(op, result);
return success();
});
addPattern<memref::StoreOp>(
patterns, [this](memref::StoreOp op, memref::StoreOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto ty = getMemRefType(op.getMemRef());
auto [ref_addr, _] = unpackMemRef(subst.getMemref(), rewriter);
if (!hasMemorySpace(op.getMemRefType(), tpu::MemorySpace::smem)) {
return failure();
}
if (!op.getValue().getType().isIntOrFloat() ||
op.getValue().getType().getIntOrFloatBitWidth() > 32) {
return failure();
}
auto [offset, part] =
indicesToOffset(op.getLoc(), ty, subst.getIndices(), rewriter);
Value addr = rewriter.create<llo::ScalarAddressSmemOp>(
op.getLoc(), ref_addr, offset);
Value to_store = subst.getValue();
if (part != nullptr) {
Value zero = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0));
unsigned bitwidth = op.getValue().getType().getIntOrFloatBitWidth();
unsigned base_mask = (1 << bitwidth) - 1;
Value update_mask = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(base_mask));
Value shift = rewriter.create<llo::ScalarMulS32Op>(
op.getLoc(), part,
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(bitwidth)));
// Here, we take an inverse of the updated bits.
Value existing_mask = rewriter.create<llo::ScalarBitwiseXorOp>(
op.getLoc(),
// Shift the mask to where the update will happen.
rewriter.create<llo::ScalarShllOp>(op.getLoc(), update_mask,
shift),
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0xFFFFFFFF)));
// Bitcast to_store to I32 type for shifting.
to_store = rewriter.create<llo::ScalarBitcastOp>(
op.getLoc(), rewriter.getI32Type(), to_store);
// Blend the shifted update with existing values with the updated
// location masked out.
to_store = rewriter.create<llo::ScalarBitwiseOrOp>(
op.getLoc(),
rewriter.create<llo::ScalarShllOp>(op.getLoc(), to_store,
shift),
rewriter.create<llo::ScalarBitwiseAndOp>(
op.getLoc(), existing_mask,
rewriter.create<llo::ScalarLoadOp>(
op.getLoc(), to_store.getType(), addr, zero)));
}
rewriter.replaceOpWithNewOp<llo::ScalarStoreOp>(op, addr, to_store);
return success();
});
addPattern<memref::AllocaOp>(
patterns, [this](memref::AllocaOp op, memref::AllocaOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto ref_ty = op.getType();
auto size_bytes = getMemRefSizeInBytes(ref_ty);
if (hasMemorySpace(ref_ty, tpu::MemorySpace::vmem)) {
if (!size_bytes ||
*size_bytes % target_->VmemWordSizeBytes() != 0) {
return failure();
}
rewriter.replaceOpWithNewOp<llo::AllocaVmemOp>(
op, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(*size_bytes /
target_->VmemWordSizeBytes()));
return success();
}
if (hasMemorySpace(ref_ty, tpu::MemorySpace::smem)) {
if (!size_bytes ||
*size_bytes % target_->SmemWordSizeBytes() != 0) {
return failure();
}
rewriter.replaceOpWithNewOp<llo::AllocaSmemOp>(
op, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(*size_bytes /
target_->SmemWordSizeBytes()));
return success();
}
// TODO: Support semaphore allocations
op.emitOpError(
"Cannot allocate ref in non-VMEM/SMEM memory space using "
"memref.alloca.");
return failure();
});
}
void populateCFToLLOConversionPatterns(RewritePatternSet &patterns) {
addPattern<cf::AssertOp>(patterns, [](cf::AssertOp op,
cf::AssertOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::ErrorIfOp>(
op,
rewriter.create<llo::PredicateNegateOp>(op.getLoc(), subst.getArg()),
op.getMsgAttr());
return success();
});
}
void populateTPUToLLOConversionPatterns(RewritePatternSet &patterns) {
addPattern<tpu::CreateMaskOp>(
patterns, [this](tpu::CreateMaskOp op, tpu::CreateMaskOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (op.getHigh().size() != 2) {
return failure();
}
if (!isMaskVectorType(op.getType())) {
return failure();
}
auto get_constant = [](Value v) -> std::optional<int64_t> {
auto cst_op = v.getDefiningOp<arith::ConstantOp>();
if (!cst_op) {
return std::nullopt;
}
return cast<IntegerAttr>(cst_op.getValue()).getInt();
};
auto sublane_low = get_constant(op.getLow()[0]);
auto sublane_high = get_constant(op.getHigh()[0]);
auto lane_low = get_constant(op.getLow()[1]);
auto lane_high = get_constant(op.getHigh()[1]);
if (!sublane_low || !sublane_high || !lane_low || !lane_high) {
return failure();
}
if (*sublane_low < 0 || *sublane_low > sublaneCount() ||
*sublane_high < 0 || *sublane_high > sublaneCount() ||
*lane_low < 0 || *lane_low > laneCount() || *lane_high < 0 ||
*lane_high > laneCount()) {
return failure();
}
if (*sublane_low > *sublane_high || *lane_low > *lane_high) {
return failure();
}
if (*sublane_low == *sublane_high || *lane_low == *lane_high) {
rewriter.replaceOpWithNewOp<llo::ConstantOp>(
op, op.getType(),
SplatElementsAttr::get(cast<ShapedType>(op.getType()), false));
return success();
}
// LLO create mask has inclusive upper bounds.
rewriter.replaceOpWithNewOp<llo::VectorCreateMaskOp>(
op, op.getType(),
/*sublane_start=*/*sublane_low,
/*sublane_end=*/(*sublane_high - 1),
/*lane_start=*/*lane_low, /*lane_end=*/(*lane_high - 1));
return success();
});
addPattern<tpu::CreateSubelementMaskOp>(
patterns, [this](tpu::CreateSubelementMaskOp op,
tpu::CreateSubelementMaskOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isMaskVectorType(op.getType())) {
return failure();
}
auto vty = cast<VectorType>(op.getType());
if (vty.getRank() != 3 || vty.getDimSize(2) != op.getNumSubelems()) {
return failure();
}
if (0 > op.getFrom() || op.getFrom() >= op.getTo() ||
op.getTo() > target_->SublaneCount() * op.getNumSubelems()) {
return failure();
}
auto create_sublane_mask = [&rewriter,
&op](int32_t packed_limit) -> Value {
return rewriter.create<llo::VectorCreateSublaneMaskOp>(
op.getLoc(), op.getType(),
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(packed_limit)));
};
if (target_->DeepseaVersion() == ::tpu::TpuVersion::kPufferfish) {
if (op.getNumSubelems() != 1 && op.getNumSubelems() != 2) {
return failure();
}
int32_t from = op.getFrom();
int32_t to = op.getTo();
if (op.getNumSubelems() == 1) {
from *= 2;
to *= 2;
}
Value mask = nullptr;
if (from != 0) {
uint32_t packed_from = from | (from << 16);
mask = rewriter.create<llo::VectorMaskNegateOp>(
op.getLoc(), op.getType(), create_sublane_mask(packed_from));
}
if (to != target_->SublaneCount() * 2) {
uint32_t packed_to = to | (to << 16);
auto lt_to = create_sublane_mask(packed_to);
if (!mask) {
mask = lt_to;
} else {
mask = rewriter.create<llo::VectorMaskAndOp>(
op.getLoc(), op.getType(), mask, lt_to);
}
}
if (!mask) {
rewriter.replaceOpWithNewOp<llo::ConstantOp>(
op, op.getType(),
rewriter.getIntegerAttr(rewriter.getI1Type(), true));
return success();
}
rewriter.replaceOp(op, mask);
return success();
}
#if GCE_PERMITS_VIPERLITE || GCE_PERMITS_VIPERFISH
if (target_->DeepseaVersion() == ::tpu::TpuVersion::kViperfish) {
if (op.getNumSubelems() != 1 && op.getNumSubelems() != 2 &&
op.getNumSubelems() != 4) {
return failure();
}
uint32_t native_subelems = 4;
int32_t from =
op.getFrom() * (native_subelems / op.getNumSubelems());
// We subtract 1, because upper bound is inclusive in Viperfish.
int32_t to =
(op.getTo() * (native_subelems / op.getNumSubelems())) - 1;
CHECK_LE(from, to);
CHECK_LE(from, 2 << 4);
CHECK_LE(to, 2 << 4);
int32_t packed_limit = from | (to << 8);
rewriter.replaceOp(op, create_sublane_mask(packed_limit));
return success();
}
#endif
return failure();
});
addPattern<tpu::AllReduceOp>(
patterns, [](tpu::AllReduceOp op, tpu::AllReduceOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (op.getDim() == 0) {
switch (op.getKind()) {
case tpu::ReductionKind::SUM:
rewriter.replaceOpWithNewOp<llo::VectorAddSublaneReduceF32Op>(
op, subst.getInput());
break;
case tpu::ReductionKind::MAX:
rewriter.replaceOpWithNewOp<llo::VectorMaxSublaneReduceF32Op>(
op, subst.getInput());
break;
default:
return failure();
}
return success();
}
if (op.getDim() == 1) {
switch (op.getKind()) {
case tpu::ReductionKind::SUM:
rewriter.replaceOpWithNewOp<llo::VectorAddReduceF32Op>(
op, subst.getInput());
break;
case tpu::ReductionKind::MAX:
rewriter.replaceOpWithNewOp<llo::VectorMaxReduceF32Op>(
op, subst.getInput());
break;
}
return success();
}
return failure();
});
addPattern<tpu::BitcastVregOp>(
patterns, [](tpu::BitcastVregOp op, tpu::BitcastVregOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::VectorBitcastOp>(op, op.getType(),
subst.getInput());
return success();
});
addPattern<tpu::RollVectorsOp>(
patterns, [](tpu::RollVectorsOp op, tpu::RollVectorsOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.eraseOp(op);
return success();
});
addPattern<tpu::UnrollVectorsOp>(
patterns, [](tpu::UnrollVectorsOp op, tpu::UnrollVectorsOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (auto rop = subst.getInput().getDefiningOp<tpu::RollVectorsOp>()) {
if (rop.getNumOperands() != op.getNumResults()) {
return failure();
}
for (auto [v1, v2] :
llvm::zip(rop.getOperandTypes(), op.getResultTypes())) {
if (v1 != v2) {
return failure();
}
}
rewriter.replaceOp(op, rop->getOperands());
return success();
}
return failure();
});
addPattern<tpu::LoadOp>(
patterns,
[this](tpu::LoadOp op, tpu::LoadOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
auto base_ty = getMemRefType(op.getBase());
auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter);
if (!hasMemorySpace(base_ty, tpu::MemorySpace::vmem)) {
return failure();
}
auto sublane_mask = op.getSublaneMask();
if (sublane_mask.size() != sublaneCount()) {
return failure();
}
auto [sublane_mask_i32, num_read_sublanes, all_sublanes_read] =
encodeSublaneMask(sublane_mask);
std::optional<Value> offset = indicesToVmemOffset(
subst.getIndices(), num_read_sublanes, base_ty, rewriter,
/*consistent_directions=*/false);
if (!offset) {
return op.emitOpError("Failed to convert indices to linear offset");
}
Value sublane_mask_v = nullptr;
if (!all_sublanes_read) {
sublane_mask_v = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(32, false),
rewriter.getUI32IntegerAttr(sublane_mask_i32));
}
uint32_t sublane_stride = op.getSublaneStride().value_or(1);
if (sublane_stride <= 1) {
auto [base, displacement] =
vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset);
rewriter.replaceOpWithNewOp<llo::VectorLoadOp>(
op, op.getType(), base, displacement, sublane_mask_v,
sublane_stride);
} else {
Value addr = base_addr;
if (*offset) {
addr = rewriter.create<llo::ScalarAddressVmemOp>(op.getLoc(),
addr, *offset);
}
rewriter.replaceOpWithNewOp<llo::VldWithArbitrarySlaneStrideOp>(
op, op.getType(), addr, sublane_stride, sublaneCount(),
sublane_mask_v);
}
return success();
});
addPattern<tpu::StoreOp>(
patterns, [this](tpu::StoreOp op, tpu::StoreOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(subst.getValueToStore().getType())) {
return failure();
}
auto source_ty = getMemRefType(op.getBase());
auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter);
if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) {
return failure();
}
auto sublane_mask = op.getSublaneMask();
if (sublane_mask.size() != sublaneCount()) {
return failure();
}
auto [sublane_mask_i32, num_read_sublanes, all_sublanes_read] =
encodeSublaneMask(sublane_mask);
std::optional<Value> offset = indicesToVmemOffset(
subst.getIndices(), num_read_sublanes, source_ty, rewriter,
/*consistent_directions=*/false);
if (!offset) {
return failure();
}
Value sublane_mask_v = nullptr;
if (!all_sublanes_read) {
sublane_mask_v = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(32, false),
rewriter.getUI32IntegerAttr(sublane_mask_i32));
}
auto sublane_stride = op.getSublaneStride().value_or(1);
if (sublane_stride <= 1) {
auto [base, displacement] =
vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset);
if (auto mask = op.getMask()) {
rewriter.replaceOpWithNewOp<llo::VectorStoreMaskedOp>(
op, /*address=*/base,
/*displacement=*/displacement,
/*mask=*/mask, /*to_store=*/subst.getValueToStore(),
/*sublane_stride=*/sublane_stride,
/*sublanes_per_stride=*/1,
/*sublane_mask=*/sublane_mask_v);
} else {
rewriter.replaceOpWithNewOp<llo::VectorStoreOp>(
op, /*address=*/base,
/*displacement=*/displacement,
/*to_store=*/subst.getValueToStore(),
/*sublane_mask=*/sublane_mask_v,
/*sublane_stride=*/sublane_stride);
}
} else {
Value addr = base_addr;
if (*offset) {
addr = rewriter.create<llo::ScalarAddressVmemOp>(op.getLoc(),
addr, *offset);
}
if (auto mask = op.getMask()) {
rewriter
.replaceOpWithNewOp<llo::VstMaskedWithArbitrarySlaneStrideOp>(
op, addr, mask, subst.getValueToStore(), sublane_stride,
sublaneCount(), sublane_mask_v);
} else {
rewriter.replaceOpWithNewOp<llo::VstWithArbitrarySlaneStrideOp>(
op, addr, subst.getValueToStore(), sublane_stride,
sublaneCount(), sublane_mask_v);
}
}
return success();
});
addPattern<tpu::MatmulOp>(
patterns,
[this](tpu::MatmulOp op, tpu::MatmulOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
bool transposed = op.getTransposeRhs();
bool high_precision = false;
if (auto precision = op.getPrecision()) {
switch (*precision) {
case tpu::ContractPrecision::kBF16:
high_precision = false;
break;
case tpu::ContractPrecision::kFP32:
high_precision = true;
break;
}
}
CHECK_EQ(target_->MxuContractingSize(),
target_->MxuNoncontractingSize());
int64_t mxu_size = target_->MxuContractingSize();
CHECK_EQ(mxu_size % laneCount(), 0);
auto lhs_type = cast<VectorType>(subst.getLhs().getType());
int64_t lhs_packing = 32 / lhs_type.getElementTypeBitWidth();
auto lhs_shape = lhs_type.getShape();
if (lhs_shape[0] <= 0 ||
lhs_shape[0] % (sublaneCount() * lhs_packing) != 0 ||
lhs_shape[1] != mxu_size ||
(high_precision && !lhs_type.getElementType().isF32())) {
op.emitOpError() << "Bad lhs type in tpu.matmul";
return failure();
}
mlir::ValueRange lhs_vectors;
if (auto rvop = subst.getLhs().getDefiningOp<tpu::RollVectorsOp>()) {
lhs_vectors = rvop->getOperands();
} else {
return failure();
}
auto rhs_type = cast<VectorType>(subst.getRhs().getType());
int64_t rhs_packing = 32 / rhs_type.getElementTypeBitWidth();
if (rhs_type.getShape() != ArrayRef<int64_t>{mxu_size, mxu_size} ||
(high_precision && !rhs_type.getElementType().isF32())) {
op.emitOpError() << "Bad rhs type in tpu.matmul";
return failure();
}
mlir::ValueRange rhs_vectors;
if (auto rvop = subst.getRhs().getDefiningOp<tpu::RollVectorsOp>()) {
rhs_vectors = rvop->getOperands();
CHECK_EQ(rhs_vectors.size(),
mxu_size * mxu_size /
(sublaneCount() * rhs_packing * laneCount()));
} else {
return failure();
}
auto acc_type = cast<VectorType>(subst.getAcc().getType());
if (acc_type.getShape() != lhs_type.getShape() ||
acc_type.getElementTypeBitWidth() != 32) {
op.emitOpError() << "Bad acc type in tpu.matmul";
return failure();
}
std::vector<Value> acc_vectors;
if (auto rvop = subst.getAcc().getDefiningOp<tpu::RollVectorsOp>()) {
acc_vectors.insert(acc_vectors.end(), rvop.operand_begin(),
rvop.operand_end());
} else {
return failure();
}
int64_t acc_lhs_compression = acc_type.getElementTypeBitWidth() /
lhs_type.getElementTypeBitWidth();
CHECK_EQ(acc_vectors.size(),
lhs_vectors.size() * acc_lhs_compression);
// Make sure the type of the accumulator matches the operand types.
bool integral = acc_type.getElementType().isSignlessInteger();
if (integral) {
if (!lhs_type.getElementType().isSignlessInteger(8) ||
!rhs_type.getElementType().isSignlessInteger(8)) {
return failure();
}
} else {
if (!isa<BFloat16Type, Float32Type>(lhs_type.getElementType()) ||
!isa<BFloat16Type, Float32Type>(rhs_type.getElementType())) {
return failure();
}
}
// The code below assumes MXU size is either 128 or 256.
if (mxu_size != 128 && mxu_size != 256) {
op.emitOpError()
<< "Not implemented: unsupported MXU size " << mxu_size;
return failure();
}
auto loc = op.getLoc();
auto push_latch = [&rhs_vectors = std::as_const(rhs_vectors),
&rewriter, mxu_size, loc,
this](llo::GainLatchMode mode) {
// Here tile refers to 128x128 tile.
auto tile_cols = mxu_size / laneCount();
auto tile_rows = tile_cols;
auto vregs_per_tile = rhs_vectors.size() / (tile_cols * tile_rows);
bool reverse_push = target_->ShouldReverseTileForLatching();
for (int tr = 0; tr < tile_rows; ++tr) {
// Vreg index in the 128x128 tile.
int vi = reverse_push ? vregs_per_tile - 1 : 0;
while (vi >= 0 && vi < vregs_per_tile) {
for (int tc = tile_cols - 1; tc >= 0; --tc) {
int idx =
tr * tile_cols * vregs_per_tile + vi * tile_cols + tc;
Value rhs_vector = rhs_vectors[idx];
// JF and DF want to be fed the latch with reversed sublanes.
if (target_->ShouldReverseForLatching()) {
rhs_vector = rewriter.create<llo::VectorSublaneReverseOp>(
loc, rhs_vector.getType(), rhs_vector);
}
if (tile_cols == 1) {
rewriter.create<llo::VectorLatchOp>(loc, rhs_vector, mode);
} else if (tc == 0) {
// On GFC, we switched from vlatch 1/2 to just vlatch1.
int variant = (reverse_push && tr > 0) ? 2 : 1;
rewriter.create<llo::VectorLatchIOp>(loc, rhs_vector,
variant, mode);
} else {
rewriter.create<llo::VectorMatprepSubrOp>(loc, rhs_vector,
mode);
}
}
vi = reverse_push ? vi - 1 : vi + 1;
}
}
};
// Watch out! Calling this function mutates acc_vectors
// by accumulating the matmul result into them.
auto matmul = [&lhs_vectors = std::as_const(lhs_vectors),
&acc_vectors, &rewriter, mxu_size, integral, lhs_type,
loc, acc_lhs_compression, this](llo::MatmulMode mode) {
llo::MatmulDataFormat data_format = llo::MatmulDataFormat::kF32;
if (lhs_type.getElementType().isBF16()) {
data_format = llo::MatmulDataFormat::kBf16;
CHECK(mode == llo::MatmulMode::kRound);
} else if (lhs_type.getElementType().isSignlessInteger(8)) {
data_format = llo::MatmulDataFormat::kS8;
CHECK(mode == llo::MatmulMode::kRound);
} else if (!lhs_type.getElementType().isF32()) {
LOG(FATAL) << "Unexpected LHS type";
}
auto col_cnt = mxu_size / laneCount();
auto row_cnt = lhs_vectors.size() / col_cnt;
for (int r = 0; r < row_cnt; ++r) {
for (int c = col_cnt - 1; c >= 0; --c) {
Value lhs_vector = lhs_vectors[r * col_cnt + c];
if (col_cnt == 1) {
rewriter.create<llo::VectorMatmulOp>(
loc, lhs_vector, mode,
/*mxu_id=*/0,
/*dwg=*/false,
/*data_format=*/data_format);
} else if (c == 0) {
rewriter.create<llo::VectorMatmulMubrOp>(
loc, lhs_vector, mode,
/*mxu_id=*/0,
/*dwg=*/false,
/*data_format=*/data_format);
} else {
rewriter.create<llo::VectorMatprepMubrOp>(loc, lhs_vector,
mode);
}
}
// Now the matmul is done, we can accumulate the result to acc.
for (int ur = 0; ur < acc_lhs_compression; ++ur) {
for (int c = 0; c < col_cnt; ++c) {
int idx =
r * acc_lhs_compression * col_cnt + ur * col_cnt + c;
Value result = rewriter.create<llo::VectorMatresOp>(
loc, acc_vectors[idx].getType(), data_format);
if (integral) {
acc_vectors[idx] = rewriter.create<llo::VectorAddS32Op>(
loc, acc_vectors[idx], result);
} else {
acc_vectors[idx] = rewriter.create<llo::VectorAddF32Op>(
loc, acc_vectors[idx], result);
}
}
}
}
};
using llo::MatmulMode;
using llo::GainLatchMode;
using llo::MatmulDataFormat;
if (high_precision) {
// This path is only taken for fp32 inputs. All targets support
// the matmul and latch modes used here, so no checks are necessary.
using MatmulInstance = std::pair<MatmulMode, GainLatchMode>;
static const std::array<MatmulInstance, 6> matmul_sequence = {
MatmulInstance{MatmulMode::kSoftLowEight,
GainLatchMode::kNoXposeHiF32},
MatmulInstance{MatmulMode::kHigh,
GainLatchMode::kNoXposeSoftLowEightF32},
MatmulInstance{MatmulMode::kLow, GainLatchMode::kNoXposeLowF32},
MatmulInstance{MatmulMode::kSoftMiddleEight,
GainLatchMode::kNoXposeHiF32},
MatmulInstance{MatmulMode::kHigh,
GainLatchMode::kNoXposeSoftMiddleEightF32},
MatmulInstance{MatmulMode::kHigh, GainLatchMode::kNoXposeHiF32},
};
std::optional<GainLatchMode> last_latch_mode = std::nullopt;
for (auto [matmul_mode, latch_mode_no_transp] : matmul_sequence) {
if (!last_latch_mode.has_value() ||
*last_latch_mode != latch_mode_no_transp) {
if (last_latch_mode.has_value()) {
rewriter.create<llo::VectorDoneWithGainsOp>(op.getLoc());
}
GainLatchMode latch_mode = latch_mode_no_transp;
if (transposed) {
latch_mode = static_cast<GainLatchMode>(
static_cast<int32_t>(latch_mode) + 1);
}
push_latch(latch_mode);
last_latch_mode = latch_mode_no_transp;
}
matmul(matmul_mode);
}
} else {
GainLatchMode latch_mode;
MatmulDataFormat data_format;
if (rhs_type.getElementType().isF32()) {
latch_mode = transposed ? GainLatchMode::kXposeF32
: GainLatchMode::kNoXposeF32;
data_format = MatmulDataFormat::kF32;
} else if (rhs_type.getElementType().isBF16()) {
latch_mode = transposed ? GainLatchMode::kXposePackedBf16
: GainLatchMode::kNoXposePackedBf16;
data_format = MatmulDataFormat::kBf16;
} else if (rhs_type.getElementType().isSignlessInteger(8)) {
latch_mode = transposed ? GainLatchMode::kXposeS8
: GainLatchMode::kNoXposeS8;
data_format = MatmulDataFormat::kS8;
} else {
return failure();
}
if (!target_->SupportsMatmulDataFormat(
static_cast<xla::jellyfish::MatmulDataFormat>(
data_format)) ||
!target_->SupportsGainLatchMode(
static_cast<xla::jellyfish::GainLatchMode>(latch_mode))) {
op->emitOpError(
"Unsupported input data type in matrix multiplication.");
return failure();
}
push_latch(latch_mode);
matmul(MatmulMode::kRound);
}
rewriter.create<llo::VectorDoneWithGainsOp>(op.getLoc());
rewriter.replaceOpWithNewOp<tpu::RollVectorsOp>(op, lhs_type,
acc_vectors);
return success();
});
addPattern<tpu::RotateOp>(
patterns, [this](tpu::RotateOp op, tpu::RotateOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
int32_t amount = op.getAmount();
if (op.getDimension() == 0) {
if (amount < 0 || amount >= sublaneCount()) {
return failure();
}
if (op.getStride().has_value() ||
op.getStrideDimension().has_value()) {
op.emitOpError("Not implemented: rotate sublanes with stride.");
return failure();
}
if (amount % sublaneCount() == 0) {
rewriter.replaceOp(op, subst.getValue());
return success();
}
// NOTE: LLO rotates sublanes down.
rewriter.replaceOpWithNewOp<llo::VectorSublaneRotateOp>(
op, subst.getValue(), sublaneCount() - amount);
return success();
}
if (op.getDimension() == 1) {
if (amount < 0 || amount >= laneCount()) {
return failure();
}
if (op.getStride().has_value() &&
op.getStrideDimension().value_or(1) != 0) {
op.emitOpError("Expect stride dimension is 0");
return failure();
}
if (amount % laneCount() == 0 && op.getStride().value_or(0) == 0) {
rewriter.replaceOp(op, subst.getValue());
return success();
}
rewriter.replaceOpWithNewOp<llo::VectorRotateOp>(
op, subst.getValue(), amount, op.getStrideAttr());
return success();
}
return failure();
});
addPattern<tpu::GatherOp>(
patterns, [this](tpu::GatherOp op, tpu::GatherOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto indices = op.getIndices();
if (!isRepresentableVectorType(op.getType()) ||
op.getDimension() != 0 || indices.size() != sublaneCount()) {
return failure();
}
bool constant = true;
for (int32_t idx : indices) {
if (idx < 0 || idx >= sublaneCount()) {
return failure();
}
constant &= idx == indices[0];
}
if (constant) {
rewriter.replaceOpWithNewOp<llo::VectorSublaneReplicateOp>(
op, op.getType(), op.getSource(),
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(indices[0])));
} else {
rewriter.replaceOpWithNewOp<llo::VectorSublaneShuffleOp>(
op, op.getType(), op.getSource(), op.getIndicesAttr());
}
return success();
});
addPattern<tpu::DynamicGatherOp>(
patterns,
[this](tpu::DynamicGatherOp op, tpu::DynamicGatherOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(op.getType())) {
return failure();
}
if (op.getDimension() != 1) {
return failure();
}
auto set_permute = rewriter.create<llo::VectorSetPermutePatternOp>(
op.getLoc(), op.getType(), subst.getIndices(),
llo::SetPermuteMode::kOneSublane, /*xlu_id=*/0,
/*source_bus=*/nullptr);
auto request = rewriter.create<llo::VectorPermuteOp>(
op.getLoc(), op.getType(), subst.getSource(), set_permute,
/*xlu_id=*/0,
/*source_bus=*/nullptr);
rewriter.replaceOpWithNewOp<llo::VectorPermuteResultOp>(
op, op.getType(), request, 0);
return success();
});
addPattern<tpu::IotaOp>(
patterns, [this](tpu::IotaOp op, tpu::IotaOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isRepresentableVectorType(op.getType()) ||
op.getType().getElementType() != rewriter.getI32Type()) {
return failure();
}
if (op.getDimension() == 0) {
rewriter.replaceOpWithNewOp<llo::VectorSublaneId>(op, op.getType());
} else if (op.getDimension() == 1) {
rewriter.replaceOpWithNewOp<llo::VectorLaneId>(op, op.getType());
} else if (!op.getDimension().has_value()) {
rewriter.replaceOpWithNewOp<llo::VectorLaneSeqOp>(op, op.getType());
} else {
return failure();
}
return success();
});
// EraseLayoutOp exists only to appease some silly canonicalization rules
// and has no operational meaning.
addPattern<tpu::EraseLayoutOp>(
patterns, [](tpu::EraseLayoutOp op, tpu::EraseLayoutOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOp(op, subst.getOperand());
return success();
});
// MemRefSqueezeOp is a type cast and has no operational meaning.
addPattern<tpu::MemRefSqueezeOp>(
patterns, [](tpu::MemRefSqueezeOp op, tpu::MemRefSqueezeOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOp(op, subst.getInput());
return success();
});
// ReinterpretCastOp is a type cast and has no operational meaning.
addPattern<tpu::ReinterpretCastOp>(
patterns,
[](tpu::ReinterpretCastOp op, tpu::ReinterpretCastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOp(op, subst.getInput());
return success();
});
addPattern<tpu::PackSubelementsOp>(
patterns,
[this](tpu::PackSubelementsOp op, tpu::PackSubelementsOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto sources = subst.getSources();
auto packed_ty = op.getType().getElementType();
int packing = target_->VectorScalarBitWidth() /
packed_ty.getIntOrFloatBitWidth();
if (sources.size() != packing) {
return failure();
}
auto source_ty = cast<VectorType>(sources.front().getType());
if (!isRepresentableVectorType(source_ty) ||
!isRepresentableVectorType(op.getType())) {
return failure();
}
if (packed_ty.isBF16() && source_ty.getElementType().isF32()) {
if (target_->SupportsVectorPackOps(
xla::jellyfish::VpackFormat::kCompressedBf16)) {
rewriter.replaceOpWithNewOp<llo::VectorPackOp>(
op, op.getType(), llo::VpackFormat::kCompressedBf16,
sources[1], sources[0]);
return success();
}
auto exe_region =
rewriter.create<llo::RegionOp>(op.getLoc(), op.getType());
Block &block = exe_region.getRegion().emplaceBlock();
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
// TODO(b/280782839): Fold this into successive stores.
auto addr = rewriter.create<llo::AllocaVmemOp>(
op.getLoc(),
target_->ChunkSizeBytes() / target_->VmemWordSizeBytes());
for (int i = 0; i < sources.size(); ++i) {
rewriter.create<llo::VectorStoreAndPackX16Op>(
op.getLoc(), addr, i, sources[i], rewriter.getBF16Type());
}
auto packed = rewriter.create<llo::VectorLoadOp>(
op.getLoc(), op.getType(), addr, nullptr, nullptr);
rewriter.create<llo::YieldOp>(op.getLoc(), packed.getResult());
}
rewriter.replaceOp(op, exe_region.getResult(0));
return success();
}
if (source_ty.getElementType().isSignlessInteger(32) &&
packed_ty.isSignlessInteger()) {
unsigned bitwidth = 32;
std::vector<Value> sources(op.getSources().begin(),
op.getSources().end());
std::vector<Value> next_sources;
next_sources.reserve(sources.size() / 2);
while (bitwidth > packed_ty.getIntOrFloatBitWidth()) {
CHECK_EQ(sources.size(),
bitwidth / packed_ty.getIntOrFloatBitWidth());
llo::VpackFormat format = GetCompressedByteFormat(bitwidth / 2);
if (format == llo::VpackFormat::kInvalid) {
op.emitOpError() << "unsupported source element type "
<< source_ty.getElementType();
return failure();
}
if (!target_->SupportsVectorPackOps(
static_cast<xla::jellyfish::VpackFormat>(format))) {
op.emitOpError()
<< "Target TPU does not support required instructions";
return failure();
}
auto packed_ty =
VectorType::get(nativeVectorShape(bitwidth / 2),
rewriter.getIntegerType(bitwidth / 2));
for (int i = 0; i < sources.size(); i += 2) {
next_sources.push_back(rewriter.create<llo::VectorPackOp>(
op.getLoc(), packed_ty, format, sources[i + 1],
sources[i]));
}
bitwidth /= 2;
std::swap(sources, next_sources);
next_sources.clear();
}
CHECK_EQ(bitwidth, packed_ty.getIntOrFloatBitWidth());
CHECK_EQ(sources.size(), 1);
rewriter.replaceOp(op, sources.front());
return success();
}
op.emitOpError() << "unsupported source element type "
<< source_ty.getElementType();
return failure();
});
addPattern<tpu::UnpackSubelementsOp>(
patterns, [this](tpu::UnpackSubelementsOp op,
tpu::UnpackSubelementsOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto unpacked_ty = op.getType().getElementType();
auto source_ty = op.getSource().getType();
if (!isRepresentableVectorType(source_ty) ||
!isRepresentableVectorType(op.getType())) {
return failure();
}
llo::VpackFormat format = llo::VpackFormat::kInvalid;
if (unpacked_ty.isSignlessInteger(32) &&
source_ty.getElementType().isSignlessInteger()) {
format =
GetCompressedByteFormat(source_ty.getElementTypeBitWidth());
} else if (unpacked_ty.isF32() &&
source_ty.getElementType().isBF16()) {
format = llo::VpackFormat::kCompressedBf16;
}
if (format == llo::VpackFormat::kInvalid) {
op.emitOpError() << "unsupported source element type "
<< source_ty.getElementType();
return failure();
}
if (target_->SupportsVectorUnpackOps(
static_cast<xla::jellyfish::VpackFormat>(format))) {
rewriter.replaceOpWithNewOp<llo::VectorUnpackOp>(
op, op.getType(), op.getIndex(), format, subst.getSource());
} else {
auto exe_region =
rewriter.create<llo::RegionOp>(op.getLoc(), op.getType());
Block &block = exe_region.getRegion().emplaceBlock();
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
// TODO(b/280782839): We only need part of this space.
// TODO(b/280782839): Fold this into preceding loads.
auto addr = rewriter.create<llo::AllocaVmemOp>(
op.getLoc(),
target_->ChunkSizeBytes() / target_->VmemWordSizeBytes());
rewriter.create<llo::VectorStoreOp>(op.getLoc(), addr, nullptr,
subst.getSource(), nullptr);
Type signed_source_element_ty = source_ty.getElementType();
if (signed_source_element_ty.isSignlessInteger()) {
signed_source_element_ty = rewriter.getIntegerType(
source_ty.getElementTypeBitWidth(), true);
}
Value unpacked = rewriter.create<llo::VectorLoadAndUnpackOp>(
op.getLoc(), op.getType(), addr, op.getIndex(),
signed_source_element_ty);
rewriter.create<llo::YieldOp>(op.getLoc(), unpacked);
}
rewriter.replaceOp(op, exe_region.getResult(0));
}
return success();
});
addPattern<tpu::AllocaSemaphoreOp>(
patterns,
[](tpu::AllocaSemaphoreOp op, tpu::AllocaSemaphoreOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
MemRefType ref_ty = op.getResult().getType();
auto layout =
dyn_cast<mlir::tpu::TiledLayoutAttr>(ref_ty.getLayout());
if (!layout) {
return op.emitOpError() << "need a tiled layout";
}
if (!layout.getTiles().empty()) {
return op.emitOpError() << "tiling unsupported";
}
auto tile_strides = layout.getTileStrides();
if (tile_strides.size() != ref_ty.getRank()) {
return failure();
}
int stride = 1;
for (int i = ref_ty.getRank() - 1; i >= 0; --i) {
if (tile_strides[i] != stride) {
return op.emitOpError()
<< "non-contiguous allocations not supported";
}
if (ref_ty.isDynamicDim(i)) {
return op->emitOpError()
<< "dynamic shapes not supported in allocations";
}
stride *= ref_ty.getDimSize(i);
}
rewriter.replaceOpWithNewOp<llo::AllocaSyncFlagOp>(
op, rewriter.getI32Type(), stride);
return success();
});
addPattern<tpu::DeviceIdOp>(
patterns, [](tpu::DeviceIdOp op, tpu::DeviceIdOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::LogicalDeviceIdOp>(
op, rewriter.getI32Type());
return success();
});
addPattern<tpu::SemaphoreReadOp>(
patterns, [](tpu::SemaphoreReadOp op, tpu::SemaphoreReadOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::VSyncRead>(op, subst.getSemaphore());
return success();
});
addPattern<tpu::SemaphoreWaitOp>(
patterns, [](tpu::SemaphoreWaitOp op, tpu::SemaphoreWaitOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
auto c0 = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0));
auto neg_amount = rewriter.create<llo::ScalarSubS32Op>(
op.getLoc(), rewriter.getI32Type(), c0, subst.getAmount());
// This could race if multiple cores were to wait on the same
// semaphore, but I don't think this can happen.
rewriter.create<llo::VSyncAddOp>(op.getLoc(), subst.getSemaphore(),
neg_amount);
rewriter.replaceOpWithNewOp<llo::VWaitGeOp>(op, subst.getSemaphore(),
c0);
return success();
});
addPattern<tpu::SemaphoreSignalOp>(
patterns,
[this](tpu::SemaphoreSignalOp op, tpu::SemaphoreSignalOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!op.getDeviceId()) {
rewriter.replaceOpWithNewOp<llo::VSyncAddOp>(
op, subst.getSemaphore(), subst.getAmount());
} else {
// If there are multiple tensor cores, then we might need to decode
// device id and core index from subst's device_id.
if (target_->TensorCoresPerChip() > 1 &&
!unsafe_allow_multicore_remote_dma_) {
return failure();
}
auto core_index = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0));
rewriter.replaceOpWithNewOp<llo::VSyncAddRemoteOp>(
op, subst.getSemaphore(), subst.getDeviceId(), core_index,
subst.getAmount());
}
return success();
});
addPattern<tpu::EnqueueDMAOp>(
patterns,
[this](tpu::EnqueueDMAOp op, tpu::EnqueueDMAOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
// If there are multiple tensor cores, then we might need to decode
// device id and core index from subst's device_id. Local DMAs are ok.
if (op.getDeviceId() && target_->TensorCoresPerChip() > 1 &&
!unsafe_allow_multicore_remote_dma_) {
op.emitOpError("Multicore remote DMAs are not supported.");
}
auto src_ty = getMemRefType(op.getSource());
auto tgt_ty = getMemRefType(op.getTarget());
auto [src_addr, src_dynamic_sizes] =
unpackMemRef(subst.getSource(), rewriter);
auto [tgt_addr, tgt_dynamic_sizes] =
unpackMemRef(subst.getTarget(), rewriter);
if (src_ty.getElementType() != tgt_ty.getElementType()) {
op.emitOpError("DMA source and target element type mismatch.");
return failure();
}
int64_t bytewidth = src_ty.getElementTypeBitWidth() / 8;
if (src_ty.getShape() != tgt_ty.getShape()) {
return op.emitOpError("DMA source and target shape mismatch.");
}
// Check that dynamic sizes match too.
CHECK_EQ(src_dynamic_sizes.size(), tgt_dynamic_sizes.size());
for (auto [src_size, dst_size] :
llvm::zip(src_dynamic_sizes, tgt_dynamic_sizes)) {
rewriter.create<llo::ErrorIfOp>(
op.getLoc(),
rewriter.create<llo::ScalarCmpNeS32Op>(op.getLoc(), src_size,
dst_size),
rewriter.getStringAttr("Dynamic shape mismatch"));
}
ArrayRef<int64_t> shape = src_ty.getShape();
// Note that memory space can differ. Only layout left to be checked.
auto src_layout = cast<tpu::TiledLayoutAttr>(src_ty.getLayout());
auto tgt_layout = cast<tpu::TiledLayoutAttr>(tgt_ty.getLayout());
// Strides can differ, but tiling should be the same.
if (src_layout.getTiles() != tgt_layout.getTiles()) {
op.emitOpError("DMA source and target tiling mismatch.");
return failure();
}
auto tiling = src_layout.getTiles().front().dimensions();
SmallVector<Value> tiled_dynamic_shape;
if (failed(getTiledDynamicShape(&tiled_dynamic_shape, shape,
src_dynamic_sizes, tiling, rewriter,
op))) {
return failure();
}
int32_t tile_bytes = bytewidth;
for (int64_t s : tiling) {
tile_bytes *= s;
}
if (tile_bytes % 512 != 0) {
return op.emitOpError("Tile size must be divisible by 512 bytes.");
}
Value num_elements =
getDynamicNumElements(shape, src_dynamic_sizes, rewriter);
Value size_512b = rewriter.create<llo::ScalarDivS32Op>(
op.getLoc(), rewriter.getI32Type(), num_elements,
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(512 / bytewidth)));
SmallVector<Value> steps_per_stride = tiled_dynamic_shape;
steps_per_stride.back() = rewriter.create<llo::ScalarMulS32Op>(
op.getLoc(), rewriter.getI32Type(), steps_per_stride.back(),
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(tile_bytes)));
auto get_strides = [&](TiledLayoutAttr layout)
-> mlir::FailureOr<SmallVector<int32_t>> {
if (layout.getTileStrides().back() != 1) {
op.emitOpError("expected a 1 stride");
return failure();
}
SmallVector<int32_t> strides_in_bytes;
strides_in_bytes.reserve(layout.getTileStrides().size());
for (int64_t stride : layout.getTileStrides()) {
strides_in_bytes.push_back(stride * tile_bytes);
}
return strides_in_bytes;
};
auto src_strides = get_strides(src_layout);
auto dst_strides = get_strides(tgt_layout);
if (failed(src_strides) || failed(dst_strides)) {
return failure();
}
CHECK_EQ(steps_per_stride.size(), src_strides->size());
CHECK_EQ(steps_per_stride.size(), dst_strides->size());
auto src_tile_strides = src_layout.getTileStrides();
auto tgt_tile_strides = tgt_layout.getTileStrides();
// We try to flatten pairs of dimensions that don't need to use
// multi-level striding.
for (int i = steps_per_stride.size() - 2; i >= 0; --i) {
int64_t tiled_size;
if (auto cst = tiled_dynamic_shape[i + 1]
.getDefiningOp<llo::ConstantOp>()) {
tiled_size = cast<IntegerAttr>(cst.getValue()).getInt();
} else {
continue;
}
if (src_tile_strides[i] == tiled_size * src_tile_strides[i + 1] &&
tgt_tile_strides[i] == tiled_size * tgt_tile_strides[i + 1]) {
steps_per_stride[i + 1] = rewriter.create<llo::ScalarMulS32Op>(
op.getLoc(), rewriter.getI32Type(), steps_per_stride[i + 1],
steps_per_stride[i]);
steps_per_stride.erase(steps_per_stride.begin() + i);
src_strides->erase(src_strides->begin() + i);
dst_strides->erase(dst_strides->begin() + i);
}
}
// Remove the dimensions XLA expects to be implicit.
src_strides->pop_back();
dst_strides->pop_back();
steps_per_stride.erase(steps_per_stride.begin());
if (steps_per_stride.size() > target_->StrideLevelCount()) {
op.emitOpError("Not implemented: DMA is too complicated.");
return failure();
}
Value core_index = nullptr;
if (subst.getDeviceId()) {
core_index = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0));
}
rewriter.replaceOpWithNewOp<llo::EnqueueDMAOp>(
op, src_addr, subst.getSourceSemaphore(), size_512b, tgt_addr,
subst.getTargetSemaphore(),
rewriter.getDenseI32ArrayAttr(*src_strides),
rewriter.getDenseI32ArrayAttr(*dst_strides), steps_per_stride,
subst.getDeviceId(), core_index);
return success();
});
addPattern<tpu::WaitDMAOp>(
patterns,
[this](tpu::WaitDMAOp op, tpu::WaitDMAOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
auto dma_ty = getMemRefType(op.getRef());
auto dma_layout = cast<tpu::TiledLayoutAttr>(dma_ty.getLayout());
auto tiling = dma_layout.getTiles().front().dimensions();
ArrayRef<int64_t> shape = dma_ty.getShape();
auto [_, dynamic_sizes] = unpackMemRef(subst.getRef(), rewriter);
SmallVector<Value> tiled_dynamic_shape;
if (failed(getTiledDynamicShape(&tiled_dynamic_shape, shape,
dynamic_sizes, tiling, rewriter,
op))) {
return failure();
}
int32_t tile_bytes = dma_ty.getElementTypeBitWidth() / 8;
for (int64_t s : tiling) {
tile_bytes *= s;
}
if (tile_bytes % 512 != 0) {
return op.emitOpError("Tile size must be divisible by 512 bytes.");
}
Value num_elements =
getDynamicNumElements(shape, dynamic_sizes, rewriter);
Value size_512b = rewriter.create<llo::ScalarDivS32Op>(
op.getLoc(), rewriter.getI32Type(), num_elements,
rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(
512 / (dma_ty.getElementTypeBitWidth() / 8))));
rewriter.replaceOpWithNewOp<llo::DMADoneOp>(
op, size_512b, subst.getSemaphore(),
hasMemorySpace(dma_ty, tpu::MemorySpace::smem));
return success();
});
addPattern<tpu::MemRefSliceOp>(
patterns,
[this](tpu::MemRefSliceOp op, tpu::MemRefSliceOpAdaptor subst,
ConversionPatternRewriter &rewriter) -> LogicalResult {
auto indices = subst.getBaseIdx();
auto slice_shape = op.getResult().getType().getShape();
auto source_ty = getMemRefType(op.getMemRef());
if (!source_ty.hasStaticShape()) {
return op.emitOpError(
"Only slicing of memrefs with static shapes is supported.");
}
auto source_shape = source_ty.getShape();
bool is_semaphore =
hasMemorySpace(source_ty, tpu::MemorySpace::kSemaphoreMem);
if (is_semaphore && !isa<SemaphoreType, DMASemaphoreType>(
source_ty.getElementType())) {
return op.emitOpError(
"References to semaphore memory space must have a semaphore "
"element type.");
}
if (indices.size() != slice_shape.size() ||
indices.size() != source_shape.size()) {
op.emitOpError("Indices and slice shapes must match.");
return failure();
}
auto tiled_layout =
dyn_cast<tpu::TiledLayoutAttr>(source_ty.getLayout());
if (!tiled_layout) {
op.emitOpError("Must have tiled layout.");
return failure();
}
absl::Span<const int64_t> tile;
if (!tiled_layout.getTiles().empty()) {
// We assume that only the highest-level tile matters.
// All lower-level tiles only rearrange the data within the tile.
tile = tiled_layout.getTiles().front().dimensions();
}
if (source_shape.size() < tile.size()) {
op.emitOpError("Source rank must not be smaller than tile rank.");
return failure();
}
// Check that sliced portions are aligned to tile boundaries.
// We use the values before dialect conversion, since ops like
// tpu.assume_multiple get eliminated.
auto tiled_indices = op.getBaseIdx().take_back(tile.size());
auto tiled_slice_shape = slice_shape.take_back(tile.size());
bool is_aligned = true;
for (int64_t i = 0; i < tile.size(); ++i) {
is_aligned &= tiled_slice_shape[i] % tile[i] == 0;
if (!isGuaranteedDivisible(tiled_indices[i], tile[i])) {
op.emitOpError(
"Failed to prove that a tile index is divisible by the "
"tiling.");
return failure();
}
}
if (!is_aligned) {
op.emitOpError("Slice shape must be aligned to tile boundaries.");
return failure();
}
auto int_const = [&](int i) -> Value {
return rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(i));
};
Value offset = int_const(0);
auto tile_strides = tiled_layout.getTileStrides();
int tile_element_size = std::accumulate(tile.begin(), tile.end(), 1,
std::multiplies<int>());
// TODO(apaszke): More often then not it's the trailing indices that
// are static and not the leading ones. If we computed the sum in the
// opposite order, we might expose constant folding opportunities.
for (int i = indices.size() - 1; i >= 0; --i) {
auto idx = indices[i];
auto dim_size = source_shape[i];
auto tile_i = i - indices.size() + tile.size();
if (tile_i >= 0 && tile_i < tile.size()) {
idx = rewriter.create<llo::ScalarDivS32Op>(
op.getLoc(), rewriter.getI32Type(), idx,
int_const(tile[tile_i]));
dim_size /= tile[tile_i];
}
auto scaled_index = rewriter.create<llo::ScalarMulS32Op>(
op.getLoc(), rewriter.getI32Type(), int_const(tile_strides[i]),
idx);
offset = rewriter.create<llo::ScalarAddS32Op>(
op.getLoc(), rewriter.getI32Type(), offset, scaled_index);
}
int bitwidth;
if (is_semaphore) {
bitwidth = target_->SflagWordSizeBytes() * 8;
} else {
bitwidth = source_ty.getElementTypeBitWidth();
}
const auto tile_size_bytes = tile_element_size * bitwidth / 8;
auto addr = rewriter.create<llo::AddrScaledOp>(
op.getLoc(), subst.getMemRef(), offset, tile_size_bytes);
if (op.getResult().getType().hasStaticShape()) {
rewriter.replaceOp(op, addr);
} else {
SmallVector<Value> memref_rep;
memref_rep.push_back(addr);
memref_rep.append(subst.getDynamicSizes().begin(),
subst.getDynamicSizes().end());
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, op.getType(), memref_rep);
}
return success();
});
addPattern<tpu::BroadcastInSublanesOp>(
patterns, [](tpu::BroadcastInSublanesOp op,
tpu::BroadcastInSublanesOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
Value lane = rewriter.create<llo::ConstantOp>(
op.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(op.getLane()));
rewriter.replaceOpWithNewOp<llo::VectorBroadcastSublaneChunkOp>(
op, op.getType(), subst.getSource(), lane);
return success();
});
addPattern<tpu::GetBarrierSemaphoreOp>(
patterns, [](tpu::GetBarrierSemaphoreOp op,
tpu::GetBarrierSemaphoreOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::GetBarrierSyncFlagOp>(op);
return success();
});
addPattern<tpu::TraceOp>(
patterns, [](tpu::TraceOp op, tpu::TraceOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
llo::TraceOp new_op = rewriter.create<llo::TraceOp>(
op.getLoc(), TypeRange{}, op.getMessageAttr(), op.getLevelAttr());
CHECK_EQ(new_op->getNumRegions(), 1);
rewriter.cloneRegionBefore(op.getRegion(), new_op.getRegion(),
new_op.getRegion().end());
rewriter.eraseOp(op);
return success();
});
addPattern<tpu::TraceStartOp>(
patterns, [](tpu::TraceStartOp op, tpu::TraceStartOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::TraceStartOp>(
op, op.getMessageAttr(), op.getLevelAttr());
return success();
});
addPattern<tpu::TraceStopOp>(
patterns, [](tpu::TraceStopOp op, tpu::TraceStopOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::TraceStopOp>(op);
return success();
});
addPattern<tpu::RegionOp>(
patterns, [](tpu::RegionOp op, tpu::RegionOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
llo::RegionOp new_op =
rewriter.create<llo::RegionOp>(op.getLoc(), TypeRange{});
if (new_op->getNumRegions() != 1) {
return failure();
}
rewriter.inlineRegionBefore(op.getRegion(), new_op.getRegion(),
new_op.getRegion().end());
rewriter.replaceOp(op, new_op);
return success();
});
addPattern<tpu::YieldOp>(
patterns, [](tpu::YieldOp op, tpu::YieldOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::YieldOp>(op, subst.getOperands());
return success();
});
addPattern<tpu::MaskCastOp>(
patterns, [this](tpu::MaskCastOp op, tpu::MaskCastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
if (!isMaskVectorType(op.getInput().getType()) ||
!isMaskVectorType(op.getType())) {
return failure();
}
rewriter.replaceOp(op, subst.getInput());
return success();
});
addPattern<tpu::AssumeMultipleOp>(
patterns,
[](tpu::AssumeMultipleOp op, tpu::AssumeMultipleOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOp(op, subst.getValue());
return success();
});
addPattern<tpu::GetIterationBoundOp>(
patterns,
[](tpu::GetIterationBoundOp op, tpu::GetIterationBoundOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<llo::GetIterationBoundOp>(
op, op.getDimAttr());
return success();
});
}
std::optional<int64_t> getMemRefSizeInBytes(MemRefType type) {
auto layout = type.getLayout().dyn_cast<tpu::TiledLayoutAttr>();
if (!type.hasStaticShape() || !layout) {
return std::nullopt;
}
std::vector<int64_t> minor_to_major(type.getRank());
std::iota(minor_to_major.rbegin(), minor_to_major.rend(),
static_cast<int64_t>(0));
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
xla::ConvertMlirTypeToPrimitiveType(type.getElementType()),
type.getShape(), minor_to_major, layout.getTiles());
if (!xla::jellyfish::TransferSizeUtil::HasSupportedTiling(
*target_->Topology(), xla_shape)) {
return std::nullopt;
}
return target_->ShapeSizeCompact(xla_shape);
}
LogicalResult getTiledDynamicShape(SmallVector<Value> *tiled_dynamic_shape,
ArrayRef<int64_t> shape,
ArrayRef<Value> dynamic_sizes,
absl::Span<const int64_t> tiling,
ConversionPatternRewriter &rewriter,
Operation *op) {
auto dynamic_dim_it = dynamic_sizes.begin();
for (int64_t i = 0; i < shape.size() - tiling.size(); ++i) {
Value size;
if (mlir::ShapedType::isDynamic(shape[i])) {
CHECK(dynamic_dim_it != dynamic_sizes.end());
size = *dynamic_dim_it++;
} else {
size = rewriter.create<llo::ConstantOp>(
op->getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(shape[i]));
}
tiled_dynamic_shape->push_back(size);
}
for (int i = 0; i < tiling.size(); ++i) {
int64_t size = shape[shape.size() - tiling.size() + i];
if (mlir::ShapedType::isDynamic(size)) {
return op->emitOpError("Dynamic sizes in tiled dims unsupported");
}
tiled_dynamic_shape->push_back(rewriter.create<llo::ConstantOp>(
op->getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(llvm::divideCeil(size, tiling[i]))));
}
return success();
}
Value getDynamicNumElements(ArrayRef<int64_t> shape,
ArrayRef<Value> dynamic_sizes,
ConversionPatternRewriter &rewriter) {
int64_t num_static_elements = 1;
for (int64_t size : shape) {
num_static_elements *= mlir::ShapedType::isDynamic(size) ? 1 : size;
}
Value num_elements = rewriter.create<llo::ConstantOp>(
rewriter.getUnknownLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(num_static_elements));
for (Value size : dynamic_sizes) {
num_elements = rewriter.create<llo::ScalarMulS32Op>(
rewriter.getUnknownLoc(), rewriter.getI32Type(), num_elements, size);
}
return num_elements;
}
std::pair<Value, SmallVector<Value>> unpackMemRef(
Value ref, ConversionPatternRewriter &rewriter) {
if (isa<IntegerType>(ref.getType())) {
return std::make_pair(ref, SmallVector<Value>{});
}
CHECK(isa<MemRefType>(ref.getType()));
auto ref_ty = getMemRefType(ref);
int64_t num_dynamic_dims = 0;
for (auto size : ref_ty.getShape()) {
if (mlir::ShapedType::isDynamic(size)) {
++num_dynamic_dims;
}
}
Type i32 = IntegerType::get(ref.getContext(), 32);
SmallVector<Type> memref_rep_tys(1 + num_dynamic_dims, i32);
auto cast = rewriter.create<UnrealizedConversionCastOp>(
rewriter.getUnknownLoc(), memref_rep_tys, ref);
auto sizes = cast.getResults().drop_front(1);
return std::make_pair(cast.getResult(0),
SmallVector<Value>(sizes.begin(), sizes.end()));
}
const xla::jellyfish::Target *target_;
const bool unsafe_allow_multicore_remote_dma_;
};
void LowerToLLOPass::runOnOperation() {
std::unique_ptr<MockTpuInstance> mock_instance;
if (target_ == nullptr) {
mock_instance = tryCreatingMock();
if (!mock_instance) {
signalPassFailure();
return;
}
target_ = mock_instance->target();
} else {
CHECK(mock_target == -1)
<< "mock_instance can only be used when target is unspecified";
}
TypeConverter converter;
converter.addConversion(
[](MemRefType ty,
mlir::SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
int64_t dynamic_dims = 0;
for (auto size : ty.getShape()) {
dynamic_dims += mlir::ShapedType::isDynamic(size);
}
types.append(1 + dynamic_dims, IntegerType::get(ty.getContext(), 32));
return success();
});
converter.addConversion([](tpu::SemaphoreType ty) -> std::optional<Type> {
return IntegerType::get(ty.getContext(), 32);
});
converter.addConversion([](tpu::DMASemaphoreType ty) -> std::optional<Type> {
return IntegerType::get(ty.getContext(), 32);
});
converter.addConversion([](IntegerType ty) -> std::optional<Type> {
if (ty.getWidth() > 32 || ty.getSignedness() != IntegerType::Signless) {
return std::nullopt;
}
return IntegerType::get(ty.getContext(), 32);
});
converter.addConversion([](IndexType ty) -> std::optional<Type> {
return IntegerType::get(ty.getContext(), 32);
});
converter.addConversion(
[](Float32Type ty) -> std::optional<Type> { return ty; });
converter.addConversion([this](VectorType ty) -> std::optional<Type> {
if (!isRepresentableVectorType(ty)) {
return std::nullopt;
}
return ty;
});
ConversionTarget target(getContext());
target.addLegalDialect<mlir::llo::LLODialect>();
target.addLegalOp<func::FuncOp, func::ReturnOp>();
target.addDynamicallyLegalOp<scf::IfOp>(
[&](scf::IfOp op) { return converter.isLegal(op.getResultTypes()); });
target.addDynamicallyLegalOp<scf::ForOp>([&](scf::ForOp op) {
// We check operand types, because we want to verify the type of loop bounds
// doesn't include the index type.
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
return op->getParentOp() && isa<scf::IfOp, scf::ForOp>(op->getParentOp()) &&
converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
for (Value arg : op.getArguments()) {
if (!arg.getType().isSignlessInteger(32)) {
return false;
}
}
return true;
});
RewritePatternSet patterns(&getContext());
populateVectorToLLOConversionPatterns(patterns);
populateArithToLLOConversionPatterns(patterns);
populateMathToLLOConversionPatterns(patterns);
populateMemrefToLLOConversionPatterns(patterns);
populateTPUToLLOConversionPatterns(patterns);
populateCFToLLOConversionPatterns(patterns);
addPattern<UnrealizedConversionCastOp>(
patterns,
[](UnrealizedConversionCastOp op, UnrealizedConversionCastOpAdaptor subst,
ConversionPatternRewriter &rewriter) {
rewriter.eraseOp(op);
return success();
});
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
scf::populateSCFStructuralTypeConversions(converter, patterns);
// Preserve the expected layouts before we proceed with the conversion.
func::FuncOp func = getOperation();
for (int i = 0; i < func.getNumArguments(); ++i) {
auto arg_ty = func.getArgument(i).getType();
// This is needed to allow for programs written in LLO directly.
if (!func.getArgAttr(i, "llo.type")) {
func.setArgAttr(i, "llo.type", TypeAttr::get(arg_ty));
}
auto memref_ty = dyn_cast<MemRefType>(arg_ty);
if (!memref_ty) {
continue;
}
auto layout_attr = dyn_cast<tpu::TiledLayoutAttr>(memref_ty.getLayout());
if (!layout_attr) {
func.emitOpError(
"All memref arguments should use the TiledLayoutAttr for layout");
signalPassFailure();
return;
}
func.setArgAttr(i, "llo.layout", layout_attr);
}
if (failed(applyFullConversion(func, target, std::move(patterns)))) {
signalPassFailure();
}
if (mock_instance) {
target_ = nullptr; // Avoid leaving a dangling pointer.
}
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createLowerToLLOPass(
const xla::jellyfish::Target &target,
bool unsafe_allow_multicore_remote_dma) {
return std::make_unique<LowerToLLOPass>(&target,
unsafe_allow_multicore_remote_dma);
}
std::unique_ptr<OperationPass<func::FuncOp>> createPartialLowerToLLOPass() {
return std::make_unique<LowerToLLOPass>();
}
} // namespace mlir::tpu